Problems using SslStream in a TLS web socket serve

2019-08-14 04:11发布

问题:

I followed this example to create my test certificates. I used Certificate.cer for the server and Certificate.pfx for the client:

makecert -r -pe -n "CN=Test Certificate" -sky exchange Certificate.cer -sv Key.pvk -eku 1.3.6.1.5.5.7.3.1,1.3.6.1.5.5.7.3.2

"C:\Program Files (x86)\Windows Kits\8.1\bin\x64\pvk2pfx.exe" -pvk Key.pvk -spc Certificate.cer -pfx Certificate.pfx

I am trying to create a web socket server and properly validate certificates from both the client and server sides of the communication. Here is my entire console application which I am currently building:

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Net.WebSockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading.Tasks;

namespace WebSockets
{    
    class Program
    {
        static void Main(string[] args)
        {
            CreateWebSocketClient(CreateWebSocketServer(1337), 1338);
            Console.WriteLine("Press any key to exit.");
            Console.ReadKey();
        }

        private static IPEndPoint CreateWebSocketServer(int port)
        {
            var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.IP);
            IPEndPoint endpoint = new IPEndPoint(IPAddress.Loopback, port);
            socket.Bind(endpoint);
            socket.Listen(Int32.MaxValue);
            socket.BeginAccept((result) =>
            {
                var clientSocket = socket.EndAccept(result);
                Console.WriteLine("{0}: Connected to the client at {1}.", DateTime.Now, clientSocket.RemoteEndPoint);
                using (var stream = new SslStream(new NetworkStream(clientSocket), false, (sender, certificate, chain, sslPolicyErrors) =>
                    {
                        return true;
                    }, (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) =>
                    {
                        return new X509Certificate2("Certificate.pfx");
                    }, EncryptionPolicy.RequireEncryption))
                {
                    stream.AuthenticateAsServer(new X509Certificate2("Certificate.pfx"), true, SslProtocols.Tls12, true);
                    stream.Write("Hello".ToByteArray());
                    Console.WriteLine("{0}: Read \"{1}\" from the client at {2}.", DateTime.Now, stream.ReadMessage(), clientSocket.RemoteEndPoint);
                }
            }, null);
            Console.WriteLine("{0}: Web socket server started at {1}.", DateTime.Now, socket.LocalEndPoint);
            return endpoint;
        }

        private static void CreateWebSocketClient(IPEndPoint remoteEndpoint, int port)
        {
            var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.IP);
            IPEndPoint localEndpoint = new IPEndPoint(IPAddress.Loopback, port);
            socket.Bind(localEndpoint);
            socket.BeginConnect(remoteEndpoint, (result) =>
            {
                socket.EndConnect(result);
                Console.WriteLine("{0}: Connected to the server at {1}.", DateTime.Now, remoteEndpoint);
                using (var stream = new SslStream(new NetworkStream(socket), false, (sender, certificate, chain, sslPolicyErrors) =>
                    {
                        return true;
                    }, (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) =>
                    {
                        return new X509Certificate2("Certificate.cer");
                    }, EncryptionPolicy.RequireEncryption))
                {
                    stream.AuthenticateAsClient(remoteEndpoint.ToString(), new X509Certificate2Collection(new X509Certificate2[] { new X509Certificate2("Certificate.cer") }), SslProtocols.Tls12, true);
                    stream.Write("Hello".ToByteArray());
                    Console.WriteLine("{0}: Read \"{1}\" from the server at {2}.", DateTime.Now, stream.ReadMessage(), remoteEndpoint);
                }
            }, null);
        }
    }

    public static class StringExtensions
    {
        public static Byte[] ToByteArray(this String value)
        {
            Byte[] bytes = new Byte[value.Length * sizeof(Char)];
            Buffer.BlockCopy(value.ToCharArray(), 0, bytes, 0, bytes.Length);
            return bytes;
        }

        public static String FromByteArray(this Byte[] bytes)
        {
            Char[] characters = new Char[bytes.Length / sizeof(Char)];
            Buffer.BlockCopy(bytes, 0, characters, 0, bytes.Length);
            return new String(characters).Trim(new Char[] { (Char)0 });
        }

        public static int BufferSize = 0x400;

        public static String ReadMessage(this SslStream stream)
        {
            var buffer = new Byte[BufferSize];
            stream.Read(buffer, 0, BufferSize);
            return FromByteArray(buffer);
        }
    }
}

Communication between server and client works fine when you run it, but I am not sure how I should implement the callbacks, specifically because sslPolicyErrors = RemoteCertificateNotAvailable when the RemoteCertificateValidationCallback is called on the server side and sslPolicyErrors = RemoteCertificateNameMismatch | RemoteCertificateChainErrors when the RemoteCertificateValidationCallback is called on the client side. Also, certificate and chain are null on the server side but appear on the callback from the client side. Why is that? What are the problems with my implementation and how can I make my implementation validate SSL certificates properly? I have tried searching online about the SslStream but I have yet to see a full, X509-based TLS server-client implementation that does the type of certificate validation I need.

回答1:

I had three separate problems. My initial approach was good, but:

  1. I have misused certificates here, as using the .pfx certificate on the client side resolves my RemoteCertificateNotAvailable problem. I am not sure as to why the .cer did not work.

  2. I have specified the wrong subject name in my call to AuthenticateAsClient, as using "Test Certificate" for the first argument instead of remoteEndpoint.ToString() solves my RemoteCertificateNameMismatch.

  3. Despite being self-signed, to get around the RemoteCertificateChainErrors error, I had to add this certificate to the Trusted People store under my current user account in order to trust the certificate.

Some other small refinements included, and my resulting code, which accepts multiple clients now as well (as I had fixed some bugs above), is as follows (please don't copy this verbatim as it needs a lot of Pokemon exception handling in different places, proper clean-up logic, making use of the bytes read on Read calls instead of trimming NUL, and the introduction of some Unicode character such as EOT to specify the end of messages, parsing for it, as well as handling of odd sized buffers which are not supported since our C# character size is 2 bytes, handling of odd reads, etc.; this needs a lot of refinement before it ever sees the light of a production system and serves only as an example or a proof of concept, if you will.):

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Net.WebSockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace WebSockets
{
    class Program
    {
        static void Main(string[] args)
        {
            IPEndPoint server = CreateWebSocketServer(1337);
            CreateWebSocketClient(server, 1338);
            CreateWebSocketClient(server, 1339);
            Console.WriteLine("Press any key to exit.");
            Console.ReadKey();
        }

        private static IPEndPoint CreateWebSocketServer(int port)
        {
            var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.IP);
            IPEndPoint endpoint = new IPEndPoint(IPAddress.Loopback, port);
            socket.Bind(endpoint);
            socket.Listen(Int32.MaxValue);
            ListenForClients(socket);
            Console.WriteLine("{0}: Web socket server started at {1}.", DateTime.Now, socket.LocalEndPoint);
            return endpoint;
        }

        private static void ListenForClients(Socket socket)
        {
            socket.BeginAccept((result) =>
            {
                new Thread(() =>
                {
                    ListenForClients(socket);
                }).Start();
                var clientSocket = socket.EndAccept(result);
                Console.WriteLine("{0}: Connected to the client at {1}.", DateTime.Now, clientSocket.RemoteEndPoint);
                using (var stream = new SslStream(new NetworkStream(clientSocket), false, (sender, certificate, chain, sslPolicyErrors) =>
                {
                    if (sslPolicyErrors == SslPolicyErrors.None)
                        return true;
                    return false;
                }, (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) =>
                {
                    return new X509Certificate2("Certificate.pfx");
                }, EncryptionPolicy.RequireEncryption))
                {
                    stream.AuthenticateAsServer(new X509Certificate2("Certificate.pfx"), true, SslProtocols.Tls12, true);
                    stream.Write("Hello".ToByteArray());
                    Console.WriteLine("{0}: Read \"{1}\" from the client at {2}.", DateTime.Now, stream.ReadMessage(), clientSocket.RemoteEndPoint);
                }
            }, null);
        }

        private static void CreateWebSocketClient(IPEndPoint remoteEndpoint, int port)
        {
            var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.IP);
            IPEndPoint localEndpoint = new IPEndPoint(IPAddress.Loopback, port);
            socket.Bind(localEndpoint);
            socket.BeginConnect(remoteEndpoint, (result) =>
            {
                socket.EndConnect(result);
                Console.WriteLine("{0}: Client at {1} connected to the server at {2}.", DateTime.Now, localEndpoint, remoteEndpoint);
                using (var stream = new SslStream(new NetworkStream(socket), false, (sender, certificate, chain, sslPolicyErrors) =>
                {
                    if (sslPolicyErrors == SslPolicyErrors.None)
                        return true;
                    return false;
                }, (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) =>
                {
                    return new X509Certificate2("Certificate.pfx");
                }, EncryptionPolicy.RequireEncryption))
                {
                    stream.AuthenticateAsClient("Test Certificate", new X509Certificate2Collection(new X509Certificate2[] { new X509Certificate2("Certificate.pfx") }), SslProtocols.Tls12, true);
                    stream.Write("Hello".ToByteArray());
                    Console.WriteLine("{0}: Client at {1} read \"{2}\" from the server at {3}.", DateTime.Now, localEndpoint, stream.ReadMessage(), remoteEndpoint);
                }
            }, null);
        }
    }

    public static class StringExtensions
    {
        public static Byte[] ToByteArray(this String value)
        {
            Byte[] bytes = new Byte[value.Length * sizeof(Char)];
            Buffer.BlockCopy(value.ToCharArray(), 0, bytes, 0, bytes.Length);
            return bytes;
        }

        public static String FromByteArray(this Byte[] bytes)
        {
            Char[] characters = new Char[bytes.Length / sizeof(Char)];
            Buffer.BlockCopy(bytes, 0, characters, 0, bytes.Length);
            return new String(characters).Trim(new Char[] { (Char)0 });
        }

        public static int BufferSize = 0x400;

        public static String ReadMessage(this SslStream stream)
        {
            var buffer = new Byte[BufferSize];
            stream.Read(buffer, 0, BufferSize);
            return FromByteArray(buffer);
        }
    }
}

I hope this helps others demystify web sockets, SSL streams, X509 certificates, and so forth, in C#. Happy coding. :) I may end up posting its final edition on my blog.