Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import inspect
import dataclasses
import typing
from copy import deepcopy
from enum import Enum
Expand Down Expand Up @@ -80,18 +81,24 @@ def my_wf(in1: int, in2: int) -> int:
:param flyte_interface_types: One side of an :py:class:`flytekit.models.interface.TypedInterface` basically.
:param native_types: Map to native Python type.
"""
if incoming_values is None:
raise AssertionError("Incoming values cannot be None, must be a dict")
if incoming_values is None: raise ValueError("Incoming values cannot be None, must be a dict")

result = {} # So as to not overwrite the input_kwargs
for k, v in incoming_values.items():
if k not in flyte_interface_types:
raise AssertionError(f"Received unexpected keyword argument {k}")
raise ValueError(f"Received unexpected keyword argument {k}")
var = flyte_interface_types[k]
t = native_types[k]
try:
if type(v) is Promise:
v = resolve_attr_path_in_promise(v)
if dataclasses.is_dataclass(v):
# if the value is a dataclass, we need to check that it isn't
# comprised of promises. If it is, we need to resolve them.
for field in dataclasses.fields(v):
if isinstance(getattr(v, field.name), Promise):
setattr(v, field.name, resolve_attr_path_in_promise(getattr(v, field.name)))

result[k] = TypeEngine.to_literal(ctx, v, t, var.type)
except TypeTransformerFailedError as exc:
raise TypeTransformerFailedError(f"Failed argument '{k}': {exc}") from exc
Expand Down Expand Up @@ -716,6 +723,21 @@ def __rshift__(self, other: Any):
return Output(*promises) # type: ignore


def _collection_contains_promise(maybe_collection: Any) -> bool:
"""
Determine if there's a collection at any depth below the given object that contains a Promise.
"""
if isinstance(maybe_collection, Promise):
return True
if isinstance(maybe_collection, list):
return any(_collection_contains_promise(x) for x in maybe_collection)
if isinstance(maybe_collection, dict):
return any(_collection_contains_promise(x) for x in maybe_collection.values())
if dataclasses.is_dataclass(maybe_collection):
return any(_collection_contains_promise(x) for x in maybe_collection.__dict__.values())
return False


def binding_data_from_python_std(
ctx: _flyte_context.FlyteContext,
expected_literal_type: _type_models.LiteralType,
Expand Down Expand Up @@ -768,6 +790,32 @@ def binding_data_from_python_std(

return _literals_models.BindingData(collection=collection)

elif dataclasses.is_dataclass(t_value):
# If any of the attributes are a promise or contain a promise, we must
# convert the dataclass to a dictionary and then convert the dictionary,
# otherwise we can convert the dataclass directly
has_promise_attr = _collection_contains_promise(t_value)
if not has_promise_attr:
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type, expected_literal_type).scalar
return _literals_models.BindingData(scalar=scalar)
else:
if (
expected_literal_type.simple != _type_models.SimpleType.STRUCT or
expected_literal_type.structure is None
):
raise AssertionError(
f"this should be a Struct type and it is not: {type(t_value)} vs {expected_literal_type}"
)
m = _literals_models.BindingDataMap(
bindings={
k: binding_data_from_python_std(
ctx, expected_literal_type.structure.dataclass_type[k], v, type(v), nodes
)
for k, v in t_value.__dict__.items()
}
)
return _literals_models.BindingData(map=m)

elif isinstance(t_value, dict):
if (
expected_literal_type.map_value_type is None
Expand Down
39 changes: 35 additions & 4 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from collections import OrderedDict
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Type, cast

from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
Expand Down Expand Up @@ -121,7 +120,8 @@ def modify_literal_uris(lit: Literal):
)


class TypeTransformerFailedError(TypeError, AssertionError, ValueError): ...
class TypeTransformerFailedError(TypeError, AssertionError, ValueError):
...


class TypeTransformer(typing.Generic[T]):
Expand Down Expand Up @@ -484,6 +484,25 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp

self._serialize_flyte_type(python_val, python_type)

# If there are any promises that are ready to resolve, we should resolve them for each attribute
# to the primitive value before trying to serialize the dataclass
for field in dataclasses.fields(python_val):
# TODO is there something less hacky we can do here? This would
# be a circular import as-is
from flytekit.core.promise import Promise, resolve_attr_path_in_promise

val = python_val.__getattribute__(field.name)
if isinstance(val, Promise):
promise = val
assert promise.is_ready, f"Promise {promise} is not ready"
import pdb; pdb.set_trace()
if promise.val.scalar.primitive is not None:
val = promise.eval()
else:
# If the promise is not a primitive, we'll need to resolve the path
val = resolve_attr_path_in_promise(promise).eval()
python_val.__setattr__(field.name, val)

# The `to_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`.
# It serializes a data class into a JSON string.
if hasattr(python_val, "to_json"):
Expand Down Expand Up @@ -797,10 +816,22 @@ def tag(expected_python_type: Type[T]) -> str:
def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)})

def _handle_list_literal(self, ctx: FlyteContext, elems: list) -> Literal:
if len(elems) == 0:
return Literal(collection=LiteralCollection(literals=[]))
st = type(elems[0])
lt = TypeEngine.to_literal_type(st)
lits = [TypeEngine.to_literal(ctx, x, st, lt) for x in elems]
return Literal(collection=LiteralCollection(literals=lits))

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
struct = Struct()
try:
struct.update(_MessageToDict(cast(Message, python_val)))
message_dict = _MessageToDict(cast(Message, python_val))
if isinstance(message_dict, list):
# _MessageToDict will return a `list` on ListValue protobufs
return self._handle_list_literal(ctx, message_dict)
struct.update(message_dict)
except Exception:
raise TypeTransformerFailedError("Failed to convert to generic protobuf struct")
return Literal(scalar=Scalar(generic=struct))
Expand Down Expand Up @@ -1554,7 +1585,7 @@ def __init__(self):
super().__init__("Typed Union", typing.Union)

@staticmethod
def is_optional_type(t: Type) -> bool:
def is_optional_type(t: Type[T]) -> bool:
"""Return True if `t` is a Union or Optional type."""
return _is_union_type(t) or type(None) in get_args(t)

Expand Down