From 2aae637e7335c3e8b6562536ce920b4f071bf6f8 Mon Sep 17 00:00:00 2001 From: 27bslash6 <2221076+27bslash6@users.noreply.github.com> Date: Tue, 13 Jan 2026 22:01:24 +1100 Subject: [PATCH 1/2] fix: lazy-load ArrowSerializer to avoid ImportError without pyarrow ArrowSerializer was unconditionally imported at module load time, causing ImportError for users who install cachekit without the [data] extra (pyarrow). Changes: - Remove top-level ArrowSerializer import - Add lazy loading via _get_arrow_serializer() helper - Use __getattr__ for lazy attribute access - Update SERIALIZER_REGISTRY to use None placeholder - Fix benchmark_serializers and get_serializer_info to handle lazy loading gracefully Users can now: - pip install cachekit and use default serializers - pip install cachekit[data] to enable ArrowSerializer Fixes #41 --- src/cachekit/serializers/__init__.py | 63 ++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 13 deletions(-) diff --git a/src/cachekit/serializers/__init__.py b/src/cachekit/serializers/__init__.py index 28a4a1c..ee55510 100644 --- a/src/cachekit/serializers/__init__.py +++ b/src/cachekit/serializers/__init__.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import logging from threading import Lock -from typing import Any +from typing import TYPE_CHECKING, Any from cachekit._rust_serializer import ByteStorage -from .arrow_serializer import ArrowSerializer from .auto_serializer import AutoSerializer from .base import ( SerializationError, @@ -16,8 +17,25 @@ from .orjson_serializer import OrjsonSerializer from .standard_serializer import StandardSerializer +if TYPE_CHECKING: + from .arrow_serializer import ArrowSerializer + logger = logging.getLogger(__name__) +# Lazy import for optional ArrowSerializer (requires pyarrow from [data] extra) +_ArrowSerializer: type | None = None + + +def _get_arrow_serializer() -> type: + """Lazy-load ArrowSerializer. Raises ImportError if pyarrow not installed.""" + global _ArrowSerializer + if _ArrowSerializer is None: + from .arrow_serializer import ArrowSerializer + + _ArrowSerializer = ArrowSerializer + return _ArrowSerializer + + # Validate ByteStorage works correctly test_storage = ByteStorage("msgpack") test_data = b"test validation data" @@ -36,7 +54,7 @@ "auto": AutoSerializer, # Python-specific types (NumPy, pandas, datetime optimization) "default": StandardSerializer, # Language-agnostic MessagePack for multi-language caches "std": StandardSerializer, # Explicit StandardSerializer alias - "arrow": ArrowSerializer, + "arrow": None, # Lazy-loaded: requires pyarrow from [data] extra "orjson": OrjsonSerializer, "encrypted": EncryptionWrapper, # AutoSerializer + AES-256-GCM encryption } @@ -96,8 +114,13 @@ def get_serializer(name: str, enable_integrity_checking: bool = True) -> Seriali f"@cache(serializer=MySerializer())" ) + # Get serializer class (lazy-load arrow if needed) + if name == "arrow": + serializer_class = _get_arrow_serializer() + else: + serializer_class = SERIALIZER_REGISTRY[name] + # Instantiate with integrity checking configuration - serializer_class = SERIALIZER_REGISTRY[name] if name in ("default", "std", "auto", "arrow", "orjson"): # All core serializers use enable_integrity_checking parameter serializer = serializer_class(enable_integrity_checking=enable_integrity_checking) @@ -167,9 +190,9 @@ def get_available_serializers() -> dict[str, Any]: def benchmark_serializers() -> dict[str, Any]: """Get instantiated serializers for benchmarking.""" serializers = {} - for name, cls in get_available_serializers().items(): + for name in SERIALIZER_REGISTRY: try: - serializers[name] = cls() + serializers[name] = get_serializer(name) except Exception as e: logger.warning(f"Failed to instantiate {name} serializer: {e}") return serializers @@ -178,28 +201,42 @@ def benchmark_serializers() -> dict[str, Any]: def get_serializer_info() -> dict[str, dict[str, Any]]: """Get information about available serializers.""" info = {} - for name, cls in get_available_serializers().items(): + for name in SERIALIZER_REGISTRY: try: - instance = cls() + instance = get_serializer(name) info[name] = { - "class": cls.__name__, - "module": cls.__module__, + "class": type(instance).__name__, + "module": type(instance).__module__, "available": True, - "description": cls.__doc__ or "No description available", + "description": type(instance).__doc__ or "No description available", } # Add method info if available if hasattr(instance, "get_info"): info[name].update(instance.get_info()) + except ImportError as e: + info[name] = { + "class": "ArrowSerializer" if name == "arrow" else "Unknown", + "module": "cachekit.serializers.arrow_serializer", + "available": False, + "error": str(e), + } except Exception as e: info[name] = { - "class": cls.__name__, - "module": cls.__module__, + "class": "Unknown", + "module": "unknown", "available": False, "error": str(e), } return info +def __getattr__(name: str) -> Any: + """Lazy attribute access for optional ArrowSerializer.""" + if name == "ArrowSerializer": + return _get_arrow_serializer() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + # Export the main interface __all__ = [ "ArrowSerializer", From b9cc5747feea477be10a6e487b1c2467d696a970 Mon Sep 17 00:00:00 2001 From: 27bslash6 <2221076+27bslash6@users.noreply.github.com> Date: Tue, 13 Jan 2026 22:02:57 +1100 Subject: [PATCH 2/2] fix: use platform.node() instead of os.uname() for Windows compatibility os.uname() is not available on Windows, causing ImportError when importing cachekit. platform.node() is cross-platform and provides the same hostname information. Fixes #43 --- src/cachekit/logging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cachekit/logging.py b/src/cachekit/logging.py index c8d8719..ad700c3 100644 --- a/src/cachekit/logging.py +++ b/src/cachekit/logging.py @@ -7,6 +7,7 @@ import json import logging import os +import platform import random import threading import time @@ -170,7 +171,7 @@ def __init__(self, name: str, mask_sensitive: bool = True): # Pre-computed values for performance self._sampling_threshold = int(SAMPLING_RATE * 100) - self._hostname = os.uname().nodename + self._hostname = platform.node() self._pid = os.getpid() # PII patterns to mask (pre-compiled for speed)