diff --git a/tests/test_cached.py b/tests/test_cached.py index dbebc83..911ae40 100644 --- a/tests/test_cached.py +++ b/tests/test_cached.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import contextvars from unittest.mock import AsyncMock, MagicMock, call import pytest @@ -68,6 +69,23 @@ async def test_params_are_passed_through(self): assert await decorated_fn("foo") == ("foo",) assert await decorated_fn("foo", bar="baz") == ("foo", ("bar", "baz")) + async def test_context_var_passed(self): + var = contextvars.ContextVar("var") + var.set("test") + + def read_contextvar(): + return var.get() + + async def example(): + var.set("example") + return read_contextvar() + + decorated_fn = cachetools_async.cached({})(example) + + actual = await decorated_fn() + assert actual == "example" + assert var.get() == "test" + async def test_multiple_calls(self): mock = AsyncMock()