Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 65 additions & 26 deletions graflo/architecture/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,32 +161,68 @@ def __post_init__(self):
ValueError: If transform configuration is invalid
"""
super().__post_init__()
self.functional_transform = False
if self._foo is not None:
self.functional_transform = True

self.input = self._tuple_it(self.input)
self.functional_transform = self._foo is not None

# Normalize containers
self.fields = self._tuple_it(self.fields)
self.input = self._tuple_it(self.input)
self.output = self._tuple_it(self.output)

self.input = self.fields if self.fields and not self.input else self.input
# Derive relationships between map, input, output, and fields.
self._init_input_from_fields()
self._init_io_from_map()
self._init_from_switch()
self._default_output_from_input()
self._init_map_from_io()

self._validate_configuration()

def _init_input_from_fields(self) -> None:
"""Populate input from fields when provided."""
if self.fields and not self.input:
self.input = self.fields

def _init_io_from_map(self, force_init=False) -> None:
"""Populate input/output tuples from an explicit map."""
if not self.map:
return
if force_init or (not self.input and not self.output):
input_fields, output_fields = zip(*self.map.items())
self.input = tuple(input_fields)
self.output = tuple(output_fields)
elif not self.input:
self.input = tuple(self.map.keys())
elif not self.output:
self.output = tuple(self.map.values())

def _init_from_switch(self) -> None:
"""Fallback initialization using switch definitions."""
if self.switch and not self.input and not self.output:
self.input = tuple(self.switch)
# We rely on the first switch entry to infer the output shape.
first_key = self.input[0]
self.output = self._tuple_it(self.switch[first_key])

def _default_output_from_input(self) -> None:
"""Ensure output mirrors input when not explicitly provided."""
if not self.output:
self.output = self.input
self.output = self._tuple_it(self.output)

if not self.input and not self.output:
if self.map:
items = list(self.map.items())
self.input = tuple(x for x, _ in items)
self.output = tuple(x for _, x in items)
elif self.switch:
self.input = tuple([k for k in self.switch])
self.output = tuple(self.switch[self.input[0]])
elif not self.name:
raise ValueError(
"Either input and output, fields, map or name should be"
" provided in Transform constructor."
)
def _init_map_from_io(self) -> None:
"""Derive map from input/output when possible."""
if self.map or not self.input or not self.output:
return
if len(self.input) != len(self.output):
return
self.map = {src: dst for src, dst in zip(self.input, self.output)}

def _validate_configuration(self) -> None:
"""Validate that the transform has enough information to operate."""
if not self.input and not self.output and not self.name:
raise ValueError(
"Either input/output, fields, map or name must be provided in Transform "
"constructor."
)

def __call__(self, *nargs, **kwargs):
"""Execute the transform.
Expand All @@ -198,9 +234,7 @@ def __call__(self, *nargs, **kwargs):
Returns:
dict: Transformed data
"""
is_mapping = self._foo is None

if is_mapping:
if self.is_mapping:
input_doc = nargs[0]
if isinstance(input_doc, dict):
output_values = [input_doc[k] for k in self.input]
Expand All @@ -219,7 +253,12 @@ def __call__(self, *nargs, **kwargs):
r = output_values
return r

def _dress_as_dict(self, transform_result):
@property
def is_mapping(self) -> bool:
"""True when the transform is pure mapping (no function)."""
return self._foo is None

def _dress_as_dict(self, transform_result) -> dict[str, Any]:
"""Convert transform result to dictionary format.

Args:
Expand All @@ -238,15 +277,15 @@ def _dress_as_dict(self, transform_result):
return upd

@property
def is_dummy(self):
def is_dummy(self) -> bool:
"""Check if this is a dummy transform.

Returns:
bool: True if this is a dummy transform
"""
return (self.name is not None) and (not self.map and self._foo is None)

def update(self, t: Transform):
def update(self, t: Transform) -> Transform:
"""Update this transform with another transform's configuration.

Args:
Expand Down
Loading
Loading