How to allow a Server to accept both SSL and plain

2019-04-28 02:24发布

问题:

I am trying to create a server that can accept both secure SSL and insecure plain text connection (for backwards compatibility). My code is almost working except the first transmitted data received from an insecure client loses the first 5 bytes (chars) at the server. More specificially if I transmit 30 bytes on an insecure connection, when the server gets to the OnClientDataReceived() function, the line "int iRx = nwStream.EndRead(asyn);", then iRx = 25. Any subsequent messages transmitted from the client contain all sent bytes/chars as expected. I suspect that the initial assumption of the connection being an SSLStream may be stripping the first 5 bytes and then when it fails, those 5 bytes have already been extracted from the buffer and are no longer available. Does any body know of another approach I could take to write the code so that the server automatically can switch on the fly?

I am trying to avoid doing the following:

  • Require that a client connect using a plain text NetworkStream and then request to upgrade to an SSL stream
  • Setting up two TcpListeners on two different ports (one for secure, one for insecure)

Here is my code:

/// Each client that connects gets an instance of the ConnectedClient class.
Class Pseudo_ConnectedClient
{
    //Properties
    byte[] Buffer; //Holds temporary buffer of read bytes from BeginRead()
    TcpClient TCPClient; //Reference to the connected client
    Socket ClientSocket; //The outer Socket Reference of the connected client
    StringBuilder CurrentMessage; //concatenated chunks of data in buffer until we have a complete message (ends with <ETX>
    Stream Stream; //SSLStream or NetworkStream depending on client
    ArrayList MessageQueue; //Array of complete messages received from client that need to be processed
}

/// When a new client connects (OnClientConnection callback is executed), the server creates the ConnectedClient object and stores its 
/// reference in a local dictionary, then configures the callbacks for incoming data (WaitForClientData)
void OnClientConnection(IAsyncResult result)
{
    TcpListener listener = result.AsyncState as TcpListener;
    TcpClient clnt = null;

    try
    {
        if (!IsRunning) //then stop was called, so don't call EndAcceptTcpClient because it will throw and ObjectDisposedException
            return;

        //Start accepting the next connection...
        listener.BeginAcceptTcpClient(this.onClientConnection, listener);

        //Get reference to client and set flag to indicate connection accepted.
        clnt = listener.EndAcceptTcpClient(result);

        //Add the reference to our ArrayList of Connected Clients
        ConnectedClient conClnt = new ConnectedClient(clnt);
        _clientList.Add(conClnt);

        //Configure client to listen for incoming data
        WaitForClientData(conClnt);
    }
    catch (Exception ex)
    {
        Trace.WriteLine("Server:OnClientConnection: Exception - " + ex.ToString());
    }
}

/// WaitForClientData registers the AsyncCallback to handle incoming data from a client (OnClientDataReceieved).  
/// If a certificate has been provided, then it listens for clients to connect on an SSLStream and configures the 
/// BeginAuthenticateAsServer callback.  If no certificate is provided, then it only sets up a NetworkStream 
/// and prepares for the BeginRead callback.
private void WaitForClientData(ConnectedClient clnt)
{
    if (!IsRunning) return; //Then stop was called, so don't do anything

    SslStream sslStream = null;

    try
    {
        if (_pfnClientDataCallBack == null) //then define the call back function to invoke when data is received from a connected client
            _pfnClientDataCallBack = new AsyncCallback(OnClientDataReceived);

        NetworkStream nwStream = clnt.TCPClient.GetStream();

        //Check if we can establish a secure connection
        if (this.SSLCertificate != null) //Then we have the ability to make an SSL connection (SSLCertificate is a X509Certificate2 object)
        {
            if (this.certValidationCallback != null)
                sslStream = new SslStream(nwStream, true, this.certValidationCallback);
            else
                sslStream = new SslStream(nwStream, true);

            clnt.Stream = sslStream;

            //Start Listening for incoming (secure) data
            sslStream.BeginAuthenticateAsServer(this.SSLCertificate, false, SslProtocols.Default, false, onAuthenticateAsServer, clnt);
        }
        else //No certificate available to make a secure connection, so use insecure (unless not allowed)
        {
            if (this.RequireSecureConnection == false) //Then we can try to read from the insecure stream
            {
                clnt.Stream = nwStream;

                //Start Listening for incoming (unsecure) data
                nwStream.BeginRead(clnt.Buffer, 0, clnt.Buffer.Length, _pfnClientDataCallBack, clnt);
            }
            else //we can't do anything - report config problem
            {
                throw new InvalidOperationException("A PFX certificate is not loaded and the server is configured to require a secure connection");
            }
        }
    }
    catch (Exception ex)
    {
        DisconnectClient(clnt);
    }
}

