How to (repeatedly) read from .NET SslStream with

2019-04-04 04:10发布

问题:

I just need to read up to N bytes from a SslStream but if no byte has been received before a timeout, cancel, while leaving the stream in a valid state in order to try again later. (*)

This can be done easily for non-SSL streams i.e. NetworkStream simply by using its ReadTimeout property which will make the stream throw an exception on timeout. Unfortunately this approach doesn't work on SslStream per the official docs:

SslStream assumes that a timeout along with any other IOException when one is thrown from the inner stream will be treated as fatal by its caller. Reusing a SslStream instance after a timeout will return garbage. An application should Close the SslStream and throw an exception in these cases.

[Updated 1] I tried a different approach like this:

task = stream->ReadAsync(buffer, 0, buffer->Length);
if (task->Wait(timeout_ms)) {
   count = task->Result;
   ...
}

But this doesn't work if Wait() returned false: when calling ReadAsync() again later it throws an exception:

Exception thrown: 'System.NotSupportedException' in System.dll Tests.exe Warning: 0 : Failed reading from socket: System.NotSupportedException: The BeginRead method cannot be called when another read operation is pending.

[Update 2] I tried yet another approach to implement timeouts by calling Poll(timeout, ...READ) on the underlying TcpClient socket: if it returns true, then call Read() on the SSlStream, or if it returns false then we have a timeout. This doesn't work either: because SslStream presumably uses its own internal intermediary buffers, Poll() can return false even if there's data left to be read in the SslStream.

[Update 3] Another possibility would be to write a custom Stream subclass that would sit between NetworkStream and SslStream and capture the timeout exception and return 0 bytes instead to SslStream. I'm not sure how to do this, and more importantly, I have no idea if returning a 0 bytes read to SslStream would still not corrupt it somehow.

(*) The reason I'm trying to do this is that reading synchronously with a timeout from a non-secure or secure socket is the pattern I'm already using on iOS, OS X, Linux and Android for some cross-platform code. It works for non-secure sockets in .NET so the only case remaining is SslStream.

回答1:

You can certainly make approach #1 work. You simply need to keep track of the Task and continue waiting without calling ReadAsync again. So, very roughly:

private Task readTask;     // class level variable
...
  if (readTask == null) readTask = stream->ReadAsync(buffer, 0, buffer->Length);
  if (task->Wait(timeout_ms)) {
     try {
         count = task->Result;
         ...
     }
     finally {
         task = null;
     }
  }

Needs to be fleshed-out a bit so the caller can see that the read isn't completed yet but the snippet is too small to give concrete advice.



回答2:

I also encountered this problem with an SslStream returning five bytes of garbage data on the read after a timeout, and I separately came up a solution that is similar to OP's Update #3.

I created a wrapper class which wraps the Tcp NetworkStream object as it is passed into the SslStream constructor. The wrapper class passes all calls onto to the underlying NetworkStream except that the Read() method includes an extra try...catch to suppress the Timeout exception and return 0 bytes instead.

SslStream works correctly in this instance, including raising the appropriate IOException if the socket is closed. Note that our Stream returning 0 from a Read() is different from a TcpClient or Socket returning 0 from a Read() (which typically means a socket disconnect).

class SocketTimeoutSuppressedStream : Stream
{
    NetworkStream mStream;

    public SocketTimeoutSuppressedStream(NetworkStream pStream)
    {
        mStream = pStream;
    }

    public override int Read(byte[] buffer, int offset, int count)
    {
        try
        {
            return mStream.Read(buffer, offset, count);
        }
        catch (IOException lException)
        {
            SocketException lInnerException = lException.InnerException as SocketException;
            if (lInnerException != null && lInnerException.SocketErrorCode == SocketError.TimedOut)
            {
                // Normally, a simple TimeOut on the read will cause SslStream to flip its lid
                // However, if we suppress the IOException and just return 0 bytes read, this is ok.
                // Note that this is not a "Socket.Read() returning 0 means the socket closed",
                // this is a "Stream.Read() returning 0 means that no data is available"
                return 0;
            }
            throw;
        }
    }


    public override bool CanRead => mStream.CanRead;
    public override bool CanSeek => mStream.CanSeek;
    public override bool CanTimeout => mStream.CanTimeout;
    public override bool CanWrite => mStream.CanWrite;
    public virtual bool DataAvailable => mStream.DataAvailable;
    public override long Length => mStream.Length;
    public override IAsyncResult BeginRead(byte[] buffer, int offset, int size, AsyncCallback callback, object state) => mStream.BeginRead(buffer, offset, size, callback, state);
    public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback callback, object state) => mStream.BeginWrite(buffer, offset, size, callback, state);
    public void Close(int timeout) => mStream.Close(timeout);
    public override int EndRead(IAsyncResult asyncResult) => mStream.EndRead(asyncResult);
    public override void EndWrite(IAsyncResult asyncResult) => mStream.EndWrite(asyncResult);
    public override void Flush() => mStream.Flush();
    public override Task FlushAsync(CancellationToken cancellationToken) => mStream.FlushAsync(cancellationToken);
    public override long Seek(long offset, SeekOrigin origin) => mStream.Seek(offset, origin);
    public override void SetLength(long value) => mStream.SetLength(value);
    public override void Write(byte[] buffer, int offset, int count) => mStream.Write(buffer, offset, count);

    public override long Position
    {
        get { return mStream.Position; }
        set { mStream.Position = value; }
    }

    public override int ReadTimeout
    {
        get { return mStream.ReadTimeout; }
        set { mStream.ReadTimeout = value; }
    }

    public override int WriteTimeout
    {
        get { return mStream.WriteTimeout; }
        set { mStream.WriteTimeout = value; }
    }
}

This can then be used by wrapping the TcpClient NetworkStream object before it's passed to the SslStream, as follows:

NetworkStream lTcpStream = lTcpClient.GetStream();
SocketTimeoutSuppressedStream lSuppressedStream = new SocketTimeoutSuppressedStream(lTcpStream);
using (lSslStream = new SslStream(lSuppressedStream, true, ServerCertificateValidation, SelectLocalCertificate, EncryptionPolicy.RequireEncryption))

The problem comes down to SslStream corrupting its internal state on any exception from the underlying stream, even a harmless timeout. Oddly, the five (or so) bytes of data that the next read() returns are actually the start of the TLS encrypted payload data from the wire.

Hope this helps