Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions annotation/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,10 @@ def _check_type_constraint(value, constraint):
else:
return False

class IncorrectArgument(TypeError):
def __str__(self):
return '{} expected {!r} actual {!r}'.format(*self.args)

def _check_argument_types(signature, *args, **kwargs):
"""Check that the arguments of a function match the given signature."""
bound_arguments = signature.bind(*args, **kwargs)
Expand All @@ -464,16 +468,19 @@ def _check_argument_types(signature, *args, **kwargs):
if annotation is EMPTY_ANNOTATION:
annotation = AnyType
if not _check_type_constraint(value, annotation):
raise TypeError('Incorrect type for "{0}"'.format(name))
raise IncorrectArgument(name, annotation, value)

class IncorrectReturnType(TypeError):
def __str__(self):
return 'expected: {!r} actual: {!r}'.format(*self.args)

def _check_return_type(signature, return_value):
"""Check that the return value of a function matches the signature."""
annotation = signature.return_annotation
if annotation is EMPTY_ANNOTATION:
annotation = AnyType
if not _check_type_constraint(return_value, annotation):
raise TypeError('Incorrect return type')
raise IncorrectReturnType(annotation, return_value)
return return_value


Expand Down
53 changes: 26 additions & 27 deletions tests/test_annontations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from collections import namedtuple

from annotation.typed import (typechecked, Interface, union, AnyType, predicate,
optional, typedef, options, only)

optional, typedef, options, only, IncorrectReturnType, IncorrectArgument)


class TypecheckedTest(unittest.TestCase):
Expand All @@ -15,8 +14,8 @@ def test(a: int):
return a

self.assertEqual(1, test(1))
self.assertRaises(TypeError, test, 'string')
self.assertRaises(TypeError, test, 1.2)
self.assertRaises(IncorrectArgument, test, 'string')
self.assertRaises(IncorrectArgument, test, 1.2)

def test_single_argument_with_class(self):

Expand All @@ -29,7 +28,7 @@ def test(a: MyClass):

value = MyClass()
self.assertEqual(value, test(value))
self.assertRaises(TypeError, test, 'string')
self.assertRaises(IncorrectArgument, test, 'string')

def test_single_argument_with_subclass(self):

Expand All @@ -42,7 +41,7 @@ def test(a: MyClass):

value = MySubClass()
self.assertEqual(value, test(value))
self.assertRaises(TypeError, test, 'string')
self.assertRaises(IncorrectArgument, test, 'string')

def test_single_argument_with_union_annotation(self):
from decimal import Decimal
Expand All @@ -54,7 +53,7 @@ def test(a: union(int, float, Decimal)):
self.assertEqual(1, test(1))
self.assertEqual(1.5, test(1.5))
self.assertEqual(Decimal('2.5'), test(Decimal('2.5')))
self.assertRaises(TypeError, test, 'string')
self.assertRaises(IncorrectArgument, test, 'string')

def test_single_argument_with_predicate_annotation(self):

Expand All @@ -63,7 +62,7 @@ def test(a: predicate(lambda x: x > 0)):
return a

self.assertEqual(1, test(1))
self.assertRaises(TypeError, test, 0)
self.assertRaises(IncorrectArgument, test, 0)

def test_single_argument_with_optional_annotation(self):

Expand Down Expand Up @@ -91,7 +90,7 @@ def f2(a: str, b: str):
pass

self.assertEqual({}, test(f1))
self.assertRaises(TypeError, test, f2)
self.assertRaises(IncorrectArgument, test, f2)

def test_single_argument_with_options_annotation(self):

Expand All @@ -101,7 +100,7 @@ def test(a: options('open', 'write')):

self.assertEqual('open', test('open'))
self.assertEqual('write', test('write'))
self.assertRaises(TypeError, test, 'other')
self.assertRaises(IncorrectArgument, test, 'other')

def test_single_argument_with_only_annotation(self):

Expand All @@ -110,7 +109,7 @@ def test(a: only(int)):
return a

