Skip to content
Open
2 changes: 1 addition & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def local_execute(
if len(output_names) == 0:
return VoidPromise(self.name)

vals = [Promise(var, outputs_literals[var]) for var in output_names]
vals = [Promise(var, outputs_literals[var], type=self.interface.outputs[var].type) for var in output_names]
return create_task_output(vals, self.python_interface)

def __call__(self, *args: object, **kwargs: object) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]:
Expand Down
50 changes: 48 additions & 2 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import inspect
import typing
import dataclasses
from copy import deepcopy
from enum import Enum
from typing import Any, Coroutine, Dict, Hashable, List, Optional, Set, Tuple, Union, cast, get_args
Expand Down Expand Up @@ -90,15 +91,31 @@ def my_wf(in1: int, in2: int) -> int:
var = flyte_interface_types[k]
t = native_types[k]
try:
if type(v) is Promise:
v = resolve_attr_path_in_promise(v)
v = resolve_any_nested_promises(v)
result[k] = TypeEngine.to_literal(ctx, v, t, var.type)
except TypeTransformerFailedError as exc:
raise TypeTransformerFailedError(f"Failed argument '{k}': {exc}") from exc

return result


def resolve_any_nested_promises(v: Any):
"""Iterate through v in many forms to resolve any nested promises"""
if isinstance(v, Promise):
return resolve_attr_path_in_promise(v)
if isinstance(v, list):
return [resolve_any_nested_promises(x) for x in v]
if isinstance(v, dict):
return {k: resolve_any_nested_promises(v) for k, v in v.items()}
if isinstance(v, tuple):
return tuple(resolve_any_nested_promises(x) for x in v)
if dataclasses.is_dataclass(v):
# Set the fields of the dataclass to the resolved values
for field in dataclasses.fields(v):
setattr(v, field.name, resolve_any_nested_promises(getattr(v, field.name)))
return v


