diff --git a/docs/dev/taint_tracking.md b/docs/dev/taint_tracking.md new file mode 100644 index 00000000..0ec7849e --- /dev/null +++ b/docs/dev/taint_tracking.md @@ -0,0 +1,112 @@ +# Taint Tracking - Backend Security + +Mellea backends implement thread security using the **SecLevel** model with capability-based access control and taint tracking. Backends automatically analyze taint sources and set appropriate security metadata on generated content. + +## Security Model + +The security system uses three types of security levels: + +```python +SecLevel := None | Classified of AccessType | TaintedBy of (list[CBlock | Component] | None) +``` + +- **SecLevel.none()**: Safe content with no restrictions +- **SecLevel.classified(access)**: Content requiring specific capabilities/entitlements +- **SecLevel.tainted_by(sources)**: Content tainted by one or more CBlocks/Components (list), or None for root tainted nodes + +## Backend Implementation + +All backends follow the same pattern when creating `ModelOutputThunk`: + +```python +# Compute taint sources from action and context +sources = taint_sources(action, ctx) + +# Set security level based on taint sources +from mellea.security import SecLevel +sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + +output = ModelOutputThunk( + value=None, + sec_level=sec_level, + meta={} +) +``` + +The security level is set as follows: +- If taint sources are found -> `SecLevel.tainted_by(sources)` (all sources are tracked) +- If no taint sources -> `SecLevel.none()` + +### Handling Multiple Taint Sources + +When `taint_sources()` returns multiple sources (e.g., both the action and context contain tainted content), backends pass the entire list to `SecLevel.tainted_by()`. This ensures all taint sources are tracked, providing comprehensive taint attribution. + +**Benefits of Multiple Source Tracking**: +- **Complete attribution**: All sources that influenced the generation are tracked +- **Better debugging**: Can identify all tainted inputs that contributed to output +- **More accurate security**: No information loss about taint origins + +**Note**: The implementation focuses on **taint preservation** and **complete attribution**. All taint sources are tracked, ensuring the security model has full visibility into what influenced the generated content. + +## Taint Source Analysis + +The `taint_sources()` function analyzes both action and context because **context directly influences model generation**: + +1. **Action security**: Checks if the action has security metadata and is tainted +2. **Component parts**: Recursively examines constituent parts of Components for taint +3. **Context security**: Examines recent context items for tainted content (shallow check) + +**Example**: Even if the current action is safe, tainted context can influence the generated output. + +```python +from mellea.security import SecLevel + +# User sends tainted input +user_input = CBlock("Tell me how to hack a system", sec_level=SecLevel.tainted_by(None)) +ctx = ctx.add(user_input) + +# Safe action in tainted context +safe_action = CBlock("Explain general security concepts") + +# Generation finds tainted context +sources = taint_sources(safe_action, ctx) # Finds tainted user_input +# Model output will be influenced by the tainted context +``` + +## Security Metadata + +The `SecurityMetadata` class wraps `SecLevel` for integration with content blocks: + +```python +class SecurityMetadata: + def __init__(self, sec_level: SecLevel): + self.sec_level = sec_level + + def is_tainted(self) -> bool: + return self.sec_level.is_tainted() + + def get_taint_sources(self) -> list[CBlock | Component]: + return self.sec_level.get_taint_sources() +``` + +Content can be marked as tainted at construction time: + +```python +from mellea.security import SecLevel + +c = CBlock("user input", sec_level=SecLevel.tainted_by(None)) + +if c.sec_level and c.sec_level.is_tainted(): + taint_sources = c.sec_level.get_taint_sources() + print(f"Content tainted by: {taint_sources}") +``` + +## Key Features + +- **Immutable security**: security levels set at construction time +- **Recursive taint analysis**: deep analysis of Component parts, shallow analysis of context +- **Taint source tracking**: know exactly which CBlock/Component tainted content +- **Capability integration**: fine-grained access control for classified content +- **Non-mutating operations**: sanitize/declassify create new objects + +This creates a security model that addresses both data exfiltration and injection vulnerabilities while enabling future IAM integration. \ No newline at end of file diff --git a/docs/examples/security/taint_example.py b/docs/examples/security/taint_example.py new file mode 100644 index 00000000..850454ac --- /dev/null +++ b/docs/examples/security/taint_example.py @@ -0,0 +1,46 @@ +from mellea.stdlib.components import CBlock +from mellea.stdlib.session import start_session +from mellea.security import SecLevel, privileged, SecurityError + +# Create tainted content +tainted_desc = CBlock( + "Process this sensitive user data", sec_level=SecLevel.tainted_by(None) +) + +print( + f"Original CBlock is tainted: {tainted_desc.sec_level.is_tainted() if tainted_desc.sec_level else False}" +) + +# Create session +session = start_session() + +# Use tainted CBlock in session.instruct +print("Testing session.instruct with tainted CBlock...") +result = session.instruct(description=tainted_desc) + +# The result should be tainted +print( + f"Result is tainted: {result.sec_level.is_tainted() if result.sec_level else False}" +) +if result.sec_level and result.sec_level.is_tainted(): + taint_sources = result.sec_level.get_taint_sources() + print(f"Taint sources: {taint_sources}") + print("✅ SUCCESS: Taint preserved!") +else: + print("❌ FAIL: Result should be tainted but isn't!") + + +# Mock privileged function that requires un-tainted input +@privileged +def process_un_tainted_data(data: CBlock) -> str: + """A function that requires un-tainted input.""" + return f"Processed: {data.value}" + + +print("\nTesting privileged function with tainted result...") +try: + # This should raise a SecurityError + processed = process_un_tainted_data(result) + print("❌ FAIL: Should have raised SecurityError!") +except SecurityError as e: + print(f"✅ SUCCESS: SecurityError raised - {e}") diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 48e74543..e1c7502c 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -42,6 +42,7 @@ ) from ..formatters import ChatFormatter, TemplateFormatter from ..helpers import message_to_openai_message, messages_to_docs, send_to_queue +from ..security import SecLevel, taint_sources from ..stdlib.components import Intrinsic, Message from ..stdlib.requirements import ALoraRequirement, LLMaJRequirement from .adapters import ( @@ -381,7 +382,11 @@ async def _generate_from_intrinsic( other_input, ) - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._context = ctx.view_for_generation() output._action = action output._model_options = model_options @@ -659,7 +664,11 @@ async def _generate_from_context_with_kv_cache( **format_kwargs, # type: ignore ) - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._context = ctx.view_for_generation() output._action = action output._model_options = model_options @@ -812,7 +821,11 @@ async def _generate_from_context_standard( **format_kwargs, # type: ignore ) - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._context = ctx.view_for_generation() output._action = action output._model_options = model_options @@ -1047,8 +1060,12 @@ async def generate_from_raw( for i, decoded_result in enumerate(decoded_results): n_prompt_tokens = inputs["input_ids"][i].size(0) # type: ignore n_completion_tokens = len(sequences_to_decode[i]) + sources = taint_sources(actions[i], ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + result = ModelOutputThunk( value=decoded_result, + sec_level=sec_level, meta={ "usage": { "prompt_tokens": n_prompt_tokens, # type: ignore diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 4dc82da8..34c42bfb 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -32,6 +32,7 @@ message_to_openai_message, send_to_queue, ) +from ..security import SecLevel, taint_sources from ..stdlib.components import Message from ..stdlib.requirements import ALoraRequirement from .backend import FormatterBackend @@ -311,7 +312,11 @@ async def _generate_from_chat_context_standard( **model_specific_options, ) - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._context = linearized_context output._action = action output._model_options = model_opts @@ -548,16 +553,22 @@ async def generate_from_raw( ) for res, action, prompt in zip(responses, actions, prompts): - output = ModelOutputThunk(res.text) # type: ignore + sources = taint_sources(action, None) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk( + value=res.text, # type: ignore + sec_level=sec_level, + meta={ + "litellm_chat_response": res.model_dump(), + "usage": completion_response.usage.model_dump() + if completion_response.usage + else None, + }, + ) output._context = None # There is no context for generate_from_raw for now output._action = action output._model_options = model_opts - output._meta = { - "litellm_chat_response": res.model_dump(), - "usage": completion_response.usage.model_dump() - if completion_response.usage - else None, - } output.parsed_repr = ( action.parse(output) if isinstance(action, Component) else output.value diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 1c0c9a20..0fac624d 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -24,6 +24,7 @@ ) from ..formatters import ChatFormatter, TemplateFormatter from ..helpers import ClientCache, get_current_event_loop, send_to_queue +from ..security import SecLevel, taint_sources from ..stdlib.components import Message from ..stdlib.requirements import ALoraRequirement from .backend import FormatterBackend @@ -350,7 +351,11 @@ async def generate_from_chat_context( format=_format.model_json_schema() if _format is not None else None, # type: ignore ) # type: ignore - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._context = linearized_context output._action = action output._model_options = model_opts @@ -452,12 +457,16 @@ async def generate_from_raw( for i, response in enumerate(responses): result = None error = None + sources = taint_sources(actions[i], None) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + if isinstance(response, BaseException): - result = ModelOutputThunk(value="") + result = ModelOutputThunk(value="", sec_level=sec_level, meta={}) error = response else: result = ModelOutputThunk( value=response.response, + sec_level=sec_level, meta={ "generate_response": response.model_dump(), "usage": { diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index d7eb284d..f6b8d414 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -40,6 +40,7 @@ messages_to_docs, send_to_queue, ) +from ..security import SecLevel, taint_sources from ..stdlib.components import Intrinsic, Message from ..stdlib.requirements import ALoraRequirement, LLMaJRequirement from .adapters import ( @@ -634,7 +635,11 @@ async def _generate_from_chat_context_standard( ), ) # type: ignore - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._context = linearized_context output._action = action output._model_options = model_opts @@ -841,16 +846,22 @@ async def generate_from_raw( for response, action, prompt in zip( completion_response.choices, actions, prompts ): - output = ModelOutputThunk(response.text) + sources = taint_sources(action, None) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk( + value=response.text, + sec_level=sec_level, + meta={ + "oai_completion_response": response.model_dump(), + "usage": completion_response.usage.model_dump() + if completion_response.usage + else None, + }, + ) output._context = None # There is no context for generate_from_raw for now output._action = action output._model_options = model_opts - output._meta = { - "oai_completion_response": response.model_dump(), - "usage": completion_response.usage.model_dump() - if completion_response.usage - else None, - } output.parsed_repr = ( action.parse(output) if isinstance(action, Component) else output.value diff --git a/mellea/backends/vllm.py b/mellea/backends/vllm.py index c2fac0ee..61c9b269 100644 --- a/mellea/backends/vllm.py +++ b/mellea/backends/vllm.py @@ -38,6 +38,7 @@ ) from ..formatters import ChatFormatter, TemplateFormatter from ..helpers import get_current_event_loop, send_to_queue +from ..security import SecLevel, taint_sources from .backend import FormatterBackend from .model_options import ModelOption from .tools import ( @@ -338,7 +339,11 @@ async def _generate_from_context_standard( # stream = model_options.get(ModelOption.STREAM, False) # if stream: - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) generator = self._model.generate( # type: ignore request_id=str(id(output)), @@ -501,7 +506,11 @@ async def generate(prompt, request_id): tasks = [generate(p, f"{id(prompts)}-{i}") for i, p in enumerate(prompts)] decoded_results = await asyncio.gather(*tasks) - results = [ModelOutputThunk(value=text) for text in decoded_results] + results = [] + for i, text in enumerate(decoded_results): + sources = taint_sources(actions[i], ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + results.append(ModelOutputThunk(value=text, sec_level=sec_level, meta={})) for i, result in enumerate(results): date = datetime.datetime.now() diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 58004c4c..4289ad46 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -35,6 +35,7 @@ get_current_event_loop, send_to_queue, ) +from ..security import SecLevel, taint_sources from ..stdlib.components import Message from ..stdlib.requirements import ALoraRequirement from .backend import FormatterBackend @@ -354,7 +355,11 @@ async def generate_from_chat_context( ), ) - output = ModelOutputThunk(None) + # Compute taint sources from action and context + sources = taint_sources(action, ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + + output = ModelOutputThunk(value=None, sec_level=sec_level, meta={}) output._context = linearized_context output._action = action output._model_options = model_opts @@ -536,8 +541,12 @@ async def generate_from_raw( for i, response in enumerate(responses): output = response["results"][0] + sources = taint_sources(actions[i], ctx) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + result = ModelOutputThunk( value=output["generated_text"], + sec_level=sec_level, meta={ "oai_completion_response": response["results"][0], "usage": { diff --git a/mellea/core/base.py b/mellea/core/base.py index 41179bfe..ad49d6e4 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -17,6 +17,8 @@ import typing_extensions from PIL import Image as PILImage +from ..security import SecLevel, TaintChecking + class CBlock: """A `CBlock` is a block of content that can serve as input to or output from an LLM.""" @@ -25,6 +27,7 @@ def __init__( self, value: str | None, meta: dict[str, Any] | None = None, + sec_level: Any = None, *, cache: bool = False, ): @@ -33,6 +36,7 @@ def __init__( Args: value: the underlying value stored in this CBlock meta: Any meta-information about this CBlock (e.g., the inference engine's Completion object). + sec_level: Optional SecLevel for security metadata cache: If set to `True` then this CBlock's KV cache might be stored by the inference engine. Experimental. """ if value is not None and not isinstance(value, str): @@ -43,6 +47,9 @@ def __init__( meta = {} self._meta = meta + # Store security level directly + self._sec_level: SecLevel | None = sec_level + @property def value(self) -> str | None: """Gets the value of the block.""" @@ -61,6 +68,15 @@ def __repr__(self): """Provides a python-parsable representation of the block (usually).""" return f"CBlock({self.value}, {self._meta.__repr__()})" + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this CBlock. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level + class ImageBlock: """A `ImageBlock` represents an image (as base64 PNG).""" @@ -138,7 +154,7 @@ class ComponentParseError(Exception): @runtime_checkable -class Component(Protocol, Generic[S]): +class Component(TaintChecking, Protocol, Generic[S]): """A `Component` is a composite data structure that is intended to be represented to an LLM.""" def parts(self) -> list[Component | CBlock]: @@ -152,6 +168,15 @@ def format_for_llm(self) -> TemplateRepresentation | str: """ raise NotImplementedError("format_for_llm isn't implemented by default") + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + ... + def parse(self, computed: ModelOutputThunk) -> S: """Parse the expected type from a given `ModelOutputThunk`. @@ -184,9 +209,10 @@ def __init__( meta: dict[str, Any] | None = None, parsed_repr: S | None = None, tool_calls: dict[str, ModelToolCall] | None = None, + sec_level: Any = None, ): """Initializes as a cblock, optionally also with a parsed representation from an output formatter.""" - super().__init__(value, meta) + super().__init__(value, meta, sec_level=sec_level) self.parsed_repr: S | None = parsed_repr """Will be non-`None` once computed.""" diff --git a/mellea/core/requirement.py b/mellea/core/requirement.py index 9162c1fc..d0f61267 100644 --- a/mellea/core/requirement.py +++ b/mellea/core/requirement.py @@ -4,6 +4,7 @@ from collections.abc import Callable from copy import copy +from ..security import SecLevel from .backend import Backend, BaseModelSubclass from .base import CBlock, Component, Context, ModelOutputThunk, TemplateRepresentation @@ -112,6 +113,7 @@ def __init__( # Used for validation. Do not manually populate. self._output: str | None = None + self._sec_level: SecLevel | None = None async def validate( self, @@ -149,6 +151,15 @@ async def validate( context=val_ctx, ) + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level + def parts(self): """Returns all of the constituent parts of a Requirement.""" return [] diff --git a/mellea/security/__init__.py b/mellea/security/__init__.py new file mode 100644 index 00000000..8b4acb2f --- /dev/null +++ b/mellea/security/__init__.py @@ -0,0 +1,25 @@ +"""Security module for mellea. + +This module provides security features for tracking and managing the security +level of content blocks and components in the mellea library. +""" + +from .core import ( + AccessType, + SecLevel, + SecurityError, + TaintChecking, + declassify, + privileged, + taint_sources, +) + +__all__ = [ + "AccessType", + "SecLevel", + "SecurityError", + "TaintChecking", + "declassify", + "privileged", + "taint_sources", +] diff --git a/mellea/security/core.py b/mellea/security/core.py new file mode 100644 index 00000000..48a6aeab --- /dev/null +++ b/mellea/security/core.py @@ -0,0 +1,337 @@ +"""Core security functionality for mellea. + +This module provides the fundamental security classes and functions for +tracking security levels of content blocks and enforcing security policies. +""" + +import abc +import functools +from collections.abc import Callable +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Protocol, + TypeVar, + Union, + runtime_checkable, +) + +if TYPE_CHECKING: + from ..core.base import CBlock, Component + +T = TypeVar("T") + + +class SecLevelType(str, Enum): + """Security level type constants.""" + + NONE = "none" + CLASSIFIED = "classified" + TAINTED_BY = "tainted_by" + + +class AccessType(Generic[T], abc.ABC): + """Abstract base class for access-based security. + + This trait allows integration with IAM systems and provides fine-grained + access control based on entitlements rather than coarse security levels. + """ + + @abc.abstractmethod + def has_access(self, entitlement: T | None) -> bool: + """Check if the given entitlement has access. + + Args: + entitlement: The entitlement to check (e.g., user role, IAM identifier) + + Returns: + True if the entitlement has access, False otherwise + """ + + +class SecLevel(Generic[T]): + """Security level with access-based control and taint tracking. + + SecLevel := None | Classified of AccessType | TaintedBy of (list[CBlock | Component] | None) + """ + + def __init__(self, level_type: SecLevelType | str, data: Any = None): + """Initialize security level. + + Args: + level_type: Type of security level (SecLevelType enum or string) + data: Associated data (AccessType for classified, list[CBlock|Component] for tainted_by) + """ + # Convert string to enum if needed for backward compatibility + if isinstance(level_type, str): + level_type = SecLevelType(level_type) + self.level_type = level_type + self.data = data + + @classmethod + def none(cls) -> "SecLevel": + """Create a SecLevel with no restrictions (safe).""" + return cls(SecLevelType.NONE) + + @classmethod + def classified(cls, access_type: AccessType[T]) -> "SecLevel": + """Create a SecLevel with classified access requirements.""" + return cls(SecLevelType.CLASSIFIED, access_type) + + @classmethod + def tainted_by( + cls, sources: "CBlock | Component | list[CBlock | Component] | None" + ) -> "SecLevel": + """Create a SecLevel tainted by one or more CBlocks or Components. + + Args: + sources: A single CBlock/Component, a list of CBlocks/Components, or None for root nodes. + If a single source is provided, it will be converted to a list internally. + + Returns: + SecLevel with TAINTED_BY type + """ + # Normalize to list: convert single source to list, None to empty list + if sources is None: + sources_list: list[CBlock | Component] = [] + elif isinstance(sources, list): + sources_list = sources + else: + sources_list = [sources] + + return cls(SecLevelType.TAINTED_BY, sources_list) + + def is_tainted(self) -> bool: + """Check if this security level represents tainted content. + + Returns: + True if tainted, False otherwise + """ + return self.level_type == SecLevelType.TAINTED_BY + + def is_classified(self) -> bool: + """Check if this security level represents classified content. + + Returns: + True if classified, False otherwise + """ + return self.level_type == SecLevelType.CLASSIFIED + + def get_access_type(self) -> AccessType[T] | None: + """Get the AccessType for classified content. + + Returns: + The AccessType if this is classified, None otherwise + """ + if self.level_type == SecLevelType.CLASSIFIED: + return self.data + return None + + def get_taint_sources(self) -> "list[CBlock | Component]": + """Get all sources of taint if this is a tainted level. + + Returns: + List of CBlocks or Components that tainted this content, empty list if not tainted + """ + if self.level_type == SecLevelType.TAINTED_BY: + if isinstance(self.data, list): + return self.data + # Handle legacy single-source format (shouldn't happen in new code) + return [self.data] if self.data is not None else [] + return [] + + +class SecurityError(Exception): + """Exception raised for security-related errors.""" + + +@runtime_checkable +class TaintChecking(Protocol): + """Protocol for objects that can provide security level information. + + This protocol allows uniform access to security levels without + relying on hasattr checks or _meta dictionary access. + """ + + @property + def sec_level(self) -> "SecLevel | None": + """Get the security level for this object. + + Returns: + SecLevel if present, None otherwise + """ + ... + + +def taint_sources(action: "Component | CBlock", ctx: Any) -> "list[CBlock | Component]": + """Compute taint sources from action and context. + + This function examines the action and context to determine what + security sources might be present. It performs recursive analysis + of Component parts and shallow analysis of context to identify + potential taint sources and returns the actual objects that are tainted. + + Args: + action: The action component or content block + ctx: The context containing previous interactions + + Returns: + List of tainted CBlocks or Components + """ + from ..core.base import ( + CBlock, + Component, + ) # Import here to avoid circular dependency + + sources = [] + + # Check if action has security level and is tainted + if isinstance(action, TaintChecking): + sec_level = action.sec_level + if sec_level is not None and sec_level.is_tainted(): + sources.append(action) + + # For Components, check their constituent parts for taint + # Use pattern matching: CBlock doesn't have parts, Components do + match action: + case CBlock(): + # CBlock doesn't have parts, nothing to do + pass + case _ if isinstance(action, Component): + # Component is @runtime_checkable, so isinstance() works + # If it's a Component, it has parts() method by protocol definition + parts = action.parts() + for part in parts: + # Check if the part itself is tainted + if isinstance(part, TaintChecking): + sec_level = part.sec_level + if sec_level is not None and sec_level.is_tainted(): + sources.append(part) + # Recursively check Component parts for nested taint sources + # (Components can contain other Components with tainted CBlocks) + if isinstance(part, Component): + nested_sources = taint_sources(part, None) + sources.extend(nested_sources) + + # Check context for tainted content (shallow check of recent items, but recursive within each) + if hasattr(ctx, "as_list"): + try: + context_items = ctx.as_list( + last_n_components=5 + ) # Limit to recent items for performance + for item in context_items: + # Recursively check each context item (same as action check) + # Only append if item is actually a CBlock or Component (not just TaintChecking) + if isinstance(item, CBlock | Component) and isinstance( + item, TaintChecking + ): + sec_level = item.sec_level + if sec_level is not None and sec_level.is_tainted(): + sources.append(item) + # Recursively check Component parts in context items + if isinstance(item, Component): + nested_sources = taint_sources(item, None) + sources.extend(nested_sources) + except Exception: + # If context analysis fails, continue without it + pass + + return sources + + +F = TypeVar("F", bound=Callable[..., Any]) + + +def privileged(func: F) -> F: + """Decorator to mark functions that require safe (non-tainted, non-classified) input. + + Functions decorated with @privileged will raise SecurityError if + called with tainted or classified content blocks. + + Args: + func: The function to decorate + + Returns: + The decorated function + + Raises: + SecurityError: If the function is called with tainted or classified content + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Check all arguments for marked content (tainted or classified) + for arg in args: + if isinstance(arg, TaintChecking): + sec_level = arg.sec_level + if sec_level is not None: + if sec_level.is_tainted(): + taint_sources = sec_level.get_taint_sources() + if taint_sources: + source_names = ", ".join( + type(s).__name__ for s in taint_sources + ) + source_info = f" (tainted by: {source_names})" + else: + source_info = "" + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"tainted content{source_info}" + ) + elif sec_level.is_classified(): + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"classified content" + ) + + # Check keyword arguments for marked content (tainted or classified) + for key, value in kwargs.items(): + if isinstance(value, TaintChecking): + sec_level = value.sec_level + if sec_level is not None: + if sec_level.is_tainted(): + taint_sources = sec_level.get_taint_sources() + if taint_sources: + source_names = ", ".join( + type(s).__name__ for s in taint_sources + ) + source_info = f" (tainted by: {source_names})" + else: + source_info = "" + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"tainted content in argument '{key}'{source_info}" + ) + elif sec_level.is_classified(): + raise SecurityError( + f"Function {func.__name__} requires safe input, but received " + f"classified content in argument '{key}'" + ) + + return func(*args, **kwargs) + + return wrapper # type: ignore + + +def declassify(cblock: "CBlock") -> "CBlock": + """Create a declassified version of a CBlock (non-mutating). + + This function creates a new CBlock with the same content but marked + as safe (SecLevel.none()). The original CBlock is not modified. + + Args: + cblock: The CBlock to declassify + + Returns: + A new CBlock with safe security level + """ + from ..core.base import CBlock # Import here to avoid circular dependency + + # Return new CBlock with same content but safe security metadata + return CBlock( + cblock.value, + cblock._meta.copy() if cblock._meta else None, + sec_level=SecLevel.none(), + ) diff --git a/mellea/stdlib/components/chat.py b/mellea/stdlib/components/chat.py index 8763a70b..99ab8c9b 100644 --- a/mellea/stdlib/components/chat.py +++ b/mellea/stdlib/components/chat.py @@ -12,6 +12,7 @@ ModelToolCall, TemplateRepresentation, ) +from ...security import SecLevel from .docs.document import Document @@ -49,6 +50,7 @@ def __init__( if self._images is not None: self._images_cblocks = [CBlock(str(i)) for i in self._images] self._docs = documents + self._sec_level: SecLevel | None = None @property def images(self) -> None | list[str]: @@ -57,6 +59,15 @@ def images(self) -> None | list[str]: return [str(i.value) for i in self._images_cblocks] return None + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level + def parts(self) -> list[Component | CBlock]: """Returns all of the constituent parts of an Instruction.""" parts: list[Component | CBlock] = [self._content_cblock] diff --git a/mellea/stdlib/components/docs/document.py b/mellea/stdlib/components/docs/document.py index 577a6639..d59dc3b7 100644 --- a/mellea/stdlib/components/docs/document.py +++ b/mellea/stdlib/components/docs/document.py @@ -1,6 +1,7 @@ """Document component.""" from ....core import CBlock, Component, ModelOutputThunk +from ....security import SecLevel # TODO: Add support for passing in docs as model options. @@ -12,6 +13,12 @@ def __init__(self, text: str, title: str | None = None, doc_id: str | None = Non self.text = text self.title = title self.doc_id = doc_id + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component.""" + return self._sec_level def parts(self) -> list[Component | CBlock]: """The set of all the constituent parts of the `Component`.""" diff --git a/mellea/stdlib/components/docs/richdocument.py b/mellea/stdlib/components/docs/richdocument.py index 75bcb60c..b8460a3f 100644 --- a/mellea/stdlib/components/docs/richdocument.py +++ b/mellea/stdlib/components/docs/richdocument.py @@ -12,6 +12,7 @@ from docling_core.types.io import DocumentStream from ....core import CBlock, Component, ModelOutputThunk, TemplateRepresentation +from ....security import SecLevel from ..mobject import MObject, Query, Transform @@ -24,6 +25,16 @@ class RichDocument(Component[str]): def __init__(self, doc: DoclingDocument): """A `RichDocument` is a block of content with an underlying DoclingDocument.""" self._doc = doc + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level def parts(self) -> list[Component | CBlock]: """RichDocument has no parts. diff --git a/mellea/stdlib/components/genslot.py b/mellea/stdlib/components/genslot.py index eff9ae75..751edaa8 100644 --- a/mellea/stdlib/components/genslot.py +++ b/mellea/stdlib/components/genslot.py @@ -24,6 +24,7 @@ TemplateRepresentation, ValidationResult, ) +from ...security import SecLevel from ..requirements.requirement import reqify from ..session import MelleaSession @@ -289,6 +290,16 @@ def __init__(self, func: Callable[P, R]): # Set when calling the decorated func. self.precondition_requirements: list[Requirement] = [] self.requirements: list[Requirement] = [] + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level @abc.abstractmethod def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: diff --git a/mellea/stdlib/components/instruction.py b/mellea/stdlib/components/instruction.py index 32a8a0dc..00716f93 100644 --- a/mellea/stdlib/components/instruction.py +++ b/mellea/stdlib/components/instruction.py @@ -15,6 +15,7 @@ TemplateRepresentation, blockify, ) +from ...security import SecLevel from ..requirements.requirement import reqify @@ -125,6 +126,12 @@ def __init__( ) self._images = images self._repair_string: str | None = None + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component.""" + return self._sec_level def parts(self): """Returns all of the constituent parts of an Instruction.""" diff --git a/mellea/stdlib/components/intrinsic/intrinsic.py b/mellea/stdlib/components/intrinsic/intrinsic.py index c12fa54f..1731bca3 100644 --- a/mellea/stdlib/components/intrinsic/intrinsic.py +++ b/mellea/stdlib/components/intrinsic/intrinsic.py @@ -2,6 +2,7 @@ from ....backends.adapters import AdapterType, fetch_intrinsic_metadata from ....core import CBlock, Component, ModelOutputThunk, TemplateRepresentation +from ....security import SecLevel class Intrinsic(Component[str]): @@ -30,6 +31,16 @@ def __init__( if intrinsic_kwargs is None: intrinsic_kwargs = {} self.intrinsic_kwargs = intrinsic_kwargs + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level @property def intrinsic_name(self): diff --git a/mellea/stdlib/components/mify.py b/mellea/stdlib/components/mify.py index 8bbfcb35..4809dfe9 100644 --- a/mellea/stdlib/components/mify.py +++ b/mellea/stdlib/components/mify.py @@ -12,6 +12,7 @@ ModelOutputThunk, TemplateRepresentation, ) +from ...security import SecLevel from .mobject import MObjectProtocol, Query, Transform @@ -215,6 +216,15 @@ def parse(self, computed: ModelOutputThunk) -> str: except Exception as e: raise ComponentParseError(f"component parsing failed: {e}") + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return getattr(self, "_sec_level", None) + T = TypeVar("T") @@ -345,6 +355,16 @@ def mification(obj: T) -> T: # For objects, have to specifically bind methods. setattr(obj, name, types.MethodType(func, obj)) + # Add properties from MifiedProtocol (properties are descriptors, not methods) + # Create sec_level property directly to ensure Component protocol compliance + if "sec_level" not in current_members.keys(): + # Create a property that returns _sec_level attribute + sec_level_prop = property( + lambda self: getattr(self, "_sec_level", None), + doc="Get the security level for this Component.", + ) + setattr(obj, "sec_level", sec_level_prop) + # Set the defaults for the object/class. setattr(obj, "_query_type", query_type) setattr(obj, "_transform_type", transform_type) diff --git a/mellea/stdlib/components/mobject.py b/mellea/stdlib/components/mobject.py index 3ab48c04..938c102a 100644 --- a/mellea/stdlib/components/mobject.py +++ b/mellea/stdlib/components/mobject.py @@ -7,6 +7,7 @@ from typing import Protocol, runtime_checkable from ...core import CBlock, Component, ModelOutputThunk, TemplateRepresentation +from ...security import SecLevel class Query(Component[str]): @@ -21,6 +22,16 @@ def __init__(self, obj: Component, query: str) -> None: """ self._obj = obj self._query = query + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level def parts(self) -> list[Component | CBlock]: """Get the parts of the query.""" @@ -65,6 +76,16 @@ def __init__(self, obj: Component, transformation: str) -> None: """ self._obj = obj self._transformation = transformation + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level def parts(self) -> list[Component | CBlock]: """Get the parts of the transform.""" @@ -164,6 +185,16 @@ def __init__( """ self._query_type = query_type self._transform_type = transform_type + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level def parts(self) -> list[Component | CBlock]: """MObject has no parts because of how format_for_llm is defined.""" diff --git a/mellea/stdlib/components/simple.py b/mellea/stdlib/components/simple.py index 2d0f7dcc..de05ffea 100644 --- a/mellea/stdlib/components/simple.py +++ b/mellea/stdlib/components/simple.py @@ -1,6 +1,7 @@ """SimpleComponent.""" from ...core import CBlock, Component, ModelOutputThunk +from ...security import SecLevel class SimpleComponent(Component[str]): @@ -13,6 +14,12 @@ def __init__(self, **kwargs): kwargs[key] = CBlock(value=kwargs[key]) self._kwargs_type_check(kwargs) self._kwargs = kwargs + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component.""" + return self._sec_level def parts(self): """Returns the values of the kwargs.""" @@ -21,9 +28,9 @@ def parts(self): def _kwargs_type_check(self, kwargs): for key in kwargs.keys(): value = kwargs[key] - assert issubclass(type(value), Component) or issubclass( - type(value), CBlock - ), f"Expected span but found {type(value)} of value: {value}" + assert isinstance(value, Component) or isinstance(value, CBlock), ( + f"Expected span but found {type(value)} of value: {value}" + ) assert type(key) is str return True diff --git a/mellea/stdlib/components/test_based_eval.py b/mellea/stdlib/components/test_based_eval.py index f7f4bf6e..31f3572a 100644 --- a/mellea/stdlib/components/test_based_eval.py +++ b/mellea/stdlib/components/test_based_eval.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from ...core import CBlock, Component, ModelOutputThunk, TemplateRepresentation +from ...security import SecLevel class Message(BaseModel): @@ -63,6 +64,16 @@ def __init__( self.targets = targets or [] self.test_id = test_id self.input_ids = input_ids or [] + self._sec_level: SecLevel | None = None + + @property + def sec_level(self) -> SecLevel | None: + """Get the security level for this Component. + + Returns: + SecLevel if present, None otherwise + """ + return self._sec_level def parts(self) -> list[Component | CBlock]: """The set of constituent parts of the Component.""" diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index 36fe1ca0..0c2c827c 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -110,7 +110,7 @@ def act( @overload def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, @@ -131,7 +131,7 @@ def instruct( @overload def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, @@ -151,7 +151,7 @@ def instruct( def instruct( - description: str, + description: str | CBlock, context: Context, backend: Backend, *, diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 0c1df16f..f021bfec 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -330,7 +330,7 @@ def act( @overload def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -349,7 +349,7 @@ def instruct( @overload def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -367,7 +367,7 @@ def instruct( def instruct( self, - description: str, + description: str | CBlock, *, images: list[ImageBlock] | list[PILImage.Image] | None = None, requirements: list[Requirement | str] | None = None, @@ -385,7 +385,7 @@ def instruct( """Generates from an instruction. Args: - description: The description of the instruction. + description: The description of the instruction (str or CBlock). requirements: A list of requirements that the instruction can be validated against. icl_examples: A list of in-context-learning examples that the instruction can be validated against. grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. diff --git a/test/stdlib/test_security_comprehensive.py b/test/stdlib/test_security_comprehensive.py new file mode 100644 index 00000000..40c1af30 --- /dev/null +++ b/test/stdlib/test_security_comprehensive.py @@ -0,0 +1,466 @@ +"""Comprehensive security tests for mellea thread security features.""" + +import pytest +from mellea.stdlib.components import CBlock, ModelOutputThunk, SimpleComponent +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components.instruction import Instruction +from mellea.security import ( + AccessType, + SecLevel, + SecurityError, + privileged, + declassify, + taint_sources, +) + + +class TestAccessType: + """Test AccessType functionality.""" + + def test_access_type_interface(self): + """Test that AccessType is an abstract base class.""" + with pytest.raises(TypeError): + AccessType() # Should not be instantiable directly + + def test_access_type_implementation(self): + """Test implementing AccessType.""" + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + assert access.has_access("admin") + assert not access.has_access("user") + assert not access.has_access(None) + + +class TestSecLevel: + """Test SecLevel functionality.""" + + def test_sec_level_none(self): + """Test SecLevel.none() creates safe level.""" + from mellea.security.core import SecLevelType + + sec_level = SecLevel.none() + assert sec_level.level_type == SecLevelType.NONE + assert not sec_level.is_tainted() + assert not sec_level.is_classified() + assert sec_level.get_access_type() is None + + def test_sec_level_tainted_by(self): + """Test SecLevel.tainted_by() creates tainted level.""" + from mellea.security.core import SecLevelType + + source = CBlock("source content") + sec_level = SecLevel.tainted_by(source) + assert sec_level.level_type == SecLevelType.TAINTED_BY + assert sec_level.is_tainted() + assert not sec_level.is_classified() + assert sec_level.get_taint_sources() == [source] + assert sec_level.get_access_type() is None + + def test_sec_level_classified(self): + """Test SecLevel.classified() creates classified level.""" + from mellea.security.core import SecLevelType + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + assert sec_level.level_type == SecLevelType.CLASSIFIED + assert not sec_level.is_tainted() + assert sec_level.is_classified() + assert sec_level.get_access_type() is access + assert sec_level.get_access_type().has_access("admin") + assert not sec_level.get_access_type().has_access("user") + assert not sec_level.get_access_type().has_access(None) + + +class TestCBlockSecurity: + """Test CBlock security functionality.""" + + def test_cblock_mark_tainted(self): + """Test marking CBlock as tainted.""" + cblock = CBlock("test content", sec_level=SecLevel.tainted_by(None)) + + assert cblock.sec_level is not None + assert cblock.sec_level.is_tainted() + assert not cblock.sec_level.is_classified() + assert cblock.sec_level.get_access_type() is None + + def test_cblock_mark_tainted_by_source(self): + """Test marking CBlock as tainted by another source.""" + source = CBlock("source content") + cblock = CBlock("test content", sec_level=SecLevel.tainted_by(source)) + + assert cblock.sec_level.is_tainted() + assert cblock.sec_level.get_taint_sources() == [source] + + def test_cblock_default_safe(self): + """Test that CBlock defaults to safe when no security metadata.""" + cblock = CBlock("test content") + assert cblock.sec_level is None or ( + not cblock.sec_level.is_tainted() and not cblock.sec_level.is_classified() + ) + + def test_cblock_with_classified_metadata(self): + """Test CBlock with classified security metadata.""" + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + + cblock = CBlock("classified content", sec_level=sec_level) + + assert cblock.sec_level.is_classified() + access_type = cblock.sec_level.get_access_type() + assert access_type is not None + assert access_type.has_access("admin") + assert not access_type.has_access("user") + assert not access_type.has_access(None) + + +class TestDeclassify: + """Test declassify function.""" + + def test_declassify_creates_new_object(self): + """Test that declassify creates a new object without mutating original.""" + from mellea.security.core import SecLevelType + + original = CBlock("test content", sec_level=SecLevel.tainted_by(None)) + + declassified = declassify(original) + + # Objects are different + assert original is not declassified + assert id(original) != id(declassified) + + # Content is preserved + assert original.value == declassified.value + + # Security levels are different + assert original.sec_level.is_tainted() + assert not declassified.sec_level.is_tainted() + assert not declassified.sec_level.is_classified() + assert declassified.sec_level.level_type == SecLevelType.NONE + + # Original is unchanged + assert original.sec_level.is_tainted() + + def test_declassify_preserves_other_metadata(self): + """Test that declassify preserves other metadata.""" + from mellea.security.core import SecLevelType + + original = CBlock( + "test content", + meta={"custom": "value", "other": 123}, + sec_level=SecLevel.tainted_by(None), + ) + + declassified = declassify(original) + + assert declassified._meta["custom"] == "value" + assert declassified._meta["other"] == 123 + assert declassified.sec_level.level_type == SecLevelType.NONE + + +class TestPrivilegedDecorator: + """Test @privileged decorator functionality.""" + + def test_privileged_accepts_safe_input(self): + """Test that privileged functions accept safe input.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + # CBlock with no security metadata defaults to safe + safe_cblock = CBlock("safe content") + + result = safe_function(safe_cblock) + assert result == "Processed: safe content" + + def test_privileged_accepts_declassified_input(self): + """Test that privileged functions accept declassified input.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + tainted_cblock = CBlock("tainted content", sec_level=SecLevel.tainted_by(None)) + declassified_cblock = declassify(tainted_cblock) + + result = safe_function(declassified_cblock) + assert result == "Processed: tainted content" + + def test_privileged_rejects_tainted_input(self): + """Test that privileged functions reject tainted input.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + tainted_cblock = CBlock("tainted content", sec_level=SecLevel.tainted_by(None)) + + with pytest.raises(SecurityError, match="requires safe input"): + safe_function(tainted_cblock) + + def test_privileged_rejects_classified_input(self): + """Test that privileged functions reject classified input without proper entitlement.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + class TestAccess(AccessType[str]): + def has_access(self, entitlement: str | None) -> bool: + return entitlement == "admin" + + access = TestAccess() + sec_level = SecLevel.classified(access) + + classified_cblock = CBlock("classified content", sec_level=sec_level) + + with pytest.raises(SecurityError, match="requires safe input"): + safe_function(classified_cblock) + + def test_privileged_accepts_no_security_metadata(self): + """Test that privileged functions accept input with no security metadata.""" + + @privileged + def safe_function(cblock: CBlock) -> str: + return f"Processed: {cblock.value}" + + # CBlock with no security metadata defaults to safe + cblock = CBlock("content") + + result = safe_function(cblock) + assert result == "Processed: content" + + def test_privileged_with_kwargs(self): + """Test privileged function with keyword arguments.""" + + @privileged + def safe_function(data: CBlock, prefix: str = "Processed: ") -> str: + return f"{prefix}{data.value}" + + tainted_cblock = CBlock("tainted content", sec_level=SecLevel.tainted_by(None)) + + with pytest.raises(SecurityError, match="argument 'data'"): + safe_function(data=tainted_cblock) + + +class TestTaintSources: + """Test taint source computation.""" + + def test_taint_sources_from_tainted_action(self): + """Test taint sources from tainted action.""" + action = CBlock("tainted action", sec_level=SecLevel.tainted_by(None)) + + sources = taint_sources(action, None) + assert len(sources) == 1 + assert sources[0] is action + + def test_taint_sources_from_safe_action(self): + """Test taint sources from safe action.""" + action = CBlock("safe action") + # No security metadata - defaults to safe + + sources = taint_sources(action, None) + assert len(sources) == 0 + + def test_taint_sources_from_context(self): + """Test taint sources from context.""" + action = CBlock("safe action") + + # Create context with tainted content + ctx = ChatContext() + tainted_cblock = CBlock("tainted context", sec_level=SecLevel.tainted_by(None)) + ctx = ctx.add(tainted_cblock) + + sources = taint_sources(action, ctx) + assert len(sources) == 1 + assert sources[0] is tainted_cblock + + def test_taint_sources_empty(self): + """Test taint sources with no tainted content.""" + action = CBlock("safe action") + ctx = ChatContext() + safe_cblock = CBlock("safe context") + # No security metadata - defaults to safe + ctx = ctx.add(safe_cblock) + + sources = taint_sources(action, ctx) + assert len(sources) == 0 + + def test_taint_sources_from_component_parts(self): + """Test taint sources from Component parts.""" + # Create Instruction with tainted description + tainted_desc = CBlock( + "tainted description", sec_level=SecLevel.tainted_by(None) + ) + instruction = Instruction(description=tainted_desc) + + sources = taint_sources(instruction, None) + assert len(sources) == 1 + assert sources[0] is tainted_desc + + def test_taint_sources_from_nested_component_with_tainted_cblocks(self): + """Test taint sources from nested Components containing tainted CBlocks.""" + # Create tainted CBlocks + tainted_data = CBlock( + "sensitive user data", sec_level=SecLevel.tainted_by(None) + ) + tainted_config = CBlock("secret config", sec_level=SecLevel.tainted_by(None)) + safe_info = CBlock("public info") # Safe CBlock + + # Create a SimpleComponent with mixed tainted and safe CBlocks + nested_component = SimpleComponent( + data=tainted_data, config=tainted_config, info=safe_info + ) + + # Create an Instruction with the nested Component in grounding_context + instruction = Instruction( + description="Process the data", + grounding_context={"context": nested_component}, + ) + + # taint_sources should find both tainted CBlocks through the nested Component + sources = taint_sources(instruction, None) + + # Should find both tainted CBlocks + assert len(sources) == 2 + assert tainted_data in sources + assert tainted_config in sources + assert safe_info not in sources # Safe CBlock should not be included + + def test_taint_sources_shallow_search_limit(self): + """Test that shallow search only checks last 5 components.""" + action = CBlock("safe action") + + # Create context with 7 items: tainted at positions 0 and 5 + ctx = ChatContext() + tainted_early = CBlock("tainted early", sec_level=SecLevel.tainted_by(None)) + ctx = ctx.add(tainted_early) # Position 0 - outside last 5 + + # Add 4 safe items + for i in range(4): + ctx = ctx.add(CBlock(f"safe {i}")) + + tainted_late = CBlock("tainted late", sec_level=SecLevel.tainted_by(None)) + ctx = ctx.add(tainted_late) # Position 5 - within last 5 + + # Add one more safe item + ctx = ctx.add(CBlock("safe final")) # Position 6 + + sources = taint_sources(action, ctx) + # Should only find tainted_late (position 5), not tainted_early (position 0) + assert len(sources) == 1 + assert sources[0] is tainted_late + + +class TestModelOutputThunkSecurity: + """Test ModelOutputThunk security functionality.""" + + def test_from_generation_with_taint_sources(self): + """Test ModelOutputThunk creation with taint sources.""" + taint_source = CBlock("taint source", sec_level=SecLevel.tainted_by(None)) + + sec_level = SecLevel.tainted_by([taint_source]) + mot = ModelOutputThunk( + value="generated content", sec_level=sec_level, meta={"custom": "value"} + ) + + assert mot.value == "generated content" + assert mot._meta["custom"] == "value" + assert mot.sec_level is not None + assert mot.sec_level.is_tainted() + assert not mot.sec_level.is_classified() + assert mot.sec_level.get_taint_sources() == [taint_source] + + def test_from_generation_without_taint_sources(self): + """Test ModelOutputThunk creation without taint sources.""" + from mellea.security.core import SecLevelType + + mot = ModelOutputThunk( + value="generated content", + sec_level=SecLevel.none(), + meta={"custom": "value"}, + ) + + assert mot.value == "generated content" + assert mot._meta["custom"] == "value" + assert mot.sec_level is not None + assert mot.sec_level.level_type == SecLevelType.NONE + assert not mot.sec_level.is_tainted() + assert not mot.sec_level.is_classified() + + def test_from_generation_empty_taint_sources(self): + """Test ModelOutputThunk creation with empty taint sources.""" + from mellea.security.core import SecLevelType + + mot = ModelOutputThunk( + value="generated content", + sec_level=SecLevel.none(), + meta={"custom": "value"}, + ) + + assert mot.sec_level.level_type == SecLevelType.NONE + assert not mot.sec_level.is_tainted() + assert not mot.sec_level.is_classified() + + +class TestSecurityIntegration: + """Test integration between security components.""" + + def test_security_flow_through_generation(self): + """Test security metadata flows through generation pipeline.""" + from mellea.security.core import SecLevelType + + # Create tainted input + tainted_input = CBlock("user input", sec_level=SecLevel.tainted_by(None)) + + # Simulate generation with taint sources + sources = taint_sources(tainted_input, None) + sec_level = SecLevel.tainted_by(sources) if sources else SecLevel.none() + mot = ModelOutputThunk(value="model response", sec_level=sec_level) + + # Verify output is tainted + assert mot.sec_level.is_tainted() + + # Declassify the output + safe_mot = declassify(mot) + assert not safe_mot.sec_level.is_tainted() + assert not safe_mot.sec_level.is_classified() + assert safe_mot.sec_level.level_type == SecLevelType.NONE + + # Verify original is unchanged + assert mot.sec_level.is_tainted() + + def test_privileged_function_with_generated_content(self): + """Test privileged function with generated content.""" + + @privileged + def process_response(mot: ModelOutputThunk) -> str: + return f"Processed: {mot.value}" + + # Generate tainted content + taint_source = CBlock("taint source", sec_level=SecLevel.tainted_by(None)) + + sec_level = SecLevel.tainted_by([taint_source]) + mot = ModelOutputThunk(value="tainted response", sec_level=sec_level) + + # Privileged function should reject tainted content + with pytest.raises(SecurityError): + process_response(mot) + + # Declassify and try again + safe_mot = declassify(mot) + result = process_response(safe_mot) + assert result == "Processed: tainted response"