self.assertEqual(1, test(1))
self.assertRaises(TypeError, test, True)
self.assertRaises(IncorrectArgument, test, True)

def test_single_argument_with_interface(self):

Expand All @@ -129,7 +128,7 @@ def test(a: Test):
return 1

self.assertEqual(1, test(TestImplementation()))
self.assertRaises(TypeError, test, Other())
self.assertRaises(IncorrectArgument, test, Other())

def test_single_argument_with_no_annotation(self):

Expand All @@ -148,17 +147,17 @@ def test(a: int, b: str):
return a, b

self.assertEqual((1, 'string'), test(1, 'string'))
self.assertRaises(TypeError, test, 1, 1)
self.assertRaises(TypeError, test, 'string', 'string')
self.assertRaises(TypeError, test, 'string', 1)
self.assertRaises(IncorrectArgument, test, 1, 1)
self.assertRaises(IncorrectArgument, test, 'string', 'string')
self.assertRaises(IncorrectArgument, test, 'string', 1)

def test_single_argument_with_none_value(self):

@typechecked
def test(a: int):
return a

self.assertRaises(TypeError, test, None)
self.assertRaises(IncorrectArgument, test, None)

def test_multiple_arguments_some_with_annotations(self):

Expand All @@ -168,8 +167,8 @@ def test(a, b: str):

self.assertEqual((1, 'string'), test(1, 'string'))
self.assertEqual(('string', 'string'), test('string', 'string'))
self.assertRaises(TypeError, test, 1, 1)
self.assertRaises(TypeError, test, 'string', 1)
self.assertRaises(IncorrectArgument, test, 1, 1)
self.assertRaises(IncorrectArgument, test, 'string', 1)

def test_return_with_builtin_type(self):

Expand All @@ -178,8 +177,8 @@ def test(a) -> int:
return a

self.assertEqual(1, test(1))
self.assertRaises(TypeError, test, 'string')
self.assertRaises(TypeError, test, 1.2)
self.assertRaises(IncorrectReturnType, test, 'string')
self.assertRaises(IncorrectReturnType, test, 1.2)

def test_return_with_class(self):

Expand All @@ -195,7 +194,7 @@ def test2() -> MyClass:
return 1

self.assertIsInstance(test1(), MyClass)
self.assertRaises(TypeError, test2)
self.assertRaises(IncorrectReturnType, test2)

def test_return_with_sublass(self):

Expand All @@ -211,7 +210,7 @@ def test2() -> MyClass:
return 1

self.assertIsInstance(test1(), MyClass)
self.assertRaises(TypeError, test2)
self.assertRaises(IncorrectReturnType, test2)

def test_return_with_union(self):

Expand All @@ -221,7 +220,7 @@ def test(a) -> union(int, float):

self.assertEqual(1, test(1))
self.assertEqual(1.1, test(1.1))
self.assertRaises(TypeError, test, 'string')
self.assertRaises(IncorrectReturnType, test, 'string')

def test_return_with_interface(self):

Expand All @@ -244,15 +243,15 @@ def test2() -> Test:
return 1

self.assertIsInstance(test1(), TestImplementation)
self.assertRaises(TypeError, test2)
self.assertRaises(IncorrectReturnType, test2)

def test_return_with_none_value(self):

@typechecked
def test(a) -> int:
return a

self.assertRaises(TypeError, test, None)
self.assertRaises(IncorrectReturnType, test, None)

def test_complex_types(self):
simple_types = [ 'a', 1, None, 1.1, False ]
Expand Down Expand Up @@ -328,12 +327,12 @@ def test(a: check.test) -> check.test:
@typechecked
def test(a: check.test):
pass
self.assertRaises(TypeError, test, value)
self.assertRaises(IncorrectArgument, test, value)

@typechecked
def test(a) -> check.test:
return value
self.assertRaises(TypeError, test, value)
self.assertRaises(IncorrectReturnType, test, value)

class UnionTest(unittest.TestCase):

Expand Down