diff --git a/.gitignore b/.gitignore index ebcdeb8..d48988f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ build/ dist/ +.idea *.pyc *.so diff --git a/setup.py b/setup.py index ec59766..1b8bbea 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,12 @@ 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: Implementation :: CPython', 'Operating System :: POSIX', 'Operating System :: Unix', diff --git a/sslpsk/sslpsk.py b/sslpsk/sslpsk.py index bd0f12b..c1f51ce 100644 --- a/sslpsk/sslpsk.py +++ b/sslpsk/sslpsk.py @@ -14,27 +14,30 @@ from __future__ import absolute_import -import ssl import _ssl -import sys +import ssl import weakref from sslpsk import _sslpsk _callbacks = {} + class FinalizerRef(weakref.ref): """subclass weakref.ref so that attributes can be added""" pass + def _register_callback(sock, ssl_id, callback): _callbacks[ssl_id] = callback callback.unregister = FinalizerRef(sock, _unregister_callback) callback.unregister.ssl_id = ssl_id + def _unregister_callback(ref): del _callbacks[ref.ssl_id] + def _python_psk_client_callback(ssl_id, hint): """Called by _sslpsk.c to return the (psk, identity) tuple for the socket with the specified ssl socket. @@ -46,6 +49,7 @@ def _python_psk_client_callback(ssl_id, hint): res = _callbacks[ssl_id](hint) return res if isinstance(res, tuple) else (res, b"") + def _sslobj(sock): """Returns the underlying PySLLSocket object with which the C extension functions interface. @@ -57,6 +61,7 @@ def _sslobj(sock): else: return sock._sslobj._sslobj + def _python_psk_server_callback(ssl_id, identity): """Called by _sslpsk.c to return the psk for the socket with the specified ssl socket. @@ -67,46 +72,83 @@ def _python_psk_server_callback(ssl_id, identity): else: return _callbacks[ssl_id](identity) + _sslpsk.sslpsk_set_python_psk_client_callback(_python_psk_client_callback) _sslpsk.sslpsk_set_python_psk_server_callback(_python_psk_server_callback) - + + def _ssl_set_psk_client_callback(sock, psk_cb): ssl_id = _sslpsk.sslpsk_set_psk_client_callback(_sslobj(sock)) _register_callback(sock, ssl_id, psk_cb) + def _ssl_set_psk_server_callback(sock, psk_cb, hint): ssl_id = _sslpsk.sslpsk_set_accept_state(_sslobj(sock)) - _ = _sslpsk.sslpsk_set_psk_server_callback(_sslobj(sock)) - _ = _sslpsk.sslpsk_use_psk_identity_hint(_sslobj(sock), hint if hint else b"") + _ = _sslpsk.sslpsk_set_psk_server_callback(_sslobj(sock)) + _ = _sslpsk.sslpsk_use_psk_identity_hint(_sslobj(sock), hint if hint else b"") _register_callback(sock, ssl_id, psk_cb) -def wrap_socket(*args, **kwargs): - """ - """ - do_handshake_on_connect = kwargs.get('do_handshake_on_connect', True) - kwargs['do_handshake_on_connect'] = False - - psk = kwargs.setdefault('psk', None) - del kwargs['psk'] - - hint = kwargs.setdefault('hint', None) - del kwargs['hint'] - server_side = kwargs.setdefault('server_side', False) +def _ssl_setup_psk_callbacks(sslobj): + psk = sslobj.context.psk + hint = sslobj.context.hint if psk: - del kwargs['server_side'] # bypass need for cert - - sock = ssl.wrap_socket(*args, **kwargs) - - if psk: - if server_side: + if sslobj.server_side: cb = psk if callable(psk) else lambda _identity: psk - _ssl_set_psk_server_callback(sock, cb, hint) + _ssl_set_psk_server_callback(sslobj, cb, hint) else: cb = psk if callable(psk) else lambda _hint: psk if isinstance(psk, tuple) else (psk, b"") - _ssl_set_psk_client_callback(sock, cb) + _ssl_set_psk_client_callback(sslobj, cb) + + +class SSLPSKContext(ssl.SSLContext): + @property + def psk(self): + return getattr(self, "_psk", None) + + @psk.setter + def psk(self, psk): + self._psk = psk + + @property + def hint(self): + return getattr(self, "_hint", None) + + @hint.setter + def hint(self, hint): + self._hint = hint + + +class SSLPSKObject(ssl.SSLObject): + def do_handshake(self, *args, **kwargs): + _ssl_setup_psk_callbacks(self) + super().do_handshake(*args, **kwargs) + + +class SSLPSKSocket(ssl.SSLSocket): + def do_handshake(self, *args, **kwargs): + _ssl_setup_psk_callbacks(self) + super().do_handshake(*args, **kwargs) + + +SSLPSKContext.sslobject_class = SSLPSKObject +SSLPSKContext.sslsocket_class = SSLPSKSocket + - if do_handshake_on_connect: - sock.do_handshake() +def wrap_socket(sock, psk, hint=None, + server_side=False, + ssl_version=ssl.PROTOCOL_TLS, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + ciphers=None): + context = SSLPSKContext(ssl_version) + if ciphers: + context.set_ciphers(ciphers) + context.psk = psk + context.hint = hint - return sock + return context.wrap_socket( + sock=sock, server_side=server_side, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs + )