From b56dfed0107239d7fc9ac10bcf533435eedf394f Mon Sep 17 00:00:00 2001 From: Ray Date: Thu, 18 Dec 2025 18:45:30 +1100 Subject: [PATCH 1/2] feat: extend key generator with Path, UUID, Decimal, Enum, datetime, and constrained numpy support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add cross-language compatible type normalization to CacheKeyGenerator._normalize(): - Path/PurePath: .as_posix() for cross-platform consistency - UUID: str(uuid) standard format - Decimal: str(decimal) exact representation - Enum: .value (recursive normalization) - datetime: .isoformat() (UTC only, naive rejected) - numpy.ndarray: 1D, ≤100KB, i32/i64/f32/f64 dtypes Array constraints (per security review): - 256-bit Blake2b hash (collision resistance) - Little-endian byte order (cross-platform determinism) - 100KB per-array limit, 5MB aggregate limit (DoS prevention) - Version prefix __array_v1__ for protocol evolution Add custom key= parameter escape hatch for complex types: - 2D arrays, DataFrames, large data, custom objects - Integrated in both sync and async wrapper paths Rejected by design: - set/frozenset (mixed-type sorting crashes) - pandas.DataFrame (Parquet non-deterministic) --- src/cachekit/config/decorator.py | 7 +- src/cachekit/decorators/wrapper.py | 28 +- src/cachekit/key_generator.py | 216 +++++++++++++- tests/unit/test_cache_key_generator.py | 389 ++++++++++++++++++++++++- tests/unit/test_custom_key_function.py | 232 +++++++++++++++ 5 files changed, 853 insertions(+), 19 deletions(-) create mode 100644 tests/unit/test_custom_key_function.py diff --git a/src/cachekit/config/decorator.py b/src/cachekit/config/decorator.py index 016cf69..9d70a51 100644 --- a/src/cachekit/config/decorator.py +++ b/src/cachekit/config/decorator.py @@ -172,6 +172,9 @@ def local_function(): integrity_checking: Enable checksums for corruption detection (default: True) All serializers use xxHash3-64 (8 bytes). Set to False for @cache.minimal (speed-first, no integrity guarantee) + key: Custom key function for complex types. Receives (*args, **kwargs) and returns str. + Use for numpy arrays, DataFrames, or cross-language cache sharing. + Example: @cache(key=lambda arr: hashlib.blake2b(arr.tobytes()).hexdigest()) refresh_ttl_on_get: Extend TTL on cache hit ttl_refresh_threshold: Minimum remaining TTL fraction (0.0-1.0) to trigger refresh backend: L2 backend (RedisBackend, HTTPBackend, None for L1-only) @@ -183,12 +186,13 @@ def local_function(): encryption: Client-side encryption configuration """ - # Core settings (5 fields) + # Core settings (6 fields) ttl: int | None = None namespace: str | None = None serializer: Union[str, SerializerProtocol] = "default" # type: ignore[assignment] # String name or protocol instance safe_mode: bool = False integrity_checking: bool = True # Checksums for corruption detection (xxHash3-64 for all serializers) + key: Callable[..., str] | None = None # Custom key function (escape hatch for complex types) # Performance (2 fields) refresh_ttl_on_get: bool = False @@ -251,6 +255,7 @@ def to_dict(self) -> dict[str, object]: "namespace": self.namespace, "serializer": self.serializer, "safe_mode": self.safe_mode, + "key": self.key, "refresh_ttl_on_get": self.refresh_ttl_on_get, "ttl_refresh_threshold": self.ttl_refresh_threshold, "backend": self.backend, diff --git a/src/cachekit/decorators/wrapper.py b/src/cachekit/decorators/wrapper.py index 7aaf814..3593bf2 100644 --- a/src/cachekit/decorators/wrapper.py +++ b/src/cachekit/decorators/wrapper.py @@ -412,6 +412,15 @@ def create_cache_wrapper( deployment_uuid = config.encryption.deployment_uuid master_key = config.encryption.master_key + # Custom key function (escape hatch for complex types) + custom_key_func = config.key + else: + custom_key_func = None + + # Re-scope custom_key_func for closure + if "custom_key_func" not in dir(): + custom_key_func = None + # Fast mode: Disable monitoring overhead, keep performance features use_circuit_breaker = circuit_breaker and not fast_mode use_adaptive_timeout = adaptive_timeout and not fast_mode @@ -541,7 +550,13 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: PLR0912 # Key generation - needed for both L1-only and L1+L2 modes try: - if fast_mode: + # Custom key function takes priority (escape hatch for complex types) + if custom_key_func is not None: + custom_key = custom_key_func(*args, **kwargs) + if not isinstance(custom_key, str): + raise TypeError(f"key function must return str, got {type(custom_key).__name__}") + cache_key = f"{namespace or 'default'}:{custom_key}" + elif fast_mode: # Minimal key generation - no string formatting overhead from ..hash_utils import cache_key_hash @@ -878,12 +893,17 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: cache_key = None func_start_time: float | None = None # Initialize for exception handlers try: - # Fast key generation path (for simple types) - if fast_mode: + # Custom key function takes priority (escape hatch for complex types) + if custom_key_func is not None: + custom_key = custom_key_func(*args, **kwargs) + if not isinstance(custom_key, str): + raise TypeError(f"key function must return str, got {type(custom_key).__name__}") + cache_key = f"{namespace or 'default'}:{custom_key}" + elif fast_mode: # Ultra-fast key generation for hot paths (10-50μs savings) from ..hash_utils import cache_key_hash - cache_namespace = namespace or namespace or "default" + cache_namespace = namespace or "default" args_kwargs_str = str(args) + str(kwargs) cache_key = cache_namespace + ":" + func_hash + ":" + cache_key_hash(args_kwargs_str) else: diff --git a/src/cachekit/key_generator.py b/src/cachekit/key_generator.py index 9a5a605..7f3d922 100644 --- a/src/cachekit/key_generator.py +++ b/src/cachekit/key_generator.py @@ -3,10 +3,25 @@ from __future__ import annotations import hashlib -from typing import Any, Callable, cast +import sys +from datetime import datetime +from decimal import Decimal +from enum import Enum +from pathlib import Path, PurePath +from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast +from uuid import UUID import msgpack +if TYPE_CHECKING: + pass + +# Constants for constrained array support (per round-table review 2025-12-18) +ARRAY_MAX_BYTES = 100_000 # 100KB per array +ARRAY_AGGREGATE_MAX = 5_000_000 # 5MB total across all args +SUPPORTED_ARRAY_DTYPES = {"int32", "int64", "float32", "float64"} +DTYPE_MAP = {"int32": "i32", "int64": "i64", "float32": "f32", "float64": "f64"} + class CacheKeyGenerator: """Generates consistent cache keys from function calls. @@ -96,9 +111,12 @@ def _blake2b_hash(self, args: tuple, kwargs: dict) -> str: Raises: TypeError: If args/kwargs contain unsupported types (custom objects, numpy arrays, etc.) """ + # Track aggregate array bytes for DoS prevention + array_bytes_seen: list[int] = [0] + # Step 1: Normalize recursively - normalized_args = [self._normalize(arg) for arg in args] - normalized_kwargs = {k: self._normalize(v) for k, v in sorted(kwargs.items())} + normalized_args = [self._normalize(arg, array_bytes_seen) for arg in args] + normalized_kwargs = {k: self._normalize(v, array_bytes_seen) for k, v in sorted(kwargs.items())} # Step 2: Serialize with MessagePack try: @@ -112,27 +130,199 @@ def _blake2b_hash(self, args: tuple, kwargs: dict) -> str: # Step 3: Hash with Blake2b-256 return hashlib.blake2b(msgpack_bytes, digest_size=32).hexdigest() - def _normalize(self, obj: Any) -> Any: + def _normalize(self, obj: Any, _array_bytes_seen: list[int] | None = None) -> Any: """Normalize object for deterministic MessagePack encoding. - CRITICAL: Ensures identical serialization across Python, TypeScript, Go, PHP. + CRITICAL: Cross-language compatible types ONLY per Protocol v1.1. + + Supported types (per round-table review 2025-12-18): + - Primitives: int, str, bytes, bool, None, float + - Collections: dict (sorted keys), list, tuple + - Extended: Path, UUID, Decimal, Enum, datetime (UTC only) + - Arrays: numpy.ndarray (1D, ≤100KB, i32/i64/f32/f64) + + Args: + obj: Object to normalize + _array_bytes_seen: Internal tracker for aggregate array size (DoS prevention) + + Returns: + Normalized object safe for MessagePack serialization + + Raises: + TypeError: For unsupported types with helpful guidance """ + # Initialize aggregate tracker if not provided + if _array_bytes_seen is None: + _array_bytes_seen = [0] + + # === COLLECTIONS (recursive) === if isinstance(obj, dict): - # Recursively normalize dict with sorted keys - return {k: self._normalize(v) for k, v in sorted(obj.items())} + return {k: self._normalize(v, _array_bytes_seen) for k, v in sorted(obj.items())} - elif isinstance(obj, (list, tuple)): - # Recursively normalize collections (tuple→list) - return [self._normalize(x) for x in obj] + if isinstance(obj, (list, tuple)): + return [self._normalize(x, _array_bytes_seen) for x in obj] - elif isinstance(obj, float): + # === FLOAT (cross-language compat) === + if isinstance(obj, float): # CRITICAL: Normalize -0.0 → 0.0 for cross-language compatibility return 0.0 if obj == 0.0 else obj - else: - # Primitives (int, str, bytes, bool, None) pass through unchanged + # === EXTENDED TYPES === + + # Path: normalize to POSIX format for cross-platform consistency + if isinstance(obj, (Path, PurePath)): + return obj.as_posix() + + # UUID: standard string format + if isinstance(obj, UUID): + return str(obj) + + # Decimal: exact string representation + if isinstance(obj, Decimal): + return str(obj) + + # Enum: use value (recursively normalize in case value is complex) + if isinstance(obj, Enum): + return self._normalize(obj.value, _array_bytes_seen) + + # datetime: UTC only, reject naive datetimes + if isinstance(obj, datetime): + if obj.tzinfo is None: + raise TypeError( + "Naive datetime not allowed in cache keys (timezone ambiguity). " + "Use timezone-aware datetime: datetime(..., tzinfo=timezone.utc)" + ) + return obj.isoformat() + + # === NUMPY ARRAY (constrained support) === + if self._is_numpy_array(obj): + return self._normalize_array(obj, _array_bytes_seen) + + # === PRIMITIVES (pass through) === + if isinstance(obj, (int, str, bytes, bool, type(None))): return obj + # === UNSUPPORTED: Fail fast with helpful message === + return self._raise_unsupported_type(obj) + + def _is_numpy_array(self, obj: Any) -> bool: + """Check if object is numpy array without importing numpy.""" + return type(obj).__module__ == "numpy" and type(obj).__name__ == "ndarray" + + def _normalize_array(self, arr: Any, _array_bytes_seen: list[int]) -> list[Any]: + """Normalize numpy array with strict constraints. + + Constraints (per round-table review 2025-12-18): + - 1D only (cross-language simplicity) + - ≤100KB (memory safety) + - 4 dtypes: i32, i64, f32, f64 (cross-language compatibility) + - Little-endian byte order (platform determinism) + - 256-bit Blake2b hash (collision resistance) + - Version prefix for future protocol changes + + Args: + arr: numpy.ndarray to normalize + _array_bytes_seen: Aggregate byte counter for DoS prevention + + Returns: + List of ["__array_v1__", shape_list, dtype_str, content_hash] + (list format for MessagePack compatibility with strict_types=True) + + Raises: + TypeError: If array doesn't meet constraints + """ + import numpy as np + + # Constraint 1: Size limit per array + if arr.nbytes > ARRAY_MAX_BYTES: + raise TypeError( + f"Array too large ({arr.nbytes:,} bytes, max {ARRAY_MAX_BYTES:,}). Use key= parameter for large arrays." + ) + + # Constraint 2: Aggregate size limit (DoS prevention) + _array_bytes_seen[0] += arr.nbytes + if _array_bytes_seen[0] > ARRAY_AGGREGATE_MAX: + raise TypeError( + f"Total array size exceeds {ARRAY_AGGREGATE_MAX:,} bytes. Use key= parameter for batch array operations." + ) + + # Constraint 3: 1D only + if arr.ndim != 1: + raise TypeError( + f"Only 1D arrays supported in cache keys (got {arr.ndim}D). " + f"Use key= parameter for multidimensional arrays, or flatten with arr.ravel()." + ) + + # Constraint 4: Supported dtypes only + dtype_name = arr.dtype.name + if dtype_name not in SUPPORTED_ARRAY_DTYPES: + raise TypeError( + f"Unsupported array dtype '{dtype_name}'. " + f"Supported: {', '.join(sorted(SUPPORTED_ARRAY_DTYPES))}. " + f"Cast with arr.astype(np.float64) or use key= parameter." + ) + + # Ensure C-contiguous memory layout + arr = np.ascontiguousarray(arr) + + # Force little-endian byte order for cross-platform determinism + if arr.dtype.byteorder not in ("=", "<", "|"): + arr = arr.astype(arr.dtype.newbyteorder("<")) + elif arr.dtype.byteorder == "=" and sys.byteorder == "big": + arr = arr.byteswap().newbyteorder("<") + + # 256-bit Blake2b hash (per security review) + content_hash = hashlib.blake2b(arr.tobytes(), digest_size=32).hexdigest() + + # Standardized dtype string for cross-language compatibility + dtype_str = DTYPE_MAP[dtype_name] + + # Version prefix for protocol evolution + # Return as list (not tuple) for MessagePack compatibility with strict_types=True + # Shape converted to list as well + return ["__array_v1__", list(arr.shape), dtype_str, content_hash] + + def _raise_unsupported_type(self, obj: Any) -> NoReturn: + """Raise helpful TypeError for unsupported types. + + Args: + obj: The unsupported object + + Raises: + TypeError: Always, with guidance on how to handle the type + """ + type_name = type(obj).__module__ + "." + type(obj).__qualname__ + + # Specific guidance for numpy arrays that don't meet constraints + if "numpy" in type_name and "ndarray" in type_name: + raise TypeError( + "numpy array doesn't meet cache key constraints. " + "Requirements: 1D, ≤100KB, dtype in (i32, i64, f32, f64). " + "Use key= parameter for other arrays." + ) + + if "pandas" in type_name: + raise TypeError( + "pandas objects not supported as cache key arguments " + "(Parquet serialization is non-deterministic). " + "Recommended patterns:\n" + " 1. Pass identifier, return DataFrame: @cache def load(id: int) -> pd.DataFrame\n" + " 2. Use explicit key: @cache(key=lambda df: hashlib.blake2b(df.to_parquet()).hexdigest())" + ) + + if isinstance(obj, (set, frozenset)): + raise TypeError( + "set/frozenset not supported in cache keys (mixed-type sorting crashes). " + "Convert to sorted list: sorted(list(your_set))" + ) + + raise TypeError( + f"Unsupported type '{type_name}' for cache key. " + f"Supported: dict, list, tuple, int, float, str, bytes, bool, None, " + f"Path, UUID, Decimal, Enum, datetime (UTC), 1D numpy arrays (≤100KB, i32/i64/f32/f64). " + f"For custom types, use key= parameter." + ) + def _normalize_key(self, key: str) -> str: """Normalize key to ensure it's valid for cache backends. diff --git a/tests/unit/test_cache_key_generator.py b/tests/unit/test_cache_key_generator.py index aa0464a..5c6db4c 100644 --- a/tests/unit/test_cache_key_generator.py +++ b/tests/unit/test_cache_key_generator.py @@ -3,6 +3,11 @@ from __future__ import annotations from dataclasses import dataclass +from datetime import datetime, timezone +from decimal import Decimal +from enum import Enum +from pathlib import Path, PurePosixPath, PureWindowsPath +from uuid import UUID import pytest @@ -260,7 +265,7 @@ def test_func(user): user = User(1, "Alice") # Should raise TypeError for unsupported type - with pytest.raises(TypeError, match="Unsupported type for cache key generation"): + with pytest.raises(TypeError, match="Unsupported type.*for cache key"): key_generator.generate_key(test_func, (user,), {}) def test_performance_with_large_objects(self, key_generator): @@ -282,3 +287,385 @@ def test_func(data): assert (end - start) < 0.1 assert isinstance(key, str) assert len(key) > 0 + + +class TestExtendedTypeNormalization: + """Tests for Path, UUID, Decimal, Enum, datetime normalization. + + Per round-table review 2025-12-18: These are safe cross-language types. + """ + + @pytest.fixture + def key_generator(self): + """Create a basic key generator instance.""" + return CacheKeyGenerator() + + def test_path_uses_posix(self, key_generator): + """Path converts to POSIX string format.""" + assert key_generator._normalize(Path("/data/cache/foo")) == "/data/cache/foo" + # Windows paths also convert to POSIX (forward slashes) + assert key_generator._normalize(PureWindowsPath("C:\\data\\cache\\foo")) == "C:/data/cache/foo" + assert key_generator._normalize(PurePosixPath("/a/b/c")) == "/a/b/c" + + def test_path_in_key_generation(self, key_generator): + """Path works in full key generation.""" + + def func(path): + return str(path) + + key1 = key_generator.generate_key(func, (Path("/data/cache/foo"),), {}) + key2 = key_generator.generate_key(func, (Path("/data/cache/foo"),), {}) + key3 = key_generator.generate_key(func, (Path("/data/cache/bar"),), {}) + + assert key1 == key2 # Same path = same key + assert key1 != key3 # Different path = different key + + def test_uuid_string_format(self, key_generator): + """UUID normalizes to standard string format.""" + u = UUID("12345678-1234-5678-1234-567812345678") + assert key_generator._normalize(u) == "12345678-1234-5678-1234-567812345678" + + def test_uuid_in_key_generation(self, key_generator): + """UUID works in full key generation.""" + + def func(user_id): + return str(user_id) + + u1 = UUID("12345678-1234-5678-1234-567812345678") + u2 = UUID("12345678-1234-5678-1234-567812345679") + + key1 = key_generator.generate_key(func, (u1,), {}) + key2 = key_generator.generate_key(func, (u1,), {}) + key3 = key_generator.generate_key(func, (u2,), {}) + + assert key1 == key2 + assert key1 != key3 + + def test_decimal_exact_string(self, key_generator): + """Decimal preserves exact string representation.""" + # Critical for financial calculations - no floating point precision loss + d = Decimal("3.14159265358979323846") + assert key_generator._normalize(d) == "3.14159265358979323846" + + # Large decimal + big = Decimal("12345678901234567890.123456789") + assert key_generator._normalize(big) == "12345678901234567890.123456789" + + def test_decimal_in_key_generation(self, key_generator): + """Decimal works in full key generation.""" + + def func(price): + return float(price) + + key1 = key_generator.generate_key(func, (Decimal("19.99"),), {}) + key2 = key_generator.generate_key(func, (Decimal("19.99"),), {}) + key3 = key_generator.generate_key(func, (Decimal("20.00"),), {}) + + assert key1 == key2 + assert key1 != key3 + + def test_enum_uses_value(self, key_generator): + """Enum normalizes using .value.""" + + class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + assert key_generator._normalize(Color.RED) == 1 + assert key_generator._normalize(Color.GREEN) == 2 + + def test_enum_with_string_value(self, key_generator): + """Enum with string value normalizes correctly.""" + + class Status(Enum): + PENDING = "pending" + ACTIVE = "active" + + assert key_generator._normalize(Status.PENDING) == "pending" + + def test_enum_in_key_generation(self, key_generator): + """Enum works in full key generation.""" + + class Priority(Enum): + LOW = 1 + HIGH = 2 + + def func(priority): + return priority.value + + key1 = key_generator.generate_key(func, (Priority.LOW,), {}) + key2 = key_generator.generate_key(func, (Priority.LOW,), {}) + key3 = key_generator.generate_key(func, (Priority.HIGH,), {}) + + assert key1 == key2 + assert key1 != key3 + + def test_datetime_utc_only(self, key_generator): + """Datetime normalizes to ISO format (UTC required).""" + dt_utc = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + result = key_generator._normalize(dt_utc) + + assert "2024-01-15" in result + assert isinstance(result, str) + + def test_datetime_naive_raises(self, key_generator): + """Naive datetime raises TypeError (timezone ambiguity).""" + dt_naive = datetime(2024, 1, 15, 12, 0, 0) + + with pytest.raises(TypeError, match="Naive datetime"): + key_generator._normalize(dt_naive) + + def test_datetime_in_key_generation(self, key_generator): + """Timezone-aware datetime works in full key generation.""" + + def func(timestamp): + return timestamp.isoformat() + + dt1 = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + dt2 = datetime(2024, 1, 15, 13, 0, 0, tzinfo=timezone.utc) + + key1 = key_generator.generate_key(func, (dt1,), {}) + key2 = key_generator.generate_key(func, (dt1,), {}) + key3 = key_generator.generate_key(func, (dt2,), {}) + + assert key1 == key2 + assert key1 != key3 + + +class TestConstrainedArrayNormalization: + """Tests for numpy array support with strict constraints. + + Per round-table review 2025-12-18: + - 1D only (cross-language simplicity) + - ≤100KB (memory safety) + - 4 dtypes: i32, i64, f32, f64 + - 256-bit Blake2b hash + - Little-endian byte order + """ + + @pytest.fixture + def key_generator(self): + """Create a basic key generator instance.""" + return CacheKeyGenerator() + + @pytest.fixture + def np(self): + """Import numpy (skip if not available).""" + pytest.importorskip("numpy") + import numpy as np + + return np + + def test_1d_float64_works(self, key_generator, np): + """1D float64 array produces valid normalized list.""" + arr = np.array([1.0, 2.0, 3.0], dtype=np.float64) + result = key_generator._normalize(arr) + + assert result[0] == "__array_v1__" # Version prefix + assert result[1] == [3] # Shape as list + assert result[2] == "f64" # Dtype code + assert len(result[3]) == 64 # 256-bit = 64 hex chars + + def test_1d_float32_works(self, key_generator, np): + """1D float32 array produces valid normalized tuple.""" + arr = np.array([1.0, 2.0, 3.0], dtype=np.float32) + result = key_generator._normalize(arr) + + assert result[2] == "f32" + + def test_1d_int64_works(self, key_generator, np): + """1D int64 array produces valid normalized tuple.""" + arr = np.array([1, 2, 3], dtype=np.int64) + result = key_generator._normalize(arr) + + assert result[2] == "i64" + + def test_1d_int32_works(self, key_generator, np): + """1D int32 array produces valid normalized tuple.""" + arr = np.array([1, 2, 3], dtype=np.int32) + result = key_generator._normalize(arr) + + assert result[2] == "i32" + + def test_same_content_same_hash(self, key_generator, np): + """Identical array content produces identical hash.""" + arr1 = np.array([1.0, 2.0, 3.0], dtype=np.float64) + arr2 = np.array([1.0, 2.0, 3.0], dtype=np.float64) + + result1 = key_generator._normalize(arr1) + result2 = key_generator._normalize(arr2) + + assert result1[3] == result2[3] # Same hash + + def test_different_content_different_hash(self, key_generator, np): + """Different array content produces different hash.""" + arr1 = np.array([1.0, 2.0, 3.0], dtype=np.float64) + arr2 = np.array([1.0, 2.0, 4.0], dtype=np.float64) + + result1 = key_generator._normalize(arr1) + result2 = key_generator._normalize(arr2) + + assert result1[3] != result2[3] # Different hash + + def test_2d_array_rejected(self, key_generator, np): + """2D arrays are rejected (cross-language simplicity).""" + arr = np.array([[1, 2], [3, 4]], dtype=np.int32) + + with pytest.raises(TypeError, match="Only 1D arrays"): + key_generator._normalize(arr) + + def test_3d_array_rejected(self, key_generator, np): + """3D arrays are rejected.""" + arr = np.zeros((2, 3, 4), dtype=np.float32) + + with pytest.raises(TypeError, match="Only 1D arrays"): + key_generator._normalize(arr) + + def test_large_array_rejected(self, key_generator, np): + """Arrays >100KB are rejected (memory safety).""" + # 100,001 bytes = just over 100KB limit + arr = np.zeros(100_001, dtype=np.int8) + + with pytest.raises(TypeError, match="Array too large"): + key_generator._normalize(arr) + + def test_at_limit_array_accepted(self, key_generator, np): + """Arrays exactly at 100KB are accepted.""" + # 100,000 bytes = exactly at limit (100KB) + arr = np.zeros(100_000, dtype=np.int8) + + # Should convert to supported dtype + with pytest.raises(TypeError, match="Unsupported array dtype"): + # int8 isn't supported, but we test size limit is okay + key_generator._normalize(arr) + + # With supported dtype at limit + arr_f32 = np.zeros(25_000, dtype=np.float32) # 25000 * 4 = 100KB + result = key_generator._normalize(arr_f32) + assert result[0] == "__array_v1__" + + def test_unsupported_dtype_rejected(self, key_generator, np): + """Unsupported dtypes are rejected with helpful message.""" + dtypes_to_reject = [np.float16, np.int8, np.uint32, np.complex64] + + for dtype in dtypes_to_reject: + arr = np.array([1, 2, 3], dtype=dtype) + with pytest.raises(TypeError, match="Unsupported array dtype"): + key_generator._normalize(arr) + + def test_aggregate_limit_enforced(self, key_generator, np): + """Total array size across all args limited to 5MB (DoS prevention).""" + # Create arrays that individually pass but aggregate exceeds 5MB + # 60 arrays of ~100KB each = 6MB > 5MB limit + arrays = [np.zeros(25_000, dtype=np.float32) for _ in range(60)] + + with pytest.raises(TypeError, match="Total array size exceeds"): + key_generator._normalize(arrays) + + def test_cross_platform_determinism(self, key_generator, np): + """Little-endian normalization produces consistent hashes.""" + # Create big-endian array + arr_be = np.array([1.0, 2.0, 3.0], dtype=">f8") # Big-endian float64 + result_be = key_generator._normalize(arr_be) + + # Create little-endian array + arr_le = np.array([1.0, 2.0, 3.0], dtype=" threshold + + arr = np.array([1.0, 2.0, 3.0], dtype=np.float64) + + key1 = key_generator.generate_key(func, ("test", arr, 1.5), {}) + key2 = key_generator.generate_key(func, ("test", arr, 1.5), {}) + key3 = key_generator.generate_key(func, ("test", arr, 2.5), {}) + + assert key1 == key2 + assert key1 != key3 + + +class TestUnsupportedTypesWithGuidance: + """Tests for helpful error messages on unsupported types. + + Per round-table review 2025-12-18: Fail fast with guidance. + """ + + @pytest.fixture + def key_generator(self): + """Create a basic key generator instance.""" + return CacheKeyGenerator() + + def test_set_rejected_with_guidance(self, key_generator): + """set raises TypeError with sorting crash explanation.""" + with pytest.raises(TypeError, match="mixed-type sorting"): + key_generator._normalize({1, 2, 3}) + + def test_frozenset_rejected_with_guidance(self, key_generator): + """frozenset raises TypeError with sorting crash explanation.""" + with pytest.raises(TypeError, match="mixed-type sorting"): + key_generator._normalize(frozenset({1, 2, 3})) + + def test_pandas_dataframe_rejected_with_guidance(self, key_generator): + """pandas DataFrame rejected with Parquet explanation.""" + pd = pytest.importorskip("pandas") + + df = pd.DataFrame({"a": [1, 2, 3]}) + with pytest.raises(TypeError, match="Parquet serialization is non-deterministic"): + key_generator._normalize(df) + + def test_pandas_series_rejected(self, key_generator): + """pandas Series rejected (same as DataFrame).""" + pd = pytest.importorskip("pandas") + + s = pd.Series([1, 2, 3], name="test") + with pytest.raises(TypeError, match="pandas"): + key_generator._normalize(s) + + def test_custom_class_rejected_with_key_guidance(self, key_generator): + """Custom class raises TypeError suggesting key= parameter.""" + + class CustomObject: + pass + + with pytest.raises(TypeError, match="key= parameter"): + key_generator._normalize(CustomObject()) + + def test_numpy_2d_array_guidance(self, key_generator): + """2D numpy array gives specific constraint guidance.""" + np = pytest.importorskip("numpy") + + arr = np.array([[1, 2], [3, 4]], dtype=np.int32) + with pytest.raises(TypeError, match="1D"): + key_generator._normalize(arr) + + def test_numpy_large_array_guidance(self, key_generator): + """Large numpy array gives size guidance.""" + np = pytest.importorskip("numpy") + + arr = np.zeros(200_000, dtype=np.float32) # 800KB > 100KB limit + with pytest.raises(TypeError, match="100,000"): + key_generator._normalize(arr) diff --git a/tests/unit/test_custom_key_function.py b/tests/unit/test_custom_key_function.py new file mode 100644 index 0000000..079add5 --- /dev/null +++ b/tests/unit/test_custom_key_function.py @@ -0,0 +1,232 @@ +"""Tests for custom key= parameter support in @cache decorator. + +Per round-table review 2025-12-18: Custom key function is the cross-language +escape hatch for complex types (2D arrays, DataFrames, large data, custom types). +""" + +from __future__ import annotations + +import hashlib +from typing import Any + +import pytest + + +class TestCustomKeyFunctionBasic: + """Basic custom key function tests.""" + + def test_custom_key_receives_args(self): + """Custom key function receives positional arguments.""" + received_args: list[tuple[Any, ...]] = [] + + def capture_key(*args): + received_args.append(args) + return f"key:{args[0]}" + + # Import here to avoid import errors if cachekit not fully set up + from cachekit.config import DecoratorConfig + + config = DecoratorConfig.minimal(key=capture_key) + + # Verify key field is set + assert config.key is capture_key + + def test_custom_key_receives_kwargs(self): + """Custom key function receives keyword arguments.""" + received_kwargs: list[dict[str, Any]] = [] + + def capture_key(*args, **kwargs): + received_kwargs.append(kwargs) + return f"key:{kwargs.get('x', 0)}" + + from cachekit.config import DecoratorConfig + + config = DecoratorConfig.minimal(key=capture_key) + assert config.key is capture_key + + def test_custom_key_must_return_string(self): + """Key function must return string type.""" + + def bad_key(*args): + return 42 # Returns int, not str + + from cachekit.config import DecoratorConfig + + config = DecoratorConfig.minimal(key=bad_key) + + # The config accepts any callable - validation happens at runtime + assert config.key is bad_key + + +class TestCustomKeyFunctionWithNumpyArrays: + """Test custom key function with numpy arrays.""" + + @pytest.fixture + def np(self): + """Import numpy (skip if not available).""" + pytest.importorskip("numpy") + import numpy as np + + return np + + def test_array_key_function_pattern(self, np): + """Demonstrate the array key function pattern.""" + + def array_key(arr): + """Custom key for numpy array using content hash.""" + return hashlib.blake2b(arr.tobytes(), digest_size=16).hexdigest() + + arr = np.array([[1, 2], [3, 4]]) # 2D array (would fail standard key gen) + key = array_key(arr) + + assert isinstance(key, str) + assert len(key) == 32 # 128-bit = 32 hex chars + + def test_array_with_metadata_key_pattern(self, np): + """Demonstrate array + metadata key pattern.""" + + def array_with_meta_key(arr, name: str): + """Key includes both array content and metadata.""" + content_hash = hashlib.blake2b(arr.tobytes(), digest_size=16).hexdigest() + return f"{name}:{arr.shape}:{arr.dtype}:{content_hash}" + + arr = np.array([[1, 2], [3, 4]], dtype=np.float64) + key = array_with_meta_key(arr, "matrix") + + assert "matrix" in key + assert "(2, 2)" in key + assert "float64" in key + + +class TestCustomKeyFunctionWithDataFrames: + """Test custom key function with pandas DataFrames.""" + + @pytest.fixture + def pd(self): + """Import pandas (skip if not available).""" + pytest.importorskip("pandas") + import pandas as pd + + return pd + + def test_dataframe_key_function_pattern(self, pd): + """Demonstrate the DataFrame key function pattern.""" + + def dataframe_key(df): + """Custom key for DataFrame using values hash.""" + # Use values.tobytes() for deterministic hashing + return hashlib.blake2b(df.values.tobytes(), digest_size=16).hexdigest() + + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + key = dataframe_key(df) + + assert isinstance(key, str) + assert len(key) == 32 + + def test_dataframe_same_content_same_key(self, pd): + """Same DataFrame content produces same key.""" + + def dataframe_key(df): + return hashlib.blake2b(df.values.tobytes(), digest_size=16).hexdigest() + + df1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df2 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + assert dataframe_key(df1) == dataframe_key(df2) + + def test_dataframe_different_content_different_key(self, pd): + """Different DataFrame content produces different key.""" + + def dataframe_key(df): + return hashlib.blake2b(df.values.tobytes(), digest_size=16).hexdigest() + + df1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df2 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 7]}) + + assert dataframe_key(df1) != dataframe_key(df2) + + +class TestCustomKeyFunctionComposite: + """Test composite key patterns combining multiple arguments.""" + + def test_composite_key_pattern(self): + """Demonstrate composite key with multiple args.""" + + def composite_key(model_id: str, version: int, params: dict): + """Key combining model ID, version, and params hash.""" + import json + + params_str = json.dumps(params, sort_keys=True) + params_hash = hashlib.blake2b(params_str.encode(), digest_size=8).hexdigest() + return f"{model_id}:v{version}:{params_hash}" + + key = composite_key("bert-base", 2, {"max_length": 512, "batch_size": 32}) + + assert "bert-base" in key + assert "v2" in key + assert len(key.split(":")) == 3 + + def test_identity_based_key_pattern(self): + """Demonstrate identity-based key (ignoring content).""" + + def identity_key(user_id: int, _data: Any): + """Key based only on user_id, ignoring data content.""" + return f"user:{user_id}" + + # Same user_id = same key regardless of data + key1 = identity_key(123, {"name": "Alice"}) + key2 = identity_key(123, {"name": "Bob"}) + key3 = identity_key(456, {"name": "Alice"}) + + assert key1 == key2 # Same user, different data = same key + assert key1 != key3 # Different user = different key + + +class TestDecoratorConfigKeyField: + """Test DecoratorConfig key field behavior.""" + + def test_key_field_default_none(self): + """Key field defaults to None.""" + from cachekit.config import DecoratorConfig + + config = DecoratorConfig.minimal() + assert config.key is None + + def test_key_field_in_to_dict(self): + """Key field appears in to_dict output.""" + from cachekit.config import DecoratorConfig + + def my_key(*args): + return "test" + + config = DecoratorConfig.minimal(key=my_key) + config_dict = config.to_dict() + + assert "key" in config_dict + assert config_dict["key"] is my_key + + def test_key_field_accepts_lambda(self): + """Key field accepts lambda functions.""" + from cachekit.config import DecoratorConfig + + config = DecoratorConfig.minimal(key=lambda x: f"key:{x}") + assert config.key is not None + assert config.key(42) == "key:42" + + def test_key_field_accepts_callable_class(self): + """Key field accepts callable class instances.""" + + class MyKeyGenerator: + def __init__(self, prefix: str): + self.prefix = prefix + + def __call__(self, *args): + return f"{self.prefix}:{args}" + + from cachekit.config import DecoratorConfig + + gen = MyKeyGenerator("cache") + config = DecoratorConfig.minimal(key=gen) + + assert config.key is gen + assert config.key(1, 2) == "cache:(1, 2)" From 51c16f6ee78403d7d77044c5071ab05b2c71a21d Mon Sep 17 00:00:00 2001 From: Ray Date: Thu, 18 Dec 2025 19:30:12 +1100 Subject: [PATCH 2/2] test: add coverage for key normalization and custom key integration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add TestKeyNormalization: long key shortening, special char replacement - Add integration tests for custom key= parameter with L1-only mode - Test custom key caching behavior, numpy array support, error handling - Test async custom key function - key_generator.py coverage: 93% → 96% - wrapper.py coverage: 5% → 33% --- .../test_custom_key_integration.py | 206 ++++++++++++++++++ tests/unit/test_cache_key_generator.py | 64 ++++++ 2 files changed, 270 insertions(+) create mode 100644 tests/integration/test_custom_key_integration.py diff --git a/tests/integration/test_custom_key_integration.py b/tests/integration/test_custom_key_integration.py new file mode 100644 index 0000000..5dc398c --- /dev/null +++ b/tests/integration/test_custom_key_integration.py @@ -0,0 +1,206 @@ +"""Integration tests for custom key= parameter in @cache decorator. + +Tests the actual decorator behavior with L1-only mode (no Redis required). +""" + +from __future__ import annotations + +import hashlib + +import pytest + +from cachekit import cache + + +class TestCustomKeyFunctionIntegration: + """Test custom key function with actual caching behavior.""" + + def test_custom_key_function_caches_correctly(self): + """Custom key function produces cache hits.""" + call_count = 0 + + @cache( + key=lambda x, y: f"{x}:{y}", + backend=None, # L1-only mode + l1_enabled=True, + ) + def add(x: int, y: int) -> int: + nonlocal call_count + call_count += 1 + return x + y + + # First call - cache miss + result1 = add(1, 2) + assert result1 == 3 + assert call_count == 1 + + # Second call with same args - cache hit + result2 = add(1, 2) + assert result2 == 3 + assert call_count == 1 # No additional call + + # Different args - cache miss + result3 = add(2, 3) + assert result3 == 5 + assert call_count == 2 + + def test_custom_key_function_different_args_same_key(self): + """Custom key can map different args to same cache entry.""" + call_count = 0 + + # Key function ignores second argument - must accept **kwargs + def user_key(user_id, include_deleted=False, **kwargs): + return f"user:{user_id}" + + @cache( + key=user_key, + backend=None, + l1_enabled=True, + ) + def get_user(user_id: int, include_deleted: bool = False) -> dict: + nonlocal call_count + call_count += 1 + return {"id": user_id, "name": f"User {user_id}"} + + # First call + result1 = get_user(123, include_deleted=False) + assert call_count == 1 + + # Same user_id, different include_deleted - should hit cache + # because key function ignores include_deleted + result2 = get_user(123, include_deleted=True) + assert call_count == 1 # Still 1 - cache hit + assert result1 == result2 + + def test_custom_key_with_numpy_array(self): + """Custom key enables numpy arrays as arguments.""" + np = pytest.importorskip("numpy") + call_count = 0 + + def array_key(arr): + return hashlib.blake2b(arr.tobytes(), digest_size=16).hexdigest() + + @cache( + key=array_key, + backend=None, + l1_enabled=True, + ) + def sum_array(arr) -> float: + nonlocal call_count + call_count += 1 + return float(arr.sum()) + + arr1 = np.array([1.0, 2.0, 3.0]) + arr2 = np.array([1.0, 2.0, 3.0]) # Same content + arr3 = np.array([4.0, 5.0, 6.0]) # Different content + + # First call + result1 = sum_array(arr1) + assert result1 == 6.0 + assert call_count == 1 + + # Same content array - cache hit + result2 = sum_array(arr2) + assert result2 == 6.0 + assert call_count == 1 + + # Different content - cache miss + result3 = sum_array(arr3) + assert result3 == 15.0 + assert call_count == 2 + + def test_custom_key_wrong_return_type_falls_through(self): + """Key function returning non-string falls through to function execution.""" + call_count = 0 + + @cache( + key=lambda x: x, # Returns int, not str - will fail + backend=None, + l1_enabled=True, + ) + def process(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * 2 + + # Should execute function despite key error (graceful degradation) + result = process(42) + assert result == 84 + assert call_count == 1 + + # Second call also falls through (no caching due to key error) + result2 = process(42) + assert result2 == 84 + assert call_count == 2 # Called again - no cache + + def test_custom_key_error_executes_function(self): + """Key function error falls through to function execution.""" + call_count = 0 + + def bad_key(*args): + raise ValueError("Key generation failed") + + @cache( + key=bad_key, + backend=None, + l1_enabled=True, + ) + def add(x: int, y: int) -> int: + nonlocal call_count + call_count += 1 + return x + y + + # Should execute function despite key error + # (errors in key generation fall through to function) + result = add(1, 2) + assert result == 3 + assert call_count == 1 + + +class TestCustomKeyFunctionAsync: + """Test custom key function with async functions.""" + + @pytest.mark.asyncio + async def test_async_custom_key_function(self): + """Custom key works with async functions.""" + call_count = 0 + + @cache( + key=lambda x: f"async:{x}", + backend=None, + l1_enabled=True, + ) + async def async_double(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * 2 + + # First call - cache miss + result1 = await async_double(5) + assert result1 == 10 + assert call_count == 1 + + # Second call - cache hit + result2 = await async_double(5) + assert result2 == 10 + assert call_count == 1 + + @pytest.mark.asyncio + async def test_async_custom_key_wrong_return_type_falls_through(self): + """Async: Key function returning non-string falls through to function execution.""" + call_count = 0 + + @cache( + key=lambda x: 123, # Returns int, not str - will fail + backend=None, + l1_enabled=True, + ) + async def process(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * 2 + + # Should execute function despite key error (graceful degradation) + result = await process(42) + assert result == 84 + assert call_count == 1 diff --git a/tests/unit/test_cache_key_generator.py b/tests/unit/test_cache_key_generator.py index 5c6db4c..44ff5a6 100644 --- a/tests/unit/test_cache_key_generator.py +++ b/tests/unit/test_cache_key_generator.py @@ -669,3 +669,67 @@ def test_numpy_large_array_guidance(self, key_generator): arr = np.zeros(200_000, dtype=np.float32) # 800KB > 100KB limit with pytest.raises(TypeError, match="100,000"): key_generator._normalize(arr) + + +class TestKeyNormalization: + """Tests for cache key string normalization.""" + + @pytest.fixture + def key_generator(self): + """Create a basic key generator instance.""" + return CacheKeyGenerator() + + def test_long_key_gets_shortened(self, key_generator): + """Keys exceeding MAX_KEY_LENGTH are shortened with hash.""" + + def func_with_very_long_name(): + pass + + # Create args that will generate a very long key + long_arg = "x" * 300 # This alone exceeds MAX_KEY_LENGTH (250) + key = key_generator.generate_key(func_with_very_long_name, (long_arg,), {}) + + # Key should be shortened to MAX_KEY_LENGTH or less + assert len(key) <= key_generator.MAX_KEY_LENGTH + + def test_shortened_key_contains_prefix_and_hash(self, key_generator): + """Shortened keys contain readable prefix and hash for debugging.""" + + def my_func(): + pass + + # Generate key with very long namespace to force shortening + long_namespace = "a" * 300 + key = key_generator.generate_key(my_func, (), {}, namespace=long_namespace) + + # Should contain hash (32 hex chars after colon) + assert ":" in key + # Should be within limit + assert len(key) <= key_generator.MAX_KEY_LENGTH + + def test_special_characters_replaced(self, key_generator): + """Spaces, newlines, carriage returns are replaced with underscores.""" + + def func(): + pass + + # These characters would be in the key via function name or args + key = key_generator.generate_key(func, ("hello world",), {}) + + # Key should not contain problematic characters + assert " " not in key + assert "\n" not in key + assert "\r" not in key + + def test_deterministic_long_key_shortening(self, key_generator): + """Same long input produces same shortened key.""" + + def func(): + pass + + long_arg = "x" * 300 + + key1 = key_generator.generate_key(func, (long_arg,), {}) + key2 = key_generator.generate_key(func, (long_arg,), {}) + + assert key1 == key2