diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 557d621dd4..1043f09da7 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -2,6 +2,7 @@ import collections import inspect +import dataclasses import typing from copy import deepcopy from enum import Enum @@ -80,18 +81,24 @@ def my_wf(in1: int, in2: int) -> int: :param flyte_interface_types: One side of an :py:class:`flytekit.models.interface.TypedInterface` basically. :param native_types: Map to native Python type. """ - if incoming_values is None: - raise AssertionError("Incoming values cannot be None, must be a dict") + if incoming_values is None: raise ValueError("Incoming values cannot be None, must be a dict") result = {} # So as to not overwrite the input_kwargs for k, v in incoming_values.items(): if k not in flyte_interface_types: - raise AssertionError(f"Received unexpected keyword argument {k}") + raise ValueError(f"Received unexpected keyword argument {k}") var = flyte_interface_types[k] t = native_types[k] try: if type(v) is Promise: v = resolve_attr_path_in_promise(v) + if dataclasses.is_dataclass(v): + # if the value is a dataclass, we need to check that it isn't + # comprised of promises. If it is, we need to resolve them. + for field in dataclasses.fields(v): + if isinstance(getattr(v, field.name), Promise): + setattr(v, field.name, resolve_attr_path_in_promise(getattr(v, field.name))) + result[k] = TypeEngine.to_literal(ctx, v, t, var.type) except TypeTransformerFailedError as exc: raise TypeTransformerFailedError(f"Failed argument '{k}': {exc}") from exc @@ -716,6 +723,21 @@ def __rshift__(self, other: Any): return Output(*promises) # type: ignore +def _collection_contains_promise(maybe_collection: Any) -> bool: + """ + Determine if there's a collection at any depth below the given object that contains a Promise. + """ + if isinstance(maybe_collection, Promise): + return True + if isinstance(maybe_collection, list): + return any(_collection_contains_promise(x) for x in maybe_collection) + if isinstance(maybe_collection, dict): + return any(_collection_contains_promise(x) for x in maybe_collection.values()) + if dataclasses.is_dataclass(maybe_collection): + return any(_collection_contains_promise(x) for x in maybe_collection.__dict__.values()) + return False + + def binding_data_from_python_std( ctx: _flyte_context.FlyteContext, expected_literal_type: _type_models.LiteralType, @@ -768,6 +790,32 @@ def binding_data_from_python_std( return _literals_models.BindingData(collection=collection) + elif dataclasses.is_dataclass(t_value): + # If any of the attributes are a promise or contain a promise, we must + # convert the dataclass to a dictionary and then convert the dictionary, + # otherwise we can convert the dataclass directly + has_promise_attr = _collection_contains_promise(t_value) + if not has_promise_attr: + scalar = TypeEngine.to_literal(ctx, t_value, t_value_type, expected_literal_type).scalar + return _literals_models.BindingData(scalar=scalar) + else: + if ( + expected_literal_type.simple != _type_models.SimpleType.STRUCT or + expected_literal_type.structure is None + ): + raise AssertionError( + f"this should be a Struct type and it is not: {type(t_value)} vs {expected_literal_type}" + ) + m = _literals_models.BindingDataMap( + bindings={ + k: binding_data_from_python_std( + ctx, expected_literal_type.structure.dataclass_type[k], v, type(v), nodes + ) + for k, v in t_value.__dict__.items() + } + ) + return _literals_models.BindingData(map=m) + elif isinstance(t_value, dict): if ( expected_literal_type.map_value_type is None diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 4937703ef4..11693b66d4 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,7 +15,6 @@ from collections import OrderedDict from functools import lru_cache from typing import Dict, List, NamedTuple, Optional, Type, cast - from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 from google.protobuf import json_format as _json_format @@ -121,7 +120,8 @@ def modify_literal_uris(lit: Literal): ) -class TypeTransformerFailedError(TypeError, AssertionError, ValueError): ... +class TypeTransformerFailedError(TypeError, AssertionError, ValueError): + ... class TypeTransformer(typing.Generic[T]): @@ -484,6 +484,25 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp self._serialize_flyte_type(python_val, python_type) + # If there are any promises that are ready to resolve, we should resolve them for each attribute + # to the primitive value before trying to serialize the dataclass + for field in dataclasses.fields(python_val): + # TODO is there something less hacky we can do here? This would + # be a circular import as-is + from flytekit.core.promise import Promise, resolve_attr_path_in_promise + + val = python_val.__getattribute__(field.name) + if isinstance(val, Promise): + promise = val + assert promise.is_ready, f"Promise {promise} is not ready" + import pdb; pdb.set_trace() + if promise.val.scalar.primitive is not None: + val = promise.eval() + else: + # If the promise is not a primitive, we'll need to resolve the path + val = resolve_attr_path_in_promise(promise).eval() + python_val.__setattr__(field.name, val) + # The `to_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`. # It serializes a data class into a JSON string. if hasattr(python_val, "to_json"): @@ -797,10 +816,22 @@ def tag(expected_python_type: Type[T]) -> str: def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)}) + def _handle_list_literal(self, ctx: FlyteContext, elems: list) -> Literal: + if len(elems) == 0: + return Literal(collection=LiteralCollection(literals=[])) + st = type(elems[0]) + lt = TypeEngine.to_literal_type(st) + lits = [TypeEngine.to_literal(ctx, x, st, lt) for x in elems] + return Literal(collection=LiteralCollection(literals=lits)) + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: struct = Struct() try: - struct.update(_MessageToDict(cast(Message, python_val))) + message_dict = _MessageToDict(cast(Message, python_val)) + if isinstance(message_dict, list): + # _MessageToDict will return a `list` on ListValue protobufs + return self._handle_list_literal(ctx, message_dict) + struct.update(message_dict) except Exception: raise TypeTransformerFailedError("Failed to convert to generic protobuf struct") return Literal(scalar=Scalar(generic=struct)) @@ -1554,7 +1585,7 @@ def __init__(self): super().__init__("Typed Union", typing.Union) @staticmethod - def is_optional_type(t: Type) -> bool: + def is_optional_type(t: Type[T]) -> bool: """Return True if `t` is a Union or Optional type.""" return _is_union_type(t) or type(None) in get_args(t)