def resolve_attr_path_in_promise(p: Promise) -> Promise:
"""
resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value
Expand Down Expand Up @@ -141,6 +158,7 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
):
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
new_st = _maybe_fix_deserialized_ints(p, new_st)
literal_type = TypeEngine.to_literal_type(type(new_st))
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
Expand All @@ -149,6 +167,28 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
return p


def _maybe_fix_deserialized_ints(p: Promise, new_st: Any) -> Any:
"""
This function is used to fix the deserialized integers in the promise, in the case where
the promise has a type of int, but the value is deserialized as a float.
"""
if p._type is None:
# No typing, nothing to do
return new_st

if p._type.simple != SimpleType.INTEGER:
# Not an integer, nothing to do
return new_st

if type(new_st) is not int:
if type(new_st) is float:
if int(new_st) == new_st:
return int(new_st)
raise ValueError(f"Resolved value {new_st} is a float, but the promise is an integer")
raise ValueError(f"Resolved value {new_st} is not an integer, but the promise is an integer")
return new_st


def resolve_attr_path_in_pb_struct(st: _struct.Struct, attr_path: List[Union[str, int]]) -> _struct.Struct:
curr_val = st
for attr in attr_path:
Expand Down Expand Up @@ -596,6 +636,12 @@ def _append_attr(self, key) -> Promise:
# The attr_path on the ref node is for remote execute
new_promise._ref = new_promise.ref.with_attr(key)

if self._type is not None:
if self._type.simple == SimpleType.STRUCT and self._type.structure is not None:
# We should specify the type of this node, such that if it's used alone
# it can be resolved correctly.
new_promise._type = self._type.structure.dataclass_type[key]

return new_promise


Expand Down
75 changes: 48 additions & 27 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from types import NoneType
from typing import Dict, List, NamedTuple, Optional, Type, cast

from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
Expand Down Expand Up @@ -149,7 +149,7 @@ def type_assertions_enabled(self) -> bool:
return self._type_assertions_enabled

def assert_type(self, t: Type[T], v: T):
if not hasattr(t, "__origin__") and not isinstance(v, t):
if not ((get_origin(t) is not None) or isinstance(v, t)):
raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}")

@abstractmethod
Expand Down Expand Up @@ -493,22 +493,27 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp

self._make_dataclass_serializable(python_val, python_type)

# The function looks up or creates a JSONEncoder specifically designed for the object's type.
# This encoder is then used to convert a data class into a JSON string.
try:
encoder = self._encoder[python_type]
except KeyError:
encoder = JSONEncoder(python_type)
self._encoder[python_type] = encoder
# The `to_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`.
# It deserializes a JSON string into a data class, and provides additional functionality over JSONEncoder
if hasattr(python_val, "to_json"):
json_str = python_val.to_json()
else:
# The function looks up or creates a JSONEncoder specifically designed for the object's type.
# This encoder is then used to convert a data class into a JSON string.
try:
encoder = self._encoder[python_type]
except KeyError:
encoder = JSONEncoder(python_type)
self._encoder[python_type] = encoder

try:
json_str = encoder.encode(python_val)
except NotImplementedError:
# you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented.
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)
try:
json_str = encoder.encode(python_val)
except NotImplementedError:
# you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented.
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)

return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore

Expand Down Expand Up @@ -652,15 +657,20 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:

json_str = _json_format.MessageToJson(lv.scalar.generic)

# The function looks up or creates a JSONDecoder specifically designed for the object's type.
# This decoder is then used to convert a JSON string into a data class.
try:
decoder = self._decoder[expected_python_type]
except KeyError:
decoder = JSONDecoder(expected_python_type)
self._decoder[expected_python_type] = decoder
# The `from_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`.
# It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder
if hasattr(expected_python_type, "from_json"):
dc = expected_python_type.from_json(json_str) # type: ignore
else:
# The function looks up or creates a JSONDecoder specifically designed for the object's type.
# This decoder is then used to convert a JSON string into a data class.
try:
decoder = self._decoder[expected_python_type]
except KeyError:
decoder = JSONDecoder(expected_python_type)
self._decoder[expected_python_type] = decoder

dc = decoder.decode(json_str)
dc = decoder.decode(json_str)

dc = self._fix_structured_dataset_type(expected_python_type, dc)
return self._fix_dataclass_int(expected_python_type, dc)
Expand Down Expand Up @@ -696,11 +706,22 @@ def tag(expected_python_type: Type[T]) -> str:

def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)})

def _handle_list_literal(self, ctx: FlyteContext, elems: list) -> Literal:
if len(elems) == 0:
return Literal(collection=LiteralCollection(literals=[]))
st = type(elems[0])
lt = TypeEngine.to_literal_type(st)
lits = [TypeEngine.to_literal(ctx, x, st, lt) for x in elems]
return Literal(collection=LiteralCollection(literals=lits))

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
struct = Struct()
try:
struct.update(_MessageToDict(cast(Message, python_val)))
message_dict = _MessageToDict(cast(Message, python_val))
if isinstance(message_dict, list):
return self._handle_list_literal(ctx, message_dict)
struct.update(message_dict)
except Exception:
raise TypeTransformerFailedError("Failed to convert to generic protobuf struct")
return Literal(scalar=Scalar(generic=struct))
Expand Down Expand Up @@ -1051,7 +1072,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
"return v.x, instead of v, even if this has a single element"
)
if python_val is None and expected and expected.union_type is None:
if (python_val is None and python_type != NoneType) and expected and expected.union_type is None:
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
transformer = cls.get_transformer(python_type)
if transformer.type_assertions_enabled:
Expand Down
76 changes: 76 additions & 0 deletions test_dataclass_elem_list_construction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from flytekit import task, dynamic, workflow
from dataclasses import dataclass
from mashumaro.mixins.json import DataClassJSONMixin


@dataclass
class IntWrapper(DataClassJSONMixin):
x: int

@task
def get_int() -> int:
return 3

@task
def get_wrapped_int() -> IntWrapper:
return IntWrapper(x=3)

@task
def sum_list(input_list: list[int]) -> int:
return sum(input_list)


@dataclass
class StrWrapper(DataClassJSONMixin):
x: str

@task
def get_str() -> str:
return "5"

@task
def get_wrapped_str() -> StrWrapper:
return StrWrapper(x="3")

@task
def concat_list(input_list: list[str]) -> str:
return "".join(input_list)



@workflow
def convert_list_workflow1() -> int:
"""Here's a simple workflow that takes a list of strings and returns a dataclass with that list."""
promised_int = get_int()
joined_list = [4, promised_int]
return sum_list(input_list=joined_list)

@workflow
def convert_list_workflow2() -> int:
wrapped_int = get_wrapped_int()
joined_list = [4, wrapped_int.x]
return sum_list(input_list=joined_list)

@workflow
def convert_list_workflow3() -> str:
"""Here's a simple workflow that takes a list of strings and returns a dataclass with that list."""
promised_str = get_str()
joined_list = ["4", promised_str]
return concat_list(input_list=joined_list)

@workflow
def convert_list_workflow4() -> str:
wrapped_str = get_wrapped_str()
joined_list = ["4", wrapped_str.x]
return concat_list(input_list=joined_list)


if __name__ == "__main__":
print("Run 1")
print(convert_list_workflow1())
print("Run 2")
print(convert_list_workflow2())
print("Run 3")
print(convert_list_workflow3())
print("Run 4")
print(convert_list_workflow4())