/// OnAuthenticateAsServer first checks if the stream is authenticated, if it isn't it gets the TCPClient's reference 
/// to the outer NetworkStream (client.TCPClient.GetStream()) - the insecure stream and calls the BeginRead on that.  
/// If the stream is authenticated, then it keeps the reference to the SSLStream and calls BeginRead on it.
private void OnAuthenticateAsServer(IAsyncResult result)
{
    ConnectedClient clnt = null;
    SslStream sslStream = null;

    if (this.IsRunning == false) return;

    try
    {
        clnt = result.AsyncState as ConnectedClient;
        sslStream = clnt.Stream as SslStream;

        if (sslStream.IsAuthenticated)
            sslStream.EndAuthenticateAsServer(result);
        else //Try and switch to an insecure connections
        {
            if (this.RequireSecureConnection == false) //Then we are allowed to accept insecure connections
            {
                if (clnt.TCPClient.Connected)
                    clnt.Stream = clnt.TCPClient.GetStream();
            }
            else //Insecure connections are not allowed, close the connection
            {
                DisconnectClient(clnt);
            }
        }
    }
    catch (Exception ex)
    {
        DisconnectClient(clnt);
    }

    if( clnt.Stream != null) //Then we have a stream to read, start Async read
        clnt.Stream.BeginRead(clnt.Buffer, 0, clnt.Buffer.Length, _pfnClientDataCallBack, clnt);
}

/// OnClientDataReceived callback is triggered by the BeginRead async when data is available from a client.  
/// It determines if the stream (as assigned by OnAuthenticateAsServer) is an SSLStream or a NetworkStream 
/// and then reads the data out of the stream accordingly.  The logic to parse and process the message has 
/// been removed because it isn't relevant to the question.
private void OnClientDataReceived(IAsyncResult asyn)
{
    try
    {
        ConnectedClient connectClnt = asyn.AsyncState as ConnectedClient;

        if (!connectClnt.TCPClient.Connected) //Then the client is no longer connected >> clean up
        {
            DisconnectClient(connectClnt);
            return;
        }

        Stream nwStream = null;
        if( connectClnt.Stream is SslStream) //Then this client is connected via a secure stream
            nwStream = connectClnt.Stream as SslStream;
        else //this is a plain text stream
            nwStream = connectClnt.Stream as NetworkStream;

        // Complete the BeginReceive() asynchronous call by EndReceive() method which
        // will return the number of characters written to the stream by the client
        int iRx = nwStream.EndRead(asyn); //Returns the numbers of bytes in the read buffer
        char[] chars = new char[iRx];   

        // Extract the characters as a buffer and create a String
        Decoder d = ASCIIEncoding.UTF8.GetDecoder();
        d.GetChars(connectClnt.Buffer, 0, iRx, chars, 0);

        //string data = ASCIIEncoding.ASCII.GetString(buff, 0, buff.Length);
        string data = new string(chars);

        if (iRx > 0) //Then there was data in the buffer
        {
            //Append the current packet with any additional data that was already received
            connectClnt.CurrentMessage.Append(data);

            //Do work here to check for a complete message
            //Make sure two complete messages didn't get concatenated in one transmission (mobile devices)
            //Add each message to the client's messageQueue
            //Clear the currentMessage
            //Any partial messsage at the end of the buffer needs to be added to the currentMessage

            //Start reading again
            nwStream.BeginRead(connectClnt.Buffer, 0, connectClnt.Buffer.Length, OnClientDataReceived, connectClnt);
        }
        else //zero-length packet received - Disconnecting socket
        {
            DisconnectClient(connectClnt);
        }                
    }
    catch (Exception ex)
    {
        return;
    }
}

