diff --git a/src/snowflake/snowpark/mock/_plan.py b/src/snowflake/snowpark/mock/_plan.py index a0275600ae..ae04b3580c 100644 --- a/src/snowflake/snowpark/mock/_plan.py +++ b/src/snowflake/snowpark/mock/_plan.py @@ -850,6 +850,178 @@ def handle_udaf_expression( return res +def handle_flatten_function( + exp, + input_data: Union[TableEmulator, ColumnEmulator], + analyzer: "MockAnalyzer", + expr_to_alias: Dict[str, str], + join_with_input_columns: bool = True, +) -> TableEmulator: + """Handle the built-in FLATTEN table function for local testing. + + This is a minimal implementation to support explode() and explode_outer() + functions in local testing mode. It is NOT a complete FLATTEN implementation. + + Supported parameters: + - input: The array or object column to flatten (required) + - outer: If True, emit NULL row for empty/null inputs (default: False) + - mode: 'ARRAY', 'OBJECT', or 'BOTH' (default: 'BOTH') + + NOT supported (will raise NotImplementedError): + - path: Extracting nested paths from VARIANT + - recursive: Recursive flattening of nested structures + + Output columns: + - KEY: For objects, the key name; for arrays, NULL + - VALUE: The flattened value (VARIANT type) + + Note: Snowflake's full FLATTEN also outputs SEQ, PATH, INDEX, THIS columns + which are not implemented here. If your tests require these columns, + consider using integration tests against a real Snowflake instance. + """ + from snowflake.snowpark._internal.analyzer.table_function import ( + FlattenFunction, + NamedArgumentsTableFunction, + ) + + # Extract parameters based on the expression type + if isinstance(exp, FlattenFunction): + # Direct FlattenFunction has attributes directly + input_expr = exp.input + outer = exp.outer + mode = exp.mode.upper() if exp.mode else "BOTH" + path = exp.path + recursive = exp.recursive + + # Validate unsupported parameters + if path and path != "": + raise NotImplementedError( + f"FLATTEN with PATH parameter is not supported in local testing. " + f"Got path='{path}'. Use integration tests for this feature." + ) + if recursive: + raise NotImplementedError( + "FLATTEN with RECURSIVE=True is not supported in local testing. " + "Use integration tests for this feature." + ) + + elif isinstance(exp, NamedArgumentsTableFunction) and exp.func_name.lower() == "flatten": + # NamedArgumentsTableFunction has args dict + args = exp.args + input_expr = args.get("input") + outer_expr = args.get("outer") + mode_expr = args.get("mode") + path_expr = args.get("path") + recursive_expr = args.get("recursive") + + # Check for unsupported path parameter + if path_expr is not None: + path_val = calculate_expression(path_expr, input_data, analyzer, expr_to_alias) + if hasattr(path_val, "iloc"): + path_val = path_val.iloc[0] if len(path_val) > 0 else "" + if path_val and path_val != "": + raise NotImplementedError( + f"FLATTEN with PATH parameter is not supported in local testing. " + f"Got path='{path_val}'. Use integration tests for this feature." + ) + + # Check for unsupported recursive parameter + if recursive_expr is not None: + recursive_val = calculate_expression(recursive_expr, input_data, analyzer, expr_to_alias) + if hasattr(recursive_val, "iloc"): + recursive_val = recursive_val.iloc[0] if len(recursive_val) > 0 else False + if recursive_val: + raise NotImplementedError( + "FLATTEN with RECURSIVE=True is not supported in local testing. " + "Use integration tests for this feature." + ) + + # outer is a Literal expression, extract its value + if outer_expr is not None: + outer = calculate_expression(outer_expr, input_data, analyzer, expr_to_alias) + if hasattr(outer, "iloc"): + outer = outer.iloc[0] if len(outer) > 0 else False + else: + outer = False + + # mode is also a Literal expression + if mode_expr is not None: + mode = calculate_expression(mode_expr, input_data, analyzer, expr_to_alias) + if hasattr(mode, "iloc"): + mode = mode.iloc[0] if len(mode) > 0 else "BOTH" + mode = str(mode).upper() if mode else "BOTH" + else: + mode = "BOTH" + else: + raise ValueError(f"Unexpected flatten expression type: {type(exp)}") + + # Get the input column to flatten + input_col = calculate_expression(input_expr, input_data, analyzer, expr_to_alias) + + result_rows = [] + input_col_names = list(input_data.columns) if join_with_input_columns else [] + + for idx, value in enumerate(input_col): + input_row = input_data.iloc[idx] if join_with_input_columns and len(input_data) > 0 else None + + if value is None: + if outer: + # For outer=True, produce a row with NULL values + row_prefix = tuple(input_row.values) if input_row is not None else () + result_rows.append(row_prefix + (None, None)) + continue + + # Handle different types + is_array = isinstance(value, (list, tuple)) + is_object = isinstance(value, dict) + + if is_array and mode in ("ARRAY", "BOTH"): + if len(value) == 0: + if outer: + row_prefix = tuple(input_row.values) if input_row is not None else () + result_rows.append(row_prefix + (None, None)) + else: + for item in value: + row_prefix = tuple(input_row.values) if input_row is not None else () + # For arrays: KEY is None, VALUE is the item + result_rows.append(row_prefix + (None, item)) + elif is_object and mode in ("OBJECT", "BOTH"): + if len(value) == 0: + if outer: + row_prefix = tuple(input_row.values) if input_row is not None else () + result_rows.append(row_prefix + (None, None)) + else: + for k, v in value.items(): + row_prefix = tuple(input_row.values) if input_row is not None else () + # For objects: KEY and VALUE + result_rows.append(row_prefix + (k, v)) + elif outer: + # Type doesn't match mode, but outer=True means emit null row + row_prefix = tuple(input_row.values) if input_row is not None else () + result_rows.append(row_prefix + (None, None)) + + # Build result columns + output_col_names = ["KEY", "VALUE"] + all_col_names = input_col_names + output_col_names + + if result_rows: + result_df = TableEmulator(result_rows, columns=all_col_names) + else: + result_df = TableEmulator(columns=all_col_names) + + # Set up sf_types + from snowflake.snowpark.types import StringType, VariantType + + sf_types = {} + if join_with_input_columns and hasattr(input_data, "sf_types"): + sf_types.update(input_data.sf_types) + sf_types["KEY"] = ColumnType(StringType(), True) + sf_types["VALUE"] = ColumnType(VariantType(), True) + result_df.sf_types = sf_types + + return result_df + + def handle_udtf_expression( exp: FunctionExpression, input_data: Union[TableEmulator, ColumnEmulator], @@ -858,6 +1030,22 @@ def handle_udtf_expression( current_row=None, join_with_input_columns=True, ): + from snowflake.snowpark._internal.analyzer.table_function import ( + FlattenFunction, + NamedArgumentsTableFunction, + ) + + # Handle built-in table functions first + if isinstance(exp, FlattenFunction): + return handle_flatten_function( + exp, input_data, analyzer, expr_to_alias, join_with_input_columns + ) + + # Handle flatten via NamedArgumentsTableFunction (used by explode) + if isinstance(exp, NamedArgumentsTableFunction) and exp.func_name.lower() == "flatten": + return handle_flatten_function( + exp, input_data, analyzer, expr_to_alias, join_with_input_columns + ) # TODO: handle and support imports + other udtf attributes. diff --git a/src/snowflake/snowpark/mock/_select_statement.py b/src/snowflake/snowpark/mock/_select_statement.py index d21a86aeda..1a4d318e84 100644 --- a/src/snowflake/snowpark/mock/_select_statement.py +++ b/src/snowflake/snowpark/mock/_select_statement.py @@ -106,6 +106,29 @@ def execution_plan(self): self._execution_plan = MockExecutionPlan(self, self._session) return self._execution_plan + @property + def snowflake_plan(self): + """Alias for execution_plan to provide API compatibility with SelectStatement. + + Why this alias is safe: + ----------------------- + The non-mock SelectStatement.snowflake_plan returns a SnowflakePlan with an + `output` property (List[Attribute]) used for schema inference. Our + MockExecutionPlan already provides this same `output` property with identical + semantics (see MockExecutionPlan.output in _plan.py). + + This alias exists specifically to support table_function.py line 327: + plan = select_statement.select([...]).snowflake_plan + explode_col_type = plan.output[0].datatype + + If SelectStatement.snowflake_plan gains additional functionality in the future + (lazy evaluation, metadata wrapping, etc.), this alias may need to be updated + to maintain parity. The mock test suite should catch such divergence. + + See also: GitHub issue #3565 (SNOW-2213161) + """ + return self.execution_plan + @property def attributes(self): return self._attributes or self.execution_plan.attributes diff --git a/tests/mock/test_explode.py b/tests/mock/test_explode.py new file mode 100644 index 0000000000..bcb6472d8c --- /dev/null +++ b/tests/mock/test_explode.py @@ -0,0 +1,280 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""Tests for explode() and flatten() table functions in local testing mode. + +This addresses GitHub issue #3565 (SNOW-2213161) where explode() with local +testing fails with AttributeError: 'MockSelectStatement' object has no +attribute 'snowflake_plan'. +""" + +import json + +import pytest + +from snowflake.snowpark import Session +from snowflake.snowpark.functions import col, explode, explode_outer, flatten +from snowflake.snowpark.types import VariantType + + +def parse_variant(value): + """Parse a VARIANT value - it may be JSON encoded or raw.""" + if value is None: + return None + if isinstance(value, str): + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + return value + return value + + +@pytest.fixture(scope="module") +def session(): + session = Session.builder.config("local_testing", True).create() + yield session + session.close() + + +class TestExplodeBasic: + """Test basic explode functionality with arrays.""" + + def test_explode_simple_array(self, session): + """Test the exact reproduction case from issue #3565.""" + data = [{"foo": "bar", "my_array": ["one", "two"]}] + df = session.create_dataframe(data, schema=["foo", "my_array"]) + result = df.select("foo", explode("my_array")).collect() + + assert len(result) == 2 + assert result[0]["FOO"] == "bar" + # VALUE is VARIANT type, so may be JSON-encoded + assert parse_variant(result[0]["VALUE"]) == "one" + assert result[1]["FOO"] == "bar" + assert parse_variant(result[1]["VALUE"]) == "two" + + def test_explode_multiple_rows(self, session): + """Test explode with multiple input rows.""" + data = [ + {"idx": 1, "arr": ["a", "b"]}, + {"idx": 2, "arr": ["x", "y", "z"]}, + ] + df = session.create_dataframe(data, schema=["idx", "arr"]) + result = df.select("idx", explode("arr")).collect() + + assert len(result) == 5 + # Check we get correct combinations + values_by_idx = {} + for row in result: + idx = row["IDX"] + if idx not in values_by_idx: + values_by_idx[idx] = [] + values_by_idx[idx].append(parse_variant(row["VALUE"])) + + assert sorted(values_by_idx[1]) == ["a", "b"] + assert sorted(values_by_idx[2]) == ["x", "y", "z"] + + def test_explode_with_integers(self, session): + """Test explode with integer arrays.""" + data = [{"id": 1, "nums": [10, 20, 30]}] + df = session.create_dataframe(data, schema=["id", "nums"]) + result = df.select("id", explode("nums")).collect() + + assert len(result) == 3 + values = [parse_variant(row["VALUE"]) for row in result] + assert sorted(values) == [10, 20, 30] + + def test_explode_schema_has_value_column(self, session): + """Test that explode result has correct schema.""" + data = [{"arr": ["a", "b"]}] + df = session.create_dataframe(data, schema=["arr"]) + result_df = df.select(explode("arr")) + + # Should have VALUE column of VARIANT type + field_names = [f.name for f in result_df.schema.fields] + assert "VALUE" in field_names + + # VALUE should be VARIANT type + value_field = next(f for f in result_df.schema.fields if f.name == "VALUE") + assert isinstance(value_field.datatype, VariantType) + + +class TestExplodeWithMaps: + """Test explode functionality with maps/dictionaries.""" + + def test_explode_simple_map(self, session): + """Test explode with a dictionary/map column.""" + data = [{"name": "Alice", "scores": {"math": 90, "science": 85}}] + df = session.create_dataframe(data, schema=["name", "scores"]) + result = df.select("name", explode("scores")).collect() + + assert len(result) == 2 + keys = {row["KEY"] for row in result} + assert keys == {"math", "science"} + + for row in result: + assert row["NAME"] == "Alice" + if row["KEY"] == "math": + assert parse_variant(row["VALUE"]) == 90 + elif row["KEY"] == "science": + assert parse_variant(row["VALUE"]) == 85 + + +class TestExplodeOuter: + """Test explode_outer which handles empty/null arrays.""" + + def test_explode_outer_with_empty_array(self, session): + """Test explode_outer produces NULL for empty arrays.""" + data = [ + {"idx": 1, "arr": [1, 2]}, + {"idx": 2, "arr": []}, + {"idx": 3, "arr": [3]}, + ] + df = session.create_dataframe(data, schema=["idx", "arr"]) + result = df.select("idx", explode_outer("arr")).sort("idx").collect() + + # idx=1 should have 2 rows, idx=2 should have 1 NULL row, idx=3 should have 1 row + idx_counts = {} + for row in result: + idx = row["IDX"] + idx_counts[idx] = idx_counts.get(idx, 0) + 1 + + assert idx_counts[1] == 2 + assert idx_counts[2] == 1 # NULL row for empty array + assert idx_counts[3] == 1 + + def test_explode_outer_with_null_array(self, session): + """Test explode_outer produces NULL for NULL arrays.""" + data = [ + {"idx": 1, "arr": ["a"]}, + {"idx": 2, "arr": None}, + ] + df = session.create_dataframe(data, schema=["idx", "arr"]) + result = df.select("idx", explode_outer("arr")).sort("idx").collect() + + assert len(result) == 2 + # First row should have value + assert result[0]["IDX"] == 1 + assert parse_variant(result[0]["VALUE"]) == "a" + # Second row should have NULL value + assert result[1]["IDX"] == 2 + assert result[1]["VALUE"] is None + + +# Note: Alias tests are skipped as alias handling for table functions requires +# additional work in the mock framework. + + +class TestExplodeWithOriginalColumn: + """Test that explode preserves other columns from the original dataframe.""" + + def test_explode_preserves_original_columns(self, session): + """Ensure explode joins correctly with original dataframe columns.""" + data = [ + {"id": 1, "name": "Alice", "items": ["a", "b"]}, + {"id": 2, "name": "Bob", "items": ["x"]}, + ] + df = session.create_dataframe(data, schema=["id", "name", "items"]) + result = df.select("id", "name", explode("items")).collect() + + # Should have 3 rows total + assert len(result) == 3 + + # Check columns are present + assert "ID" in result[0].as_dict() + assert "NAME" in result[0].as_dict() + assert "VALUE" in result[0].as_dict() + + # Check values are correctly joined + alice_rows = [r for r in result if r["NAME"] == "Alice"] + bob_rows = [r for r in result if r["NAME"] == "Bob"] + + assert len(alice_rows) == 2 + assert len(bob_rows) == 1 + + +class TestFlattenDirect: + """Test the flatten() function directly.""" + + def test_flatten_basic(self, session): + """Test basic flatten functionality.""" + data = [{"arr": [1, 2, 3]}] + df = session.create_dataframe(data, schema=["arr"]) + result = df.select(flatten(col("arr"))).select("value").collect() + + values = [parse_variant(row["VALUE"]) for row in result] + assert sorted(values) == [1, 2, 3] + + def test_flatten_with_outer(self, session): + """Test flatten with outer=True.""" + data = [ + {"arr": [1]}, + {"arr": []}, + ] + df = session.create_dataframe(data, schema=["arr"]) + result = df.select(flatten(col("arr"), outer=True)).select("value").collect() + + # Should have 2 rows: one for the value, one NULL for empty array + assert len(result) == 2 + + +class TestFlattenUnsupportedParameters: + """Test that unsupported FLATTEN parameters raise NotImplementedError.""" + + def test_flatten_with_path_raises_error(self, session): + """Test that path parameter raises NotImplementedError.""" + data = [{"obj": {"nested": {"value": 1}}}] + df = session.create_dataframe(data, schema=["obj"]) + + with pytest.raises(NotImplementedError) as exc_info: + df.select(flatten(col("obj"), path="nested")).collect() + + assert "PATH parameter is not supported" in str(exc_info.value) + assert "local testing" in str(exc_info.value) + + def test_flatten_with_recursive_raises_error(self, session): + """Test that recursive=True raises NotImplementedError.""" + data = [{"arr": [[1, 2], [3, 4]]}] + df = session.create_dataframe(data, schema=["arr"]) + + with pytest.raises(NotImplementedError) as exc_info: + df.select(flatten(col("arr"), recursive=True)).collect() + + assert "RECURSIVE=True is not supported" in str(exc_info.value) + assert "local testing" in str(exc_info.value) + + +class TestUDTFNotAffected: + """Test that non-flatten UDTFs still work correctly after the fix.""" + + def test_custom_udtf_still_works(self, session): + """Verify that custom UDTFs are not affected by the flatten fix. + + This test ensures the early-return pattern for flatten doesn't + break the existing UDTF code path. + """ + from snowflake.snowpark.functions import lit + from snowflake.snowpark.types import IntegerType, StructField, StructType + + # Define a simple UDTF + class Repeater: + def process(self, value, times): + for _ in range(times): + yield (value,) + + output_schema = StructType([StructField("REPEATED", IntegerType())]) + + # Register the UDTF + repeater_udtf = session.udtf.register( + Repeater, + output_schema=output_schema, + input_types=[IntegerType(), IntegerType()], + name="test_repeater", + ) + + # Call with literals (supported pattern in mock) + result = session.table_function(repeater_udtf(lit(42), lit(3))).collect() + + assert len(result) == 3 + for row in result: + assert parse_variant(row["REPEATED"]) == 42