From 7ab7854982744c43ca9e306d5a905dc9e66a052d Mon Sep 17 00:00:00 2001 From: Josh McGrath Date: Fri, 21 Jun 2024 11:41:19 -0700 Subject: [PATCH 1/3] apply Jack's list patch --- flytekit/core/type_engine.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 4937703ef4..2c42152f5a 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -796,11 +796,22 @@ def tag(expected_python_type: Type[T]) -> str: def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)}) + + def _handle_list_literal(self, ctx: FlyteContext, elems: list) -> Literal: + if len(elems) == 0: + return Literal(collection=LiteralCollection(literals=[])) + st = type(elems[0]) + lt = TypeEngine.to_literal_type(st) + lits = [TypeEngine.to_literal(ctx, x, st, lt) for x in elems] + return Literal(collection=LiteralCollection(literals=lits)) def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: struct = Struct() try: - struct.update(_MessageToDict(cast(Message, python_val))) + message_dict = _MessageToDict(cast(Message, python_val)) + if isinstance(message_dict, list): + return self._handle_list_literal(ctx, message_dict) + struct.update(message_dict) except Exception: raise TypeTransformerFailedError("Failed to convert to generic protobuf struct") return Literal(scalar=Scalar(generic=struct)) From ccbb5a62f385aeb180580e2458a471cd350d6a47 Mon Sep 17 00:00:00 2001 From: Josh McGrath Date: Fri, 21 Jun 2024 13:11:16 -0700 Subject: [PATCH 2/3] some updates for optionals and lists --- flytekit/core/type_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 2c42152f5a..567778c743 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,7 +15,6 @@ from collections import OrderedDict from functools import lru_cache from typing import Dict, List, NamedTuple, Optional, Type, cast - from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 from google.protobuf import json_format as _json_format @@ -153,7 +152,7 @@ def type_assertions_enabled(self) -> bool: return self._type_assertions_enabled def assert_type(self, t: Type[T], v: T): - if not hasattr(t, "__origin__") and not isinstance(v, t): + if not ((get_origin(t) is not None) or isinstance(v, t)): raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}") @abstractmethod From 7ebae5122aeaafef8aaa20221655d3eedb1728f2 Mon Sep 17 00:00:00 2001 From: JackUrb Date: Mon, 24 Jun 2024 12:38:55 -0400 Subject: [PATCH 3/3] First pass at promise logic in dataclasses --- flytekit/core/promise.py | 54 ++++++++++++++++++++++++++++++++++-- flytekit/core/type_engine.py | 29 ++++++++++++++++--- 2 files changed, 76 insertions(+), 7 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 557d621dd4..1043f09da7 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -2,6 +2,7 @@ import collections import inspect +import dataclasses import typing from copy import deepcopy from enum import Enum @@ -80,18 +81,24 @@ def my_wf(in1: int, in2: int) -> int: :param flyte_interface_types: One side of an :py:class:`flytekit.models.interface.TypedInterface` basically. :param native_types: Map to native Python type. """ - if incoming_values is None: - raise AssertionError("Incoming values cannot be None, must be a dict") + if incoming_values is None: raise ValueError("Incoming values cannot be None, must be a dict") result = {} # So as to not overwrite the input_kwargs for k, v in incoming_values.items(): if k not in flyte_interface_types: - raise AssertionError(f"Received unexpected keyword argument {k}") + raise ValueError(f"Received unexpected keyword argument {k}") var = flyte_interface_types[k] t = native_types[k] try: if type(v) is Promise: v = resolve_attr_path_in_promise(v) + if dataclasses.is_dataclass(v): + # if the value is a dataclass, we need to check that it isn't + # comprised of promises. If it is, we need to resolve them. + for field in dataclasses.fields(v): + if isinstance(getattr(v, field.name), Promise): + setattr(v, field.name, resolve_attr_path_in_promise(getattr(v, field.name))) + result[k] = TypeEngine.to_literal(ctx, v, t, var.type) except TypeTransformerFailedError as exc: raise TypeTransformerFailedError(f"Failed argument '{k}': {exc}") from exc @@ -716,6 +723,21 @@ def __rshift__(self, other: Any): return Output(*promises) # type: ignore +def _collection_contains_promise(maybe_collection: Any) -> bool: + """ + Determine if there's a collection at any depth below the given object that contains a Promise. + """ + if isinstance(maybe_collection, Promise): + return True + if isinstance(maybe_collection, list): + return any(_collection_contains_promise(x) for x in maybe_collection) + if isinstance(maybe_collection, dict): + return any(_collection_contains_promise(x) for x in maybe_collection.values()) + if dataclasses.is_dataclass(maybe_collection): + return any(_collection_contains_promise(x) for x in maybe_collection.__dict__.values()) + return False + + def binding_data_from_python_std( ctx: _flyte_context.FlyteContext, expected_literal_type: _type_models.LiteralType, @@ -768,6 +790,32 @@ def binding_data_from_python_std( return _literals_models.BindingData(collection=collection) + elif dataclasses.is_dataclass(t_value): + # If any of the attributes are a promise or contain a promise, we must + # convert the dataclass to a dictionary and then convert the dictionary, + # otherwise we can convert the dataclass directly + has_promise_attr = _collection_contains_promise(t_value) + if not has_promise_attr: + scalar = TypeEngine.to_literal(ctx, t_value, t_value_type, expected_literal_type).scalar + return _literals_models.BindingData(scalar=scalar) + else: + if ( + expected_literal_type.simple != _type_models.SimpleType.STRUCT or + expected_literal_type.structure is None + ): + raise AssertionError( + f"this should be a Struct type and it is not: {type(t_value)} vs {expected_literal_type}" + ) + m = _literals_models.BindingDataMap( + bindings={ + k: binding_data_from_python_std( + ctx, expected_literal_type.structure.dataclass_type[k], v, type(v), nodes + ) + for k, v in t_value.__dict__.items() + } + ) + return _literals_models.BindingData(map=m) + elif isinstance(t_value, dict): if ( expected_literal_type.map_value_type is None diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 567778c743..11693b66d4 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -120,7 +120,8 @@ def modify_literal_uris(lit: Literal): ) -class TypeTransformerFailedError(TypeError, AssertionError, ValueError): ... +class TypeTransformerFailedError(TypeError, AssertionError, ValueError): + ... class TypeTransformer(typing.Generic[T]): @@ -152,7 +153,7 @@ def type_assertions_enabled(self) -> bool: return self._type_assertions_enabled def assert_type(self, t: Type[T], v: T): - if not ((get_origin(t) is not None) or isinstance(v, t)): + if not hasattr(t, "__origin__") and not isinstance(v, t): raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}") @abstractmethod @@ -483,6 +484,25 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp self._serialize_flyte_type(python_val, python_type) + # If there are any promises that are ready to resolve, we should resolve them for each attribute + # to the primitive value before trying to serialize the dataclass + for field in dataclasses.fields(python_val): + # TODO is there something less hacky we can do here? This would + # be a circular import as-is + from flytekit.core.promise import Promise, resolve_attr_path_in_promise + + val = python_val.__getattribute__(field.name) + if isinstance(val, Promise): + promise = val + assert promise.is_ready, f"Promise {promise} is not ready" + import pdb; pdb.set_trace() + if promise.val.scalar.primitive is not None: + val = promise.eval() + else: + # If the promise is not a primitive, we'll need to resolve the path + val = resolve_attr_path_in_promise(promise).eval() + python_val.__setattr__(field.name, val) + # The `to_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`. # It serializes a data class into a JSON string. if hasattr(python_val, "to_json"): @@ -795,7 +815,7 @@ def tag(expected_python_type: Type[T]) -> str: def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)}) - + def _handle_list_literal(self, ctx: FlyteContext, elems: list) -> Literal: if len(elems) == 0: return Literal(collection=LiteralCollection(literals=[])) @@ -809,6 +829,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp try: message_dict = _MessageToDict(cast(Message, python_val)) if isinstance(message_dict, list): + # _MessageToDict will return a `list` on ListValue protobufs return self._handle_list_literal(ctx, message_dict) struct.update(message_dict) except Exception: @@ -1564,7 +1585,7 @@ def __init__(self): super().__init__("Typed Union", typing.Union) @staticmethod - def is_optional_type(t: Type) -> bool: + def is_optional_type(t: Type[T]) -> bool: """Return True if `t` is a Union or Optional type.""" return _is_union_type(t) or type(None) in get_args(t)