What works:

  • If the server doesn't have a certificate, a NetworkStream is only used, and all bytes are received from the client for all messages.
  • If the server does have a certificate (an SSLStream is setup) and a secure connection can be established (web-browser using https://) and the full message is received for all messages.

What doesn't work:

  • If the server does have a certificate (an SSLStream is setup) and an insecure connection is made from a client, when the first message is received from that client, the code does correctly detect the SSLStream is not authenticated and switches to the NetworkStream of the TCPClient. However, when EndRead is called on that NetworkStream for the first message, the first 5 chars (bytes) are missing from the message that was sent, but only for the first message. All remaining messages are complete as long as the TCPClient is connected. If the client disconnects and then reconnects, the first message is clipped, then all subsequent messages are good again.

What is causing those first 5 bytes to be clipped, and how can I avoid it?

My project is currently using .NET v3.5... I would like to remain at this version and not step up to 4.0 if I can avoid it.


Follow-up Question

Damien's answer below does allow me to retain those missing 5 bytes, however, I would prefer to stick with the BeginRead and EndRead methods in my code to avoid blocking. Are there any good tutorials showing a 'best practices' when override(ing) these? More specifically, how to work with the IAsyncResult object. I get that I would need to add any content that is stored in the RestartableStream buffers, then fall through to the inner stream (base) to get the rest and return the toral. But since the IAsyncResult object is a custom class, I can't figure out the generic way that I can combine the buffs of RestartableStream with those of the inner stream before returning. Do I need to also implement BeginRead() so that I know the buffers the user wants the content stored into? I guess the other solution is, since the dropped bytes problem is only with the first message from the client (after that I know whether to use it as a SSLStream or a NetworkStream), would be to handle that first message by directly calling the Read() method of RestartableStream (temporarily blocking the code), then for all future messages use the Async callbacks to Read the contents as I do now.

回答1:

Okay, I think the best you can do is place your own class in between SslStream and NetworkStream where you implement some customized buffering. I've done a few tests on the below, but I'd recommend a few more before you put in in production (and probably some more robust error handling). I think I've avoided any 4.0 or 4.5isms:

  public sealed class RestartableReadStream : Stream
  {
    private Stream _inner;
    private List<byte[]> _buffers;
    private bool _buffering;
    private int? _currentBuffer = null;
    private int? _currentBufferPosition = null;
    public RestartableReadStream(Stream inner)
    {
      if (!inner.CanRead) throw new NotSupportedException(); //Don't know what else is being expected of us
      if (inner.CanSeek) throw new NotSupportedException(); //Just use the underlying streams ability to seek, no need for this class
      _inner = inner;
      _buffering = true;
      _buffers = new List<byte[]>();
    }

    public void StopBuffering()
    {
      _buffering = false;
      if (!_currentBuffer.HasValue)
      {
        //We aren't currently using the buffers
        _buffers = null;
        _currentBufferPosition = null;
      }
    }

    public void Restart()
    {
      if (!_buffering) throw new NotSupportedException();  //Buffering got turned off already
      if (_buffers.Count == 0) return;
      _currentBuffer = 0;
      _currentBufferPosition = 0;
    }

    public override int Read(byte[] buffer, int offset, int count)
    {
      if (_currentBuffer.HasValue)
      {
        //Try to satisfy the read request from the current buffer
        byte[] rbuffer = _buffers[_currentBuffer.Value];
        int roffset = _currentBufferPosition.Value;
        if ((rbuffer.Length - roffset) <= count)
        {
          //Just give them what we have in the current buffer (exhausting it)
          count = (rbuffer.Length - roffset);
          for (int i = 0; i < count; i++)
          {
            buffer[offset + i] = rbuffer[roffset + i];
          }

          _currentBuffer++;
          if (_currentBuffer.Value == _buffers.Count)
          {
            //We've stopped reading from the buffers
            if (!_buffering)
              _buffers = null;
            _currentBuffer = null;
            _currentBufferPosition = null;
          }
          return count;
        }
        else
        {
          for (int i = 0; i < count; i++)
          {
            buffer[offset + i] = rbuffer[roffset + i];
          }
          _currentBufferPosition += count;
          return count;
        }
      }
      //If we reach here, we're currently using the inner stream. But may be buffering the results
      int ncount = _inner.Read(buffer, offset, count);
      if (_buffering)
      {
        byte[] rbuffer = new byte[ncount];
        for (int i = 0; i < ncount; i++)
        {
          rbuffer[i] = buffer[offset + i];
        }
        _buffers.Add(rbuffer);
      }
      return ncount;
    }

    public override bool CanRead
    {
      get { return true; }
    }

    public override bool CanSeek
    {
      get { return false; }
    }

    public override bool CanWrite
    {
      get { return false; }
    }

    //No more interesting code below here

    public override void Flush()
    {
      throw new NotSupportedException();
    }

    public override long Length
    {
      get { throw new NotSupportedException(); }
    }

    public override long Position
    {
      get
      {
        throw new NotSupportedException();
      }
      set
      {
        throw new NotSupportedException();
      }
    }

    public override long Seek(long offset, SeekOrigin origin)
    {
      throw new NotSupportedException();
    }

    public override void SetLength(long value)
    {
      throw new NotSupportedException();
    }

    public override void Write(byte[] buffer, int offset, int count)
    {
      throw new NotSupportedException();
    }
  }

Usage:

Construct a RestartableReadStream around your NetworkStream. Pass that instance to SslStream. If you decide that SSL was the wrong way to do things, call Restart() and then use it again however you want to. You can even try more than two strategies (calling Restart() between each one).

Once you've settled on which strategy (e.g. SSL or non-SSL) is correct, call StopBuffering(). Once it's finished replaying any buffers it had available, it will revert to just calling Read on its inner stream. If you don't call StopBuffering, then the entire history of reads from the stream will be kept in the _buffers list, which could add quite a bit of memory pressure.

Note that none of the above particularly accounts for multi-threaded access. But if you've got multiple threads calling Read() on a single stream (especially one that's network based), I wouldn't expect any sanity anyway.



回答2:

I spent hours searching to not write a stream wrapper around NetworkStream and finally came across this and it worked for me. MSDN SocketFlag.Peek I kept finding suggestions to just write a wrapper or use separate ports, but I have a problem listening to authority or reason.

Here's my code. NLOLOL (No laughing out loud or lectures) I haven't completely figured out if I need to Peek at more than the first byte for all scenarios.

Private Async Sub ProcessTcpClient(__TcpClient As Net.Sockets.TcpClient)

        If __TcpClient Is Nothing OrElse Not __TcpClient.Connected Then Return

        Dim __RequestBuffer(0) As Byte
        Dim __BytesRead As Integer

        Using __NetworkStream As Net.Sockets.NetworkStream = __TcpClient.GetStream

            __BytesRead = __TcpClient.Client.Receive(__RequestBuffer, 0, 1, SocketFlags.Peek)
            If __BytesRead = 1 AndAlso __RequestBuffer(0) = 22 Then
                Await Me.ProcessTcpClientSsl(__NetworkStream)
            Else
                Await Me.ProcessTcpClientNonSsl(__NetworkStream)
            End If

        End Using

        __TcpClient.Close()

End Sub