From b3f9228cb4802facca2d16da8b716182fea5c813 Mon Sep 17 00:00:00 2001 From: Nikita Yakovlev Date: Tue, 21 Oct 2025 17:34:57 +0300 Subject: [PATCH 1/2] rise if redis is not async --- fast_cache_middleware/__init__.py | 1 - fast_cache_middleware/storages/redis_storage.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/fast_cache_middleware/__init__.py b/fast_cache_middleware/__init__.py index 4dabaa3..71a8b8c 100644 --- a/fast_cache_middleware/__init__.py +++ b/fast_cache_middleware/__init__.py @@ -8,7 +8,6 @@ TODO: - add check for dependencies for middleware exists. and raise error if not. - - automatically add x-cache-age to the OpenAPI schema (openapi_extra) based on caching dependency. """ from .controller import Controller diff --git a/fast_cache_middleware/storages/redis_storage.py b/fast_cache_middleware/storages/redis_storage.py index ccd88ee..f978355 100644 --- a/fast_cache_middleware/storages/redis_storage.py +++ b/fast_cache_middleware/storages/redis_storage.py @@ -31,9 +31,9 @@ def __init__( ttl: Optional[Union[int, float]] = None, namespace: str = "cache", ) -> None: - if redis is None: - raise ImportError( - "Redis is required for RedisStorage. " + if not isinstance(redis_client, redis.Redis) or redis_client is None: + raise StorageError( + "Redis async is required for RedisStorage. " "Install with Redis: fast-cache-middleware[redis]" ) From e07b6b20f834d3b406259bd6e03faad81bed90ea Mon Sep 17 00:00:00 2001 From: Nikita Yakovlev Date: Thu, 30 Oct 2025 14:58:29 +0300 Subject: [PATCH 2/2] fix redis storage test --- tests/storages/test_redis_storage.py | 75 ++++++++++++++++------------ 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/tests/storages/test_redis_storage.py b/tests/storages/test_redis_storage.py index bf5b3d2..6d47184 100644 --- a/tests/storages/test_redis_storage.py +++ b/tests/storages/test_redis_storage.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from redis.asyncio import Redis as AsyncRedis from starlette.requests import Request from starlette.responses import Response @@ -28,7 +29,7 @@ async def test_redis_storage_init_validation( ttl: float, expect_error: Type[BaseException] | None ) -> None: - mock_redis = AsyncMock() + mock_redis = AsyncMock(spec=AsyncRedis) if expect_error: with pytest.raises(expect_error): @@ -41,7 +42,11 @@ async def test_redis_storage_init_validation( @pytest.mark.asyncio async def test_store_and_retrieve_works() -> None: - mock_redis = AsyncMock() + mock_redis = AsyncMock(spec=AsyncRedis) + mock_redis.exists = AsyncMock(return_value=True) + mock_redis.set = AsyncMock() + mock_redis.get = AsyncMock() + mock_redis.delete = AsyncMock() mock_serializer = MagicMock() serialized_value = b"serialized" @@ -56,8 +61,6 @@ async def test_store_and_retrieve_works() -> None: response = Response(content="hello", status_code=200) metadata: dict[str, str | int] = {} - mock_redis.exists.return_value = True - await storage.set("key1", response, request, metadata) mock_redis.set.assert_awaited_with("cache:key1", serialized_value, ex=1) @@ -69,7 +72,10 @@ async def test_store_and_retrieve_works() -> None: @pytest.mark.asyncio async def test_store_overwrites_existing_key() -> None: - mock_redis = AsyncMock() + mock_redis = AsyncMock(spec=AsyncRedis) + mock_redis.exists = AsyncMock(return_value=True) + mock_redis.delete = AsyncMock() + mock_redis.set = AsyncMock() mock_serializer = MagicMock() serialized_value = b"serialized" @@ -81,8 +87,6 @@ async def test_store_overwrites_existing_key() -> None: response = Response(content="updated", status_code=200) metadata: dict[str, str] = {} - mock_redis.exists.return_value = True - await storage.set("existing_key", response, request, metadata) mock_redis.delete.assert_awaited_with("cache:existing_key") @@ -91,62 +95,66 @@ async def test_store_overwrites_existing_key() -> None: @pytest.mark.asyncio async def test_retrieve_returns_none_on_missing_key() -> None: - mock_redis = AsyncMock() + mock_redis = AsyncMock(spec=AsyncRedis) + mock_redis.exists = AsyncMock(return_value=False) + mock_redis.get = AsyncMock(return_value=None) + storage = RedisStorage(redis_client=mock_redis) - mock_redis.get.return_value = None - with pytest.raises(NotFoundStorageError, match="Data not found"): + with pytest.raises(TTLExpiredStorageError, match="cache:missing"): await storage.get("missing") @pytest.mark.asyncio async def test_retrieve_returns_none_on_deserialization_error() -> None: - mock_redis = AsyncMock() + mock_redis = AsyncMock(spec=AsyncRedis) + mock_redis.exists = AsyncMock(return_value=True) + mock_redis.get = AsyncMock(return_value=b"invalid") + mock_serializer = MagicMock() def raise_error(_): - raise NotFoundStorageError("missing") + raise NotFoundStorageError("Data not found") - mock_serializer = MagicMock() mock_serializer.loads = raise_error - mock_serializer.dumps = AsyncMock(return_value=b"serialized") storage = RedisStorage(redis_client=mock_redis, serializer=mock_serializer) - mock_redis.get.return_value = b"invalid" - with pytest.raises(NotFoundStorageError, match="Data not found"): await storage.get("missing") @pytest.mark.asyncio async def test_retrieve_returns_none_if_ttl_expired() -> None: - mock_redis = AsyncMock() - - def raise_error(_) -> None: - raise TTLExpiredStorageError("corrupt") + mock_redis = AsyncMock(spec=AsyncRedis) + mock_redis.exists = AsyncMock(return_value=True) + mock_redis.get = AsyncMock(return_value=b"invalid") mock_serializer = MagicMock() - mock_serializer.loads = raise_error + def raise_error(_): + raise TTLExpiredStorageError("TTL expired. Key: cache:corrupt") + + mock_serializer.loads = raise_error mock_serializer.dumps = AsyncMock(return_value=b"serialized") storage = RedisStorage(redis_client=mock_redis, serializer=mock_serializer) - mock_redis.get.return_value = b"invalid" - with pytest.raises(TTLExpiredStorageError, match="TTL expired"): - result = await storage.get("corrupt") - print(result) + await storage.get("corrupt") @pytest.mark.asyncio async def test_remove_by_regex() -> None: - mock_redis = AsyncMock() + mock_redis = AsyncMock(spec=AsyncRedis) storage = RedisStorage(redis_client=mock_redis, namespace="myspace") pattern = re.compile(r"^/api/.*") - mock_redis.scan.return_value = (0, ["myspace:/api/test1", "myspace:/api/test2"]) + + mock_redis.scan = AsyncMock( + return_value=(0, ["myspace:/api/test1", "myspace:/api/test2"]) + ) + mock_redis.delete = AsyncMock() await storage.delete(pattern) @@ -157,19 +165,24 @@ async def test_remove_by_regex() -> None: @pytest.mark.asyncio async def test_remove_with_no_matches_logs_warning() -> None: - mock_redis = AsyncMock() + mock_redis = AsyncMock(spec=AsyncRedis) storage = RedisStorage(redis_client=mock_redis, namespace="myspace") pattern = re.compile(r"^/nothing.*") - mock_redis.scan.return_value = (0, []) + + mock_redis.scan = AsyncMock(return_value=(0, [])) + mock_redis.delete = AsyncMock() await storage.delete(pattern) + mock_redis.delete.assert_not_called() @pytest.mark.asyncio async def test_close_flushes_database() -> None: - mock_redis = AsyncMock() + mock_redis = AsyncMock(spec=AsyncRedis) + mock_redis.flushdb = AsyncMock() + storage = RedisStorage(redis_client=mock_redis) await storage.close() @@ -177,7 +190,7 @@ async def test_close_flushes_database() -> None: def test_full_key() -> None: - mock_redis = AsyncMock() + mock_redis = AsyncMock(spec=AsyncRedis) storage = RedisStorage(redis_client=mock_redis, namespace="custom") assert storage._full_key("abc") == "custom:abc"