Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,178 @@ def handle_udaf_expression(
return res


def handle_flatten_function(
exp,
input_data: Union[TableEmulator, ColumnEmulator],
analyzer: "MockAnalyzer",
expr_to_alias: Dict[str, str],
join_with_input_columns: bool = True,
) -> TableEmulator:
"""Handle the built-in FLATTEN table function for local testing.

This is a minimal implementation to support explode() and explode_outer()
functions in local testing mode. It is NOT a complete FLATTEN implementation.

Supported parameters:
- input: The array or object column to flatten (required)
- outer: If True, emit NULL row for empty/null inputs (default: False)
- mode: 'ARRAY', 'OBJECT', or 'BOTH' (default: 'BOTH')

NOT supported (will raise NotImplementedError):
- path: Extracting nested paths from VARIANT
- recursive: Recursive flattening of nested structures

Output columns:
- KEY: For objects, the key name; for arrays, NULL
- VALUE: The flattened value (VARIANT type)

Note: Snowflake's full FLATTEN also outputs SEQ, PATH, INDEX, THIS columns
which are not implemented here. If your tests require these columns,
consider using integration tests against a real Snowflake instance.
"""
from snowflake.snowpark._internal.analyzer.table_function import (
FlattenFunction,
NamedArgumentsTableFunction,
)

# Extract parameters based on the expression type
if isinstance(exp, FlattenFunction):
# Direct FlattenFunction has attributes directly
input_expr = exp.input
outer = exp.outer
mode = exp.mode.upper() if exp.mode else "BOTH"
path = exp.path
recursive = exp.recursive

# Validate unsupported parameters
if path and path != "":
raise NotImplementedError(
f"FLATTEN with PATH parameter is not supported in local testing. "
f"Got path='{path}'. Use integration tests for this feature."
)
if recursive:
raise NotImplementedError(
"FLATTEN with RECURSIVE=True is not supported in local testing. "
"Use integration tests for this feature."
)

elif isinstance(exp, NamedArgumentsTableFunction) and exp.func_name.lower() == "flatten":
# NamedArgumentsTableFunction has args dict
args = exp.args
input_expr = args.get("input")
outer_expr = args.get("outer")
mode_expr = args.get("mode")
path_expr = args.get("path")
recursive_expr = args.get("recursive")

# Check for unsupported path parameter
if path_expr is not None:
path_val = calculate_expression(path_expr, input_data, analyzer, expr_to_alias)
if hasattr(path_val, "iloc"):
path_val = path_val.iloc[0] if len(path_val) > 0 else ""
if path_val and path_val != "":
raise NotImplementedError(
f"FLATTEN with PATH parameter is not supported in local testing. "
f"Got path='{path_val}'. Use integration tests for this feature."
)

# Check for unsupported recursive parameter
if recursive_expr is not None:
recursive_val = calculate_expression(recursive_expr, input_data, analyzer, expr_to_alias)
if hasattr(recursive_val, "iloc"):
recursive_val = recursive_val.iloc[0] if len(recursive_val) > 0 else False
if recursive_val:
raise NotImplementedError(
"FLATTEN with RECURSIVE=True is not supported in local testing. "
"Use integration tests for this feature."
)

# outer is a Literal expression, extract its value
if outer_expr is not None:
outer = calculate_expression(outer_expr, input_data, analyzer, expr_to_alias)
if hasattr(outer, "iloc"):
outer = outer.iloc[0] if len(outer) > 0 else False
else:
outer = False

# mode is also a Literal expression
if mode_expr is not None:
mode = calculate_expression(mode_expr, input_data, analyzer, expr_to_alias)
if hasattr(mode, "iloc"):
mode = mode.iloc[0] if len(mode) > 0 else "BOTH"
mode = str(mode).upper() if mode else "BOTH"
else:
mode = "BOTH"
else:
raise ValueError(f"Unexpected flatten expression type: {type(exp)}")

# Get the input column to flatten
input_col = calculate_expression(input_expr, input_data, analyzer, expr_to_alias)

result_rows = []
input_col_names = list(input_data.columns) if join_with_input_columns else []

for idx, value in enumerate(input_col):
input_row = input_data.iloc[idx] if join_with_input_columns and len(input_data) > 0 else None

if value is None:
if outer:
# For outer=True, produce a row with NULL values
row_prefix = tuple(input_row.values) if input_row is not None else ()
result_rows.append(row_prefix + (None, None))
continue

# Handle different types
is_array = isinstance(value, (list, tuple))
is_object = isinstance(value, dict)

if is_array and mode in ("ARRAY", "BOTH"):
if len(value) == 0:
if outer:
row_prefix = tuple(input_row.values) if input_row is not None else ()
result_rows.append(row_prefix + (None, None))
else:
for item in value:
row_prefix = tuple(input_row.values) if input_row is not None else ()
# For arrays: KEY is None, VALUE is the item
result_rows.append(row_prefix + (None, item))
elif is_object and mode in ("OBJECT", "BOTH"):
if len(value) == 0:
if outer:
row_prefix = tuple(input_row.values) if input_row is not None else ()
result_rows.append(row_prefix + (None, None))
else:
for k, v in value.items():
row_prefix = tuple(input_row.values) if input_row is not None else ()
# For objects: KEY and VALUE
result_rows.append(row_prefix + (k, v))
elif outer:
# Type doesn't match mode, but outer=True means emit null row
row_prefix = tuple(input_row.values) if input_row is not None else ()
result_rows.append(row_prefix + (None, None))

# Build result columns
output_col_names = ["KEY", "VALUE"]
all_col_names = input_col_names + output_col_names

if result_rows:
result_df = TableEmulator(result_rows, columns=all_col_names)
else:
result_df = TableEmulator(columns=all_col_names)

# Set up sf_types
from snowflake.snowpark.types import StringType, VariantType

sf_types = {}
if join_with_input_columns and hasattr(input_data, "sf_types"):
sf_types.update(input_data.sf_types)
sf_types["KEY"] = ColumnType(StringType(), True)
sf_types["VALUE"] = ColumnType(VariantType(), True)
result_df.sf_types = sf_types

return result_df


def handle_udtf_expression(
exp: FunctionExpression,
input_data: Union[TableEmulator, ColumnEmulator],
Expand All @@ -858,6 +1030,22 @@ def handle_udtf_expression(
current_row=None,
join_with_input_columns=True,
):
from snowflake.snowpark._internal.analyzer.table_function import (
FlattenFunction,
NamedArgumentsTableFunction,
)

# Handle built-in table functions first
if isinstance(exp, FlattenFunction):
return handle_flatten_function(
exp, input_data, analyzer, expr_to_alias, join_with_input_columns
)

# Handle flatten via NamedArgumentsTableFunction (used by explode)
if isinstance(exp, NamedArgumentsTableFunction) and exp.func_name.lower() == "flatten":
return handle_flatten_function(
exp, input_data, analyzer, expr_to_alias, join_with_input_columns
)

# TODO: handle and support imports + other udtf attributes.

Expand Down
23 changes: 23 additions & 0 deletions src/snowflake/snowpark/mock/_select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,29 @@ def execution_plan(self):
self._execution_plan = MockExecutionPlan(self, self._session)
return self._execution_plan

@property
def snowflake_plan(self):
"""Alias for execution_plan to provide API compatibility with SelectStatement.

Why this alias is safe:
-----------------------
The non-mock SelectStatement.snowflake_plan returns a SnowflakePlan with an
`output` property (List[Attribute]) used for schema inference. Our
MockExecutionPlan already provides this same `output` property with identical
semantics (see MockExecutionPlan.output in _plan.py).

This alias exists specifically to support table_function.py line 327:
plan = select_statement.select([...]).snowflake_plan
explode_col_type = plan.output[0].datatype

If SelectStatement.snowflake_plan gains additional functionality in the future
(lazy evaluation, metadata wrapping, etc.), this alias may need to be updated
to maintain parity. The mock test suite should catch such divergence.

See also: GitHub issue #3565 (SNOW-2213161)
"""
return self.execution_plan

@property
def attributes(self):
return self._attributes or self.execution_plan.attributes
Expand Down
Loading
Loading