Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
build/
dist/
.idea

*.pyc
*.so
Expand Down
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Programming Language :: Python :: Implementation :: CPython',
'Operating System :: POSIX',
'Operating System :: Unix',
Expand Down
98 changes: 70 additions & 28 deletions sslpsk/sslpsk.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,30 @@

from __future__ import absolute_import

import ssl
import _ssl
import sys
import ssl
import weakref

from sslpsk import _sslpsk

_callbacks = {}


class FinalizerRef(weakref.ref):
"""subclass weakref.ref so that attributes can be added"""
pass


def _register_callback(sock, ssl_id, callback):
_callbacks[ssl_id] = callback
callback.unregister = FinalizerRef(sock, _unregister_callback)
callback.unregister.ssl_id = ssl_id


def _unregister_callback(ref):
del _callbacks[ref.ssl_id]


def _python_psk_client_callback(ssl_id, hint):
"""Called by _sslpsk.c to return the (psk, identity) tuple for the socket with
the specified ssl socket.
Expand All @@ -46,6 +49,7 @@ def _python_psk_client_callback(ssl_id, hint):
res = _callbacks[ssl_id](hint)
return res if isinstance(res, tuple) else (res, b"")


def _sslobj(sock):
"""Returns the underlying PySLLSocket object with which the C extension
functions interface.
Expand All @@ -57,6 +61,7 @@ def _sslobj(sock):
else:
return sock._sslobj._sslobj


def _python_psk_server_callback(ssl_id, identity):
"""Called by _sslpsk.c to return the psk for the socket with the specified
ssl socket.
Expand All @@ -67,46 +72,83 @@ def _python_psk_server_callback(ssl_id, identity):
else:
return _callbacks[ssl_id](identity)


_sslpsk.sslpsk_set_python_psk_client_callback(_python_psk_client_callback)
_sslpsk.sslpsk_set_python_psk_server_callback(_python_psk_server_callback)



def _ssl_set_psk_client_callback(sock, psk_cb):
ssl_id = _sslpsk.sslpsk_set_psk_client_callback(_sslobj(sock))
_register_callback(sock, ssl_id, psk_cb)


def _ssl_set_psk_server_callback(sock, psk_cb, hint):
ssl_id = _sslpsk.sslpsk_set_accept_state(_sslobj(sock))
_ = _sslpsk.sslpsk_set_psk_server_callback(_sslobj(sock))
_ = _sslpsk.sslpsk_use_psk_identity_hint(_sslobj(sock), hint if hint else b"")
_ = _sslpsk.sslpsk_set_psk_server_callback(_sslobj(sock))
_ = _sslpsk.sslpsk_use_psk_identity_hint(_sslobj(sock), hint if hint else b"")
_register_callback(sock, ssl_id, psk_cb)

def wrap_socket(*args, **kwargs):
"""
"""
do_handshake_on_connect = kwargs.get('do_handshake_on_connect', True)
kwargs['do_handshake_on_connect'] = False

psk = kwargs.setdefault('psk', None)
del kwargs['psk']

hint = kwargs.setdefault('hint', None)
del kwargs['hint']

server_side = kwargs.setdefault('server_side', False)
def _ssl_setup_psk_callbacks(sslobj):
psk = sslobj.context.psk
hint = sslobj.context.hint
if psk:
del kwargs['server_side'] # bypass need for cert

sock = ssl.wrap_socket(*args, **kwargs)

if psk:
if server_side:
if sslobj.server_side:
cb = psk if callable(psk) else lambda _identity: psk
_ssl_set_psk_server_callback(sock, cb, hint)
_ssl_set_psk_server_callback(sslobj, cb, hint)
else:
cb = psk if callable(psk) else lambda _hint: psk if isinstance(psk, tuple) else (psk, b"")
_ssl_set_psk_client_callback(sock, cb)
_ssl_set_psk_client_callback(sslobj, cb)


class SSLPSKContext(ssl.SSLContext):
@property
def psk(self):
return getattr(self, "_psk", None)

@psk.setter
def psk(self, psk):
self._psk = psk

@property
def hint(self):
return getattr(self, "_hint", None)

@hint.setter
def hint(self, hint):
self._hint = hint


class SSLPSKObject(ssl.SSLObject):
def do_handshake(self, *args, **kwargs):
_ssl_setup_psk_callbacks(self)
super().do_handshake(*args, **kwargs)


class SSLPSKSocket(ssl.SSLSocket):
def do_handshake(self, *args, **kwargs):
_ssl_setup_psk_callbacks(self)
super().do_handshake(*args, **kwargs)


SSLPSKContext.sslobject_class = SSLPSKObject
SSLPSKContext.sslsocket_class = SSLPSKSocket


if do_handshake_on_connect:
sock.do_handshake()
def wrap_socket(sock, psk, hint=None,
server_side=False,
ssl_version=ssl.PROTOCOL_TLS,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
ciphers=None):
context = SSLPSKContext(ssl_version)
if ciphers:
context.set_ciphers(ciphers)
context.psk = psk
context.hint = hint

return sock
return context.wrap_socket(
sock=sock, server_side=server_side,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs
)