Using ssl context.set_servername_callback in Pytho

2019-05-08 02:45发布

问题:

I have a goal of allowing an ssl client to select from a number of valid certificate pairs from the server. The client has a CA certificate which it will use to validate the certificate coming from the server.

So to try to accomplish this, I'm using the ssl.SSLContext.set_servername_callback() on the server in combination with ssl.SSLSocket.wrap_socket's parameter:server_hostname` to try to allow the client to specify which keypair to use. Here's what the code looks like:

Server code:

import sys
import pickle
import ssl
import socket
import select

request = {'msgtype': 0, 'value': 'Ping', 'test': [chr(i) for i in range(256)]}
response = {'msgtype': 1, 'value': 'Pong'}

def handle_client(c, a):
    print("Connection from {}:{}".format(*a))
    req_raw = c.recv(10000)
    req = pickle.loads(req_raw)
    print("Received message: {}".format(req))
    res = pickle.dumps(response)
    print("Sending message: {}".format(response))
    c.send(res)

def run_server(hostname, port):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    s.bind((hostname, port))
    s.listen(8)
    print("Serving on {}:{}".format(hostname, port))

    try:
        while True:
            (c, a) = s.accept()

            def servername_callback(sock, req_hostname, cb_context, as_callback=True):
                print('Loading certs for {}'.format(req_hostname))
                server_cert = "ssl/{}/server".format(req_hostname)  # NOTE: This use of socket input is INSECURE
                cb_context.load_cert_chain(certfile="{}.crt".format(server_cert), keyfile="{}.key".format(server_cert))

                # Seems like this is designed usage: https://github.com/python/cpython/blob/3.4/Modules/_ssl.c#L1469
                sock.context = cb_context
                return None

            context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
            context.set_servername_callback(servername_callback)
            default_cert = "ssl/3.1/server"
            context.load_cert_chain(certfile="{}.crt".format(default_cert), keyfile="{}.key".format(default_cert))
            ssl_sock = context.wrap_socket(c, server_side=True)

            try:
                handle_client(ssl_sock, a)
            finally:
                c.close()

    except KeyboardInterrupt:
        s.close()

if __name__ == '__main__':
    hostname = ''
    port = 6789
    run_server(hostname, port)

Client code:

import sys
import pickle
import socket
import ssl

request = {'msgtype': 0, 'value': 'Ping', 'test': [chr(i) for i in range(256)]}
response = {'msgtype': 1, 'value': 'Pong'}


def client(hostname, port):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    print("Connecting to {}:{}".format(hostname, port))
    s.connect((hostname, port))

    ssl_sock = ssl.SSLSocket(sock=s, ca_certs="server_old.crt", cert_reqs=ssl.CERT_REQUIRED, server_hostname='3.2')

    print("Sending message: {}".format(request))
    req = pickle.dumps(request)
    ssl_sock.send(req)

    resp_raw = ssl_sock.recv(10000)
    resp = pickle.loads(resp_raw)
    print("Received message: {}".format(resp))

    ssl_sock.close()

if __name__ == '__main__':
    hostname = 'localhost'
    port = 6789
    client(hostname, port)

But it's not working. What seems to be happening is servername_callback is getting called, is getting the specified "hostname", and the call to context.load_cert_chain within the callback is not failing (though it does fail if it's given path that doesn't exist). However, the server always returns the certificate pair that was loaded prior to calling context.wrap_socket(c, server_side=True). So my question is: is there some way, within the servername_callback, to modify the keypair used by the ssl context, and get that keypair's certificate to be used for the connection?

I should also note that I checked the traffic, and the server's certificate is NOT being sent until after the servername_callback function returns (and will never be sent if it fails to complete successfully, or returns a "failure" value).

回答1:

In your callback, cb_context is the same context on which wrap_socket() was called, and the same as socket.context, so socket.context = cb_context sets the context to the same it was before.

Changing the certificate chain of a context does not affect the certificate used for the current wrap_socket() operation. The explanation for this lies in how openssl creates its underlying objects, in this case the underlying SSL structures have already been created and use copies of the chains:

NOTES

The chains associate with an SSL_CTX structure are copied to any SSL structures when SSL_new() is called. SSL structures will not be affected by any chains subsequently changed in the parent SSL_CTX.

When setting a new context, the SSL structures are updated, but that update is not performed when the new context is equal to the old one.

You need to set sock.context to a different context to make it work. You currently instantiate a new context on each new incoming connection, which is not needed. Instead you should instantiate your standard context only once and reuse that. Same goes for the dynamically loaded contexts, you could create them all on startup and put them in a dict so you can just do a lookup, e.g:

...

contexts = {}

for hostname in os.listdir("ssl"):
    print('Loading certs for {}'.format(hostname))
    server_cert = "ssl/{}/server".format(hostname)
    context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
    context.load_cert_chain(certfile="{}.crt".format(server_cert),
                            keyfile="{}.key".format(server_cert))
    contexts[hostname] = context

def servername_callback(sock, req_hostname, cb_context, as_callback=True):
    context = contexts.get(req_hostname)
    if context is not None:
        sock.context = context
    else:
        pass  # handle unknown hostname case

def run_server(hostname, port):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    s.bind((hostname, port))
    s.listen(8)
    print("Serving on {}:{}".format(hostname, port))

    context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
    context.set_servername_callback(servername_callback)
    default_cert = "ssl/3.1/server"
    context.load_cert_chain(certfile="{}.crt".format(default_cert),
                            keyfile="{}.key".format(default_cert))

    try:
        while True:
            (c, a) = s.accept()
            ssl_sock = context.wrap_socket(c, server_side=True)
            try:
                handle_client(ssl_sock, a)
            finally:
                c.close()

    except KeyboardInterrupt:
        s.close()


回答2:

So after looking at this post and a few others online, I put together a version of the code above, that worked for me perfectly... so I just thought I would share. In case it helps anyone else.

import sys
import ssl
import socket
import os

from pprint import pprint

DOMAIN_CONTEXTS = {}

ssl_root_path = "c:/ssl/"

# ----------------------------------------------------------------------------------------------------------------------
#
# As an example create domains in the ssl root path...ie
#
# c:/ssl/example.com
# c:/ssl/johndoe.com
# c:/ssl/test.com
#
# And then create self signed ssl certificates for each domain to test... and put them in the corresponding domain 
# directory... in this case the cert and key files are called cert.pem, and key.pem.... 
#

def setup_ssl_certs():

    global DOMAIN_CONTEXTS

    for hostname in os.listdir(ssl_root_path):

        #print('Loading certs for {}'.format(hostname))

        # Establish the certificate and key folder...for the various domains...
        server_cert = '{rp}{hn}/'.format(rp=ssl_root_path, hn=hostname)

        # Setup the SSL Context manager object, for authentication
        context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)

        # Load the certificate file, and key file...into the context manager.
        context.load_cert_chain(certfile="{}cert.pem".format(server_cert), keyfile="{}key.pem".format(server_cert))

        # Set the context object to the global dictionary
        DOMAIN_CONTEXTS[hostname] = context

    # Uncomment for testing only.
    #pprint(contexts)

# ----------------------------------------------------------------------------------------------------------------------

def servername_callback(sock, req_hostname, cb_context, as_callback=True):
    """
    This is a callback function for the SSL Context manager, this is what does the real work of pulling the
    domain name in the origional request.
    """

    # Uncomment for testing only
    #print(sock)
    #print(req_hostname)
    #print(cb_context)

    context = DOMAIN_CONTEXTS.get(req_hostname)

    if context:

        try:
            sock.context = context
        except Exception as error:
            print(error)
        else:
            sock.server_hostname = req_hostname

    else:
        pass  # handle unknown hostname case


def handle_client(conn, a):

    request_domain = conn.server_hostname

    request = conn.recv()

    client_ip = conn.getpeername()[0]

    resp = 'Hello {cip} welcome, from domain {d} !'.format(cip=client_ip, d=request_domain)

    conn.write(b'HTTP/1.1 200 OK\n\n%s' % resp.encode())


def run_server(hostname, port):

    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

    s.bind((hostname, port))

    s.listen(8)

    #print("Serving on {}:{}".format(hostname, port))

    context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)

    # For Python 3.4+
    context.set_servername_callback(servername_callback)

    # Only available in 3.7 !!!! have not tested it yet...
    #context.sni_callback(servername_callback)

    default_cert = "{rp}default/".format(rp=ssl_root_path)

    context.load_cert_chain(certfile="{}cert.pem".format(default_cert), keyfile="{}key.pem".format(default_cert))

    context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1  # optional

    context.set_ciphers('EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH')

    try:
        while True:

            ssock, addr = s.accept()

            try:
                conn = context.wrap_socket(ssock, server_side=True)

            except Exception as error:
                print('!!! Error, {e}'.format(e=error))

            except ssl.SSLError as e:
                print(e)

            else:
                handle_client(conn, addr)

                if conn:
                    conn.close()
                    #print('Connection closed !')

    except KeyboardInterrupt:
        s.close()

# ----------------------------------------------------------------------------------------------------------------------

def main():

    setup_ssl_certs()

    # Don't forget to update your static name resolution...  ie example.com = 127.0.0.1
    run_server('example.com', 443)

# ----------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    main()