Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion fast_cache_middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

TODO:
- add check for dependencies for middleware exists. and raise error if not.
- automatically add x-cache-age to the OpenAPI schema (openapi_extra) based on caching dependency.
"""

from .controller import Controller
Expand Down
6 changes: 3 additions & 3 deletions fast_cache_middleware/storages/redis_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def __init__(
ttl: Optional[Union[int, float]] = None,
namespace: str = "cache",
) -> None:
if redis is None:
raise ImportError(
"Redis is required for RedisStorage. "
if not isinstance(redis_client, redis.Redis) or redis_client is None:
raise StorageError(
"Redis async is required for RedisStorage. "
"Install with Redis: fast-cache-middleware[redis]"
)

Expand Down
75 changes: 44 additions & 31 deletions tests/storages/test_redis_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import AsyncMock, MagicMock

import pytest
from redis.asyncio import Redis as AsyncRedis
from starlette.requests import Request
from starlette.responses import Response

Expand All @@ -28,7 +29,7 @@
async def test_redis_storage_init_validation(
ttl: float, expect_error: Type[BaseException] | None
) -> None:
mock_redis = AsyncMock()
mock_redis = AsyncMock(spec=AsyncRedis)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Я тут задумался о том, что бы убрать моки и гонять тесты, развёртывая реальный Редис в контейнере, Action такое предоставляет.

Сложность будет в том, что локально тоже надо будет запускать у себя контейнер что б тесты не фейлились.


if expect_error:
with pytest.raises(expect_error):
Expand All @@ -41,7 +42,11 @@ async def test_redis_storage_init_validation(

@pytest.mark.asyncio
async def test_store_and_retrieve_works() -> None:
mock_redis = AsyncMock()
mock_redis = AsyncMock(spec=AsyncRedis)
mock_redis.exists = AsyncMock(return_value=True)
mock_redis.set = AsyncMock()
mock_redis.get = AsyncMock()
mock_redis.delete = AsyncMock()

mock_serializer = MagicMock()
serialized_value = b"serialized"
Expand All @@ -56,8 +61,6 @@ async def test_store_and_retrieve_works() -> None:
response = Response(content="hello", status_code=200)
metadata: dict[str, str | int] = {}

mock_redis.exists.return_value = True

await storage.set("key1", response, request, metadata)
mock_redis.set.assert_awaited_with("cache:key1", serialized_value, ex=1)

Expand All @@ -69,7 +72,10 @@ async def test_store_and_retrieve_works() -> None:

@pytest.mark.asyncio
async def test_store_overwrites_existing_key() -> None:
mock_redis = AsyncMock()
mock_redis = AsyncMock(spec=AsyncRedis)
mock_redis.exists = AsyncMock(return_value=True)
mock_redis.delete = AsyncMock()
mock_redis.set = AsyncMock()

mock_serializer = MagicMock()
serialized_value = b"serialized"
Expand All @@ -81,8 +87,6 @@ async def test_store_overwrites_existing_key() -> None:
response = Response(content="updated", status_code=200)
metadata: dict[str, str] = {}

mock_redis.exists.return_value = True

await storage.set("existing_key", response, request, metadata)

mock_redis.delete.assert_awaited_with("cache:existing_key")
Expand All @@ -91,62 +95,66 @@ async def test_store_overwrites_existing_key() -> None:

@pytest.mark.asyncio
async def test_retrieve_returns_none_on_missing_key() -> None:
mock_redis = AsyncMock()
mock_redis = AsyncMock(spec=AsyncRedis)
mock_redis.exists = AsyncMock(return_value=False)
mock_redis.get = AsyncMock(return_value=None)

storage = RedisStorage(redis_client=mock_redis)
mock_redis.get.return_value = None

with pytest.raises(NotFoundStorageError, match="Data not found"):
with pytest.raises(TTLExpiredStorageError, match="cache:missing"):
await storage.get("missing")


@pytest.mark.asyncio
async def test_retrieve_returns_none_on_deserialization_error() -> None:
mock_redis = AsyncMock()
mock_redis = AsyncMock(spec=AsyncRedis)
mock_redis.exists = AsyncMock(return_value=True)
mock_redis.get = AsyncMock(return_value=b"invalid")
mock_serializer = MagicMock()

def raise_error(_):
raise NotFoundStorageError("missing")
raise NotFoundStorageError("Data not found")

mock_serializer = MagicMock()
mock_serializer.loads = raise_error

mock_serializer.dumps = AsyncMock(return_value=b"serialized")

storage = RedisStorage(redis_client=mock_redis, serializer=mock_serializer)

mock_redis.get.return_value = b"invalid"

with pytest.raises(NotFoundStorageError, match="Data not found"):
await storage.get("missing")


@pytest.mark.asyncio
async def test_retrieve_returns_none_if_ttl_expired() -> None:
mock_redis = AsyncMock()

def raise_error(_) -> None:
raise TTLExpiredStorageError("corrupt")
mock_redis = AsyncMock(spec=AsyncRedis)
mock_redis.exists = AsyncMock(return_value=True)
mock_redis.get = AsyncMock(return_value=b"invalid")

mock_serializer = MagicMock()
mock_serializer.loads = raise_error

def raise_error(_):
raise TTLExpiredStorageError("TTL expired. Key: cache:corrupt")

mock_serializer.loads = raise_error
mock_serializer.dumps = AsyncMock(return_value=b"serialized")

storage = RedisStorage(redis_client=mock_redis, serializer=mock_serializer)

mock_redis.get.return_value = b"invalid"

with pytest.raises(TTLExpiredStorageError, match="TTL expired"):
result = await storage.get("corrupt")
print(result)
await storage.get("corrupt")


@pytest.mark.asyncio
async def test_remove_by_regex() -> None:
mock_redis = AsyncMock()
mock_redis = AsyncMock(spec=AsyncRedis)
storage = RedisStorage(redis_client=mock_redis, namespace="myspace")

pattern = re.compile(r"^/api/.*")
mock_redis.scan.return_value = (0, ["myspace:/api/test1", "myspace:/api/test2"])

mock_redis.scan = AsyncMock(
return_value=(0, ["myspace:/api/test1", "myspace:/api/test2"])
)
mock_redis.delete = AsyncMock()

await storage.delete(pattern)

Expand All @@ -157,27 +165,32 @@ async def test_remove_by_regex() -> None:

@pytest.mark.asyncio
async def test_remove_with_no_matches_logs_warning() -> None:
mock_redis = AsyncMock()
mock_redis = AsyncMock(spec=AsyncRedis)
storage = RedisStorage(redis_client=mock_redis, namespace="myspace")

pattern = re.compile(r"^/nothing.*")
mock_redis.scan.return_value = (0, [])

mock_redis.scan = AsyncMock(return_value=(0, []))
mock_redis.delete = AsyncMock()

await storage.delete(pattern)

mock_redis.delete.assert_not_called()


@pytest.mark.asyncio
async def test_close_flushes_database() -> None:
mock_redis = AsyncMock()
mock_redis = AsyncMock(spec=AsyncRedis)
mock_redis.flushdb = AsyncMock()

storage = RedisStorage(redis_client=mock_redis)

await storage.close()
mock_redis.flushdb.assert_awaited_once()


def test_full_key() -> None:
mock_redis = AsyncMock()
mock_redis = AsyncMock(spec=AsyncRedis)
storage = RedisStorage(redis_client=mock_redis, namespace="custom")

assert storage._full_key("abc") == "custom:abc"