diff --git a/.travis.yml b/.travis.yml index ede2df1..f6a6ac9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,10 @@ language: python python: - - "2.7" + - "2.7" + - "3.4" install: - - pip install tox -script: tox + - "pip install -e ." + - "pip install flake8" +script: + - "python setup.py test" + - "flake8" diff --git a/jose.py b/jose/__init__.py similarity index 84% rename from jose.py rename to jose/__init__.py index 039f1ec..2e7591f 100644 --- a/jose.py +++ b/jose/__init__.py @@ -11,6 +11,7 @@ import datetime from base64 import urlsafe_b64encode, urlsafe_b64decode +import binascii from collections import namedtuple from time import time @@ -21,6 +22,13 @@ from Crypto.Signature import PKCS1_v1_5 as PKCS1_v1_5_SIG +try: + # python 2 compatibility + unicode +except NameError: + unicode = str + + __all__ = ['encrypt', 'decrypt', 'sign', 'verify'] @@ -125,10 +133,17 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', :raises: :class:`~jose.Error` if there is an error producing the JWE """ - header = dict((add_header or {}).items() + [ - ('enc', enc), ('alg', alg)]) + header = {} - plaintext = json_encode(claims) + if add_header: + header.update(add_header) + + header.update({ + 'enc': enc, + 'alg': alg, + }) + + plaintext = json_encode(claims).encode('utf-8') # compress (if required) if compression is not None: @@ -145,6 +160,10 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', iv = rng(AES.block_size) encryption_key = rng((key_size // 8) + hash_mod.digest_size) + if not isinstance(adata, bytes): + # TODO this should probably just be an error + adata = adata.encode('utf-8') + ciphertext = cipher(plaintext, encryption_key[:-hash_mod.digest_size], iv) hash = hash_fn(_jwe_hash_str(plaintext, iv, adata), encryption_key[-hash_mod.digest_size:], hash_mod) @@ -153,15 +172,16 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', (cipher, _), _ = JWA[alg] encryption_key_ciphertext = cipher(encryption_key, jwk) - return JWE(*map(b64encode_url, - (json_encode(header), + return JWE(*list(map(b64encode_url, + (json_encode(header).encode('utf-8'), encryption_key_ciphertext, iv, ciphertext, - auth_tag(hash)))) + auth_tag(hash))))) -def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): + +def decrypt(jwe, jwk, adata=b'', validate_claims=True, expiry_seconds=None): """ Decrypts a deserialized :class:`~jose.JWE` :param jwe: An instance of :class:`~jose.JWE` @@ -182,7 +202,8 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): """ header, encryption_key_ciphertext, iv, ciphertext, tag = map( b64decode_url, jwe) - header = json_decode(header) + header = json_decode(header.decode('utf-8')) + # decrypt cek (_, decipher), _ = JWA[header['alg']] @@ -191,6 +212,10 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): # decrypt body ((_, decipher), _), ((hash_fn, _), mod) = JWA[header['enc']] + if not isinstance(adata, bytes): + # TODO this should probably just be an error + adata = adata.encode('utf-8') + plaintext = decipher(ciphertext, encryption_key[:-mod.digest_size], iv) hash = hash_fn(_jwe_hash_str(plaintext, iv, adata), encryption_key[-mod.digest_size:], mod=mod) @@ -207,7 +232,7 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): plaintext = decompress(plaintext) - claims = json_decode(plaintext) + claims = json_decode(plaintext.decode('utf-8')) _validate(claims, validate_claims, expiry_seconds) return JWT(header, claims) @@ -227,8 +252,17 @@ def sign(claims, jwk, add_header=None, alg='HS256'): """ (hash_fn, _), mod = JWA[alg] - header = dict((add_header or {}).items() + [('alg', alg)]) - header, payload = map(b64encode_url, map(json_encode, (header, claims))) + header = {} + + if add_header: + header.update(add_header) + + header.update({ + 'alg': alg, + }) + + header = b64encode_url(json_encode(header).encode('utf-8')) + payload = b64encode_url(json_encode(claims).encode('utf-8')) sig = b64encode_url(hash_fn(_jws_hash_str(header, payload), jwk['k'], mod=mod)) @@ -254,14 +288,14 @@ def verify(jws, jwk, validate_claims=True, expiry_seconds=None): :raises: :class:`~jose.Error` if there is an error decrypting the JWE """ header, payload, sig = map(b64decode_url, jws) - header = json_decode(header) + header = json_decode(header.decode('utf-8')) (_, verify_fn), mod = JWA[header['alg']] if not verify_fn(_jws_hash_str(jws.header, jws.payload), jwk['k'], sig, mod=mod): raise Error('Mismatched signatures') - claims = json_decode(b64decode_url(jws.payload)) + claims = json_decode(b64decode_url(jws.payload).decode('utf-8')) _validate(claims, validate_claims, expiry_seconds) return JWT(header, claims) @@ -270,28 +304,35 @@ def verify(jws, jwk, validate_claims=True, expiry_seconds=None): def b64decode_url(istr): """ JWT Tokens may be truncated without the usual trailing padding '=' symbols. Compensate by padding to the nearest 4 bytes. + + :param istr: A unicode string to decode + :returns: The byte string represented by `istr` """ - istr = encode_safe(istr) + # unicode check for python 2 compatibility + if not isinstance(istr, (str, unicode)): + raise ValueError("expected string, got %r" % type(istr)) + + # required for python 2 as urlsafe_b64decode does not like unicode objects + # safe as b64 encoded string should be only ascii anyway + istr = str(istr) + try: return urlsafe_b64decode(istr + '=' * (4 - (len(istr) % 4))) - except TypeError as e: + except (TypeError, binascii.Error) as e: raise Error('Unable to decode base64: %s' % (e)) def b64encode_url(istr): """ JWT Tokens may be truncated without the usual trailing padding '=' symbols. Compensate by padding to the nearest 4 bytes. - """ - return urlsafe_b64encode(encode_safe(istr)).rstrip('=') - -def encode_safe(istr, encoding='utf8'): - try: - return istr.encode(encoding) - except UnicodeDecodeError: - # this will fail if istr is already encoded - pass - return istr + :param istr: a byte string to encode + :returns: The base64 representation of the input byte string as a regular + `str` object + """ + if not isinstance(istr, bytes): + raise Exception("expected bytestring") + return urlsafe_b64encode(istr).rstrip(b'=').decode('ascii') def auth_tag(hmac): @@ -302,11 +343,17 @@ def auth_tag(hmac): def pad_pkcs7(s): sz = AES.block_size - (len(s) % AES.block_size) - return s + (chr(sz) * sz) + # TODO would be cleaner to do `bytes(sz) * sz` but python 2 behaves + # strangely + return s + (chr(sz) * sz).encode('ascii') def unpad_pkcs7(s): - return s[:-ord(s[-1])] + try: + return s[:-ord(s[-1])] + # Python 3 compatibility + except TypeError: + return s[:-s[-1]] def encrypt_oaep(plaintext, jwk): @@ -361,9 +408,15 @@ def const_compare(stra, strb): if len(stra) != len(strb): return False + try: + # python 2 compatibility + orda, ordb = list(map(ord, stra)), list(map(ord, strb)) + except TypeError: + orda, ordb = stra, strb + res = 0 - for a, b in zip(stra, strb): - res |= ord(a) ^ ord(b) + for a, b in zip(orda, ordb): + res |= a ^ b return res == 0 @@ -491,19 +544,19 @@ def _validate(claims, validate_claims, expiry_seconds): _check_not_before(now, not_before) -def _jwe_hash_str(plaintext, iv, adata=''): +def _jwe_hash_str(plaintext, iv, adata=b''): # http://tools.ietf.org/html/ # draft-ietf-jose-json-web-algorithms-24#section-5.2.2.1 - return '.'.join((adata, iv, plaintext, str(len(adata)))) + return b'.'.join((adata, iv, plaintext, bytes(len(adata)))) def _jws_hash_str(header, claims): - return '.'.join((header, claims)) + return b'.'.join((header.encode('ascii'), claims.encode('ascii'))) def cli_decrypt(jwt, key): - print decrypt(deserialize_compact(jwt), {'k':key}, - validate_claims=False) + print(decrypt(deserialize_compact(jwt), {'k':key}, + validate_claims=False)) def _cli(): diff --git a/tests.py b/jose/tests/__init__.py similarity index 80% rename from tests.py rename to jose/tests/__init__.py index dc1edb9..9e4cf76 100644 --- a/tests.py +++ b/jose/tests/__init__.py @@ -28,7 +28,7 @@ def test_serialize(self): jose.deserialize_compact('1.2.3.4') self.fail() except jose.Error as e: - self.assertEqual(e.message, 'Malformed JWT') + self.assertEqual(str(e), 'Malformed JWT') class TestJWE(unittest.TestCase): @@ -43,7 +43,7 @@ def test_jwe(self): # make sure the body can't be loaded as json (should be encrypted) try: - json.loads(jose.b64decode_url(jwe.ciphertext)) + json.loads(jose.b64decode_url(jwe.ciphertext).decode('utf-8')) self.fail() except ValueError: pass @@ -59,7 +59,7 @@ def test_jwe(self): jose.decrypt(jose.deserialize_compact(token), bad_key) self.fail() except jose.Error as e: - self.assertEqual(e.message, 'Incorrect decryption.') + self.assertEqual(str(e), 'Incorrect decryption.') def test_jwe_add_header(self): add_header = {'foo': 'bar'} @@ -85,25 +85,24 @@ def test_jwe_adata(self): rsa_priv_key) self.fail() except jose.Error as e: - self.assertEqual(e.message, 'Mismatched authentication tags') + self.assertEqual(str(e), 'Mismatched authentication tags') self.assertEqual(jwt.claims, claims) def test_jwe_invalid_base64(self): claims = {jose.CLAIM_EXPIRATION_TIME: int(time()) - 5} et = jose.serialize_compact(jose.encrypt(claims, rsa_pub_key)) - bad = b'\x00' + et + bad = '\x00' + et try: jose.decrypt(jose.deserialize_compact(bad), rsa_priv_key) - self.fail() # expecting error due to invalid base64 except jose.Error as e: - pass - - self.assertEquals( - e.args[0], - 'Unable to decode base64: Incorrect padding' - ) + self.assertEqual( + e.args[0], + 'Unable to decode base64: Incorrect padding' + ) + else: + self.fail() # expecting error due to invalid base64 def test_jwe_no_error_with_exp_claim(self): claims = {jose.CLAIM_EXPIRATION_TIME: int(time()) + 5} @@ -116,16 +115,15 @@ def test_jwe_expired_error_with_exp_claim(self): try: jose.decrypt(jose.deserialize_compact(et), rsa_priv_key) - self.fail() # expecting expired token except jose.Expired as e: - pass - - self.assertEquals( - e.args[0], - 'Token expired at {}'.format( - jose._format_timestamp(claims[jose.CLAIM_EXPIRATION_TIME]) + self.assertEqual( + e.args[0], + 'Token expired at {}'.format( + jose._format_timestamp(claims[jose.CLAIM_EXPIRATION_TIME]) + ) ) - ) + else: + self.fail() # expecting expired token def test_jwe_no_error_with_iat_claim(self): claims = {jose.CLAIM_ISSUED_AT: int(time()) - 15} @@ -142,17 +140,16 @@ def test_jwe_expired_error_with_iat_claim(self): try: jose.decrypt(jose.deserialize_compact(et), rsa_priv_key, expiry_seconds=expiry_seconds) - self.fail() # expecting expired token except jose.Expired as e: - pass - - expiration_time = claims[jose.CLAIM_ISSUED_AT] + expiry_seconds - self.assertEquals( - e.args[0], - 'Token expired at {}'.format( - jose._format_timestamp(expiration_time) + expiration_time = claims[jose.CLAIM_ISSUED_AT] + expiry_seconds + self.assertEqual( + e.args[0], + 'Token expired at {}'.format( + jose._format_timestamp(expiration_time) + ) ) - ) + else: + self.fail() # expecting expired token def test_jwe_no_error_with_nbf_claim(self): claims = {jose.CLAIM_NOT_BEFORE: int(time()) - 5} @@ -165,16 +162,15 @@ def test_jwe_not_yet_valid_error_with_nbf_claim(self): try: jose.decrypt(jose.deserialize_compact(et), rsa_priv_key) - self.fail() # expecting not valid yet except jose.NotYetValid as e: - pass - - self.assertEquals( - e.args[0], - 'Token not valid until {}'.format( - jose._format_timestamp(claims[jose.CLAIM_NOT_BEFORE]) + self.assertEqual( + e.args[0], + 'Token not valid until {}'.format( + jose._format_timestamp(claims[jose.CLAIM_NOT_BEFORE]) + ) ) - ) + else: + self.fail() # expecting not valid yet def test_jwe_ignores_expired_token_if_validate_claims_is_false(self): claims = {jose.CLAIM_EXPIRATION_TIME: int(time()) - 5} @@ -183,7 +179,7 @@ def test_jwe_ignores_expired_token_if_validate_claims_is_false(self): validate_claims=False) def test_format_timestamp(self): - self.assertEquals( + self.assertEqual( jose._format_timestamp(1403054056), '2014-06-18T01:14:16Z' ) @@ -191,7 +187,7 @@ def test_format_timestamp(self): def test_jwe_compression(self): local_claims = copy(claims) - for v in xrange(1000): + for v in range(1000): local_claims['dummy_' + str(v)] = '0' * 100 jwe = jose.serialize_compact(jose.encrypt(local_claims, rsa_pub_key)) @@ -210,21 +206,23 @@ def test_jwe_compression(self): def test_encrypt_invalid_compression_error(self): try: jose.encrypt(claims, rsa_pub_key, compression='BAD') - self.fail() except jose.Error: pass + else: + self.fail() def test_decrypt_invalid_compression_error(self): jwe = jose.encrypt(claims, rsa_pub_key, compression='DEF') - header = jose.b64encode_url('{"alg": "RSA-OAEP", ' - '"enc": "A128CBC-HS256", "zip": "BAD"}') + header = jose.b64encode_url(b'{"alg": "RSA-OAEP", ' + b'"enc": "A128CBC-HS256", "zip": "BAD"}') try: jose.decrypt(jose.JWE(*((header,) + (jwe[1:]))), rsa_priv_key) - self.fail() except jose.Error as e: - self.assertEqual(e.message, + self.assertEqual(str(e), 'Unsupported compression algorithm: BAD') + else: + self.fail() class TestJWS(unittest.TestCase): @@ -254,7 +252,7 @@ def test_jws_signature_mismatch_error(self): try: jose.verify(jose.JWS(jws.header, jws.payload, 'asd'), jwk) except jose.Error as e: - self.assertEqual(e.message, 'Mismatched signatures') + self.assertEqual(str(e), 'Mismatched signatures') class TestUtils(unittest.TestCase): @@ -264,18 +262,18 @@ def test_b64encode_url_utf8(self): self.assertEqual(jose.b64decode_url(encoded), istr) def test_b64encode_url_ascii(self): - istr = 'eric idle' + istr = b'eric idle' encoded = jose.b64encode_url(istr) self.assertEqual(jose.b64decode_url(encoded), istr) def test_b64encode_url(self): - istr = '{"alg": "RSA-OAEP", "enc": "A128CBC-HS256"}' + istr = b'{"alg": "RSA-OAEP", "enc": "A128CBC-HS256"}' # sanity check - self.assertEqual(b64encode(istr)[-1], '=') + self.assertTrue(b64encode(istr).endswith(b'=')) # actual test - self.assertNotEqual(jose.b64encode_url(istr), '=') + self.assertFalse(jose.b64encode_url(istr).endswith('=')) class TestJWA(unittest.TestCase): @@ -298,8 +296,14 @@ def test_invalid_error(self): jose.JWA['bad'] self.fail() except jose.Error as e: - self.assertTrue(e.message.startswith('Unsupported')) + self.assertTrue(str(e).startswith('Unsupported')) -if __name__ == '__main__': - unittest.main() +loader = unittest.TestLoader() +suite = unittest.TestSuite(( + loader.loadTestsFromTestCase(TestSerializeDeserialize), + loader.loadTestsFromTestCase(TestJWE), + loader.loadTestsFromTestCase(TestJWS), + loader.loadTestsFromTestCase(TestUtils), + loader.loadTestsFromTestCase(TestJWA), +)) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 9fadf90..0000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -pycrypto >= 2.6 diff --git a/setup.py b/setup.py index f56f5cc..1ce8800 100644 --- a/setup.py +++ b/setup.py @@ -5,8 +5,9 @@ from setuptools.command.bdist_rpm import bdist_rpm as _bdist_rpm here = os.path.abspath(os.path.dirname(__file__)) -REQUIRES = filter(lambda s: len(s) > 0, - open(os.path.join(here, 'requirements.txt')).read().split('\n')) +REQUIRES = [ + 'pycrypto >= 2.6', +] pkg_name = 'jose' pyver = ''.join(('python', '.'.join(map(str, sys.version_info[:2])))) @@ -61,4 +62,5 @@ def finalize_package_data(self): 'jose = jose:_cli', ) }, + test_suite='jose.tests.suite', ) diff --git a/tox.ini b/tox.ini index 56a1d96..753b1dc 100644 --- a/tox.ini +++ b/tox.ini @@ -2,8 +2,7 @@ ignore = E128 [tox] -envlist=py27 +envlist = py27, py34 [testenv] -deps=nose -commands=nosetests +commands = python setup.py test