From a719e0c220e5ef9996cbef64d69535f3f8b0a1ce Mon Sep 17 00:00:00 2001 From: nicornk Date: Fri, 17 Oct 2025 16:18:32 +0200 Subject: [PATCH 1/4] Refactor write_arrow into write_arrow and write_parquet --- .../_internal/analyzer/analyzer_utils.py | 226 +++++++++++++----- tests/integ/test_df_to_arrow.py | 2 +- 2 files changed, 167 insertions(+), 61 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 3b84ad61ee..181cc483df 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -8,14 +8,14 @@ import os import sys import tempfile -from typing import Any, Dict, List, Optional, Tuple, Union, Literal, Sequence +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union from snowflake.connector import ProgrammingError from snowflake.connector.cursor import SnowflakeCursor from snowflake.connector.options import pyarrow from snowflake.connector.pandas_tools import ( - _create_temp_stage, _create_temp_file_format, + _create_temp_stage, build_location_helper, ) from snowflake.snowpark._internal.analyzer.binary_plan_node import ( @@ -55,9 +55,9 @@ # Python 3.9 can use both # Python 3.10 needs to use collections.abc.Iterable because typing.Iterable is removed if sys.version_info <= (3, 9): - from typing import Iterable + from typing import Iterable, Iterator else: - from collections.abc import Iterable + from collections.abc import Iterable, Iterator LEFT_PARENTHESIS = "(" RIGHT_PARENTHESIS = ")" @@ -735,7 +735,9 @@ def schema_query_for_values_statement(output: List[Attribute]) -> str: query = ( SELECT - + COMMA.join([f"{DOLLAR}{i+1}{AS}{attr.name}" for i, attr in enumerate(output)]) + + COMMA.join( + [f"{DOLLAR}{i + 1}{AS}{attr.name}" for i, attr in enumerate(output)] + ) + FROM + VALUES + LEFT_PARENTHESIS @@ -758,7 +760,7 @@ def values_statement(output: List[Attribute], data: List[Row]) -> str: query = ( SELECT - + COMMA.join([f"{DOLLAR}{i+1}{AS}{c}" for i, c in enumerate(names)]) + + COMMA.join([f"{DOLLAR}{i + 1}{AS}{c}" for i, c in enumerate(names)]) + FROM + VALUES + COMMA.join(rows) @@ -1074,7 +1076,7 @@ def batch_insert_into_statement( if paramstyle == "qmark": placeholder_marks = [QUESTION_MARK] * num_cols elif paramstyle == "numeric": - placeholder_marks = [f"{SINGLE_COLON}{i+1}" for i in range(num_cols)] + placeholder_marks = [f"{SINGLE_COLON}{i + 1}" for i in range(num_cols)] elif paramstyle in ("format", "pyformat"): placeholder_marks = [PERCENT_S] * num_cols else: @@ -1999,14 +2001,14 @@ def cte_statement(queries: List[str], table_names: List[str]) -> str: return f"{WITH}{result}" -def write_arrow( +def write_parquet( cursor: SnowflakeCursor, - table: "pyarrow.Table", + parquet_files_generator: Iterator[str], + column_names: List[str], table_name: str, database: Optional[str] = None, schema: Optional[str] = None, - chunk_size: Optional[int] = None, - compression: str = "gzip", + compression: str = "auto", on_error: str = "abort_statement", use_vectorized_scanner: bool = False, parallel: int = 4, @@ -2016,7 +2018,6 @@ def write_arrow( table_type: Literal["", "temp", "temporary", "transient"] = "", use_logical_type: Optional[bool] = None, use_scoped_temp_object: bool = False, - **kwargs: Any, ) -> Tuple[ bool, int, @@ -2036,21 +2037,20 @@ def write_arrow( ] ], ]: - """Writes a pyarrow.Table to a Snowflake table. + """Uploads parquet files to a Snowflake table. - The pyarrow Table is written out to temporary files, uploaded to a temporary stage, and then copied into the final location. + Parquet files are uploaded to a temporary stage and then copied into the final location. Returns whether all files were ingested correctly, number of chunks uploaded, and number of rows ingested with all of the COPY INTO command's output for debugging purposes. Args: cursor: Snowflake connector cursor used to execute queries. - table: The pyarrow Table that is written. + parquet_files_generator: An iterator that yields local parquet file paths to upload. + column_names: List of column names in the order they appear in the parquet files. table_name: Table name where we want to insert into. database: Database schema and table is in, if not provided the default one will be used (Default value = None). schema: Schema table is in, if not provided the default one will be used (Default value = None). - chunk_size: Number of elements to be inserted in each batch, if not provided all elements will be dumped - (Default value = None). compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives a better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'gzip'). on_error: Action to take when COPY INTO statements fail, default follows documentation at: @@ -2061,10 +2061,10 @@ def write_arrow( parallel: Number of threads to be used when uploading chunks, default follows documentation at: https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). quote_identifiers: By default, identifiers, specifically database, schema, table and column names - (from df.columns) will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. + will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True) auto_create_table: When true, will automatically create a table with corresponding columns for each column in - the passed in DataFrame. The table will not be created if it already exists + the parquet files. The table will not be created if it already exists table_type: The table type of to-be-created table. The supported table types include ``temp``/``temporary`` and ``transient``. Empty means permanent table as per SQL convention. use_logical_type: Boolean that specifies whether to use Parquet logical types. With this file format option, @@ -2072,13 +2072,9 @@ def write_arrow( set use_logical_type as True. Set to None to use Snowflakes default. For more information, see: https://docs.snowflake.com/en/sql-reference/sql/create-file-format """ - # SNOW-1904593: This function mostly copies the functionality of snowflake.connector.pandas_utils.write_pandas. - # It should be pushed down into the connector, but would require a minimum required version bump. - import pyarrow.parquet # type: ignore - if database is not None and schema is None: raise ProgrammingError( - "Schema has to be provided to write_arrow when a database is provided" + "Schema has to be provided to write_arrow or write_parquet when a database is provided" ) compression_map = {"gzip": "auto", "snappy": "snappy", "none": "none"} if compression not in compression_map.keys(): @@ -2091,9 +2087,6 @@ def write_arrow( "Unsupported table type. Expected table types: temp/temporary, transient" ) - if chunk_size is None: - chunk_size = len(table) - if use_logical_type is None: sql_use_logical_type = "" elif use_logical_type: @@ -2111,38 +2104,31 @@ def write_arrow( overwrite, use_scoped_temp_object, ) - with tempfile.TemporaryDirectory() as tmp_folder: - for file_number, offset in enumerate(range(0, len(table), chunk_size)): - # write chunk to disk - chunk_path = os.path.join(tmp_folder, f"{table_name}_{file_number}.parquet") - pyarrow.parquet.write_table( - table.slice(offset=offset, length=chunk_size), - chunk_path, - **kwargs, - ) - # upload chunk - upload_sql = ( - "PUT /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */ " - "'file://{path}' @{stage_location} PARALLEL={parallel}" - ).format( - path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"), - stage_location=stage_location, - parallel=parallel, - ) - cursor.execute(upload_sql, _is_internal=True) - # Remove chunk file - os.remove(chunk_path) + + # Upload all parquet files from the generator + num_files_uploaded = 0 + for chunk_path in parquet_files_generator: + upload_sql = ( + "PUT /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */ " + "'file://{path}' @{stage_location} PARALLEL={parallel}" + ).format( + path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"), + stage_location=stage_location, + parallel=parallel, + ) + cursor.execute(upload_sql, _is_internal=True) + num_files_uploaded += 1 if quote_identifiers: quote = '"' - snowflake_column_names = [str(c).replace('"', '""') for c in table.schema.names] + snowflake_column_names = [str(c).replace('"', '""') for c in column_names] else: quote = "" - snowflake_column_names = list(table.schema.names) + snowflake_column_names = list(column_names) columns = quote + f"{quote},{quote}".join(snowflake_column_names) + quote def drop_object(name: str, object_type: str) -> None: - drop_sql = f"DROP {object_type.upper()} IF EXISTS {name} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */" + drop_sql = f"DROP {object_type.upper()} IF EXISTS {name} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */" cursor.execute(drop_sql, _is_internal=True) if auto_create_table or overwrite: @@ -2173,22 +2159,20 @@ def drop_object(name: str, object_type: str) -> None: parquet_columns = "$1:" + ",$1:".join( f"{quote}{snowflake_col}{quote}::{column_type_mapping[col]}" - for snowflake_col, col in zip(snowflake_column_names, table.schema.names) + for snowflake_col, col in zip(snowflake_column_names, column_names) ) if auto_create_table: create_table_columns = ", ".join( [ f"{quote}{snowflake_col}{quote} {column_type_mapping[col]}" - for snowflake_col, col in zip( - snowflake_column_names, table.schema.names - ) + for snowflake_col, col in zip(snowflake_column_names, column_names) ] ) create_table_sql = ( f"CREATE {table_type.upper()} TABLE IF NOT EXISTS {target_table_location} " f"({create_table_columns})" - f" /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */ " + f" /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */ " ) cursor.execute(create_table_sql, _is_internal=True) else: @@ -2204,11 +2188,11 @@ def drop_object(name: str, object_type: str) -> None: try: if overwrite and (not auto_create_table): - truncate_sql = f"TRUNCATE TABLE {target_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */" + truncate_sql = f"TRUNCATE TABLE {target_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */" cursor.execute(truncate_sql, _is_internal=True) copy_into_sql = ( - f"COPY INTO {target_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */ " + f"COPY INTO {target_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */ " f"({columns}) " f"FROM (SELECT {parquet_columns} FROM @{stage_location}) " f"FILE_FORMAT=(" @@ -2232,7 +2216,7 @@ def drop_object(name: str, object_type: str) -> None: quote_identifiers=quote_identifiers, ) drop_object(original_table_location, "table") - rename_table_sql = f"ALTER TABLE {target_table_location} RENAME TO {original_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */" + rename_table_sql = f"ALTER TABLE {target_table_location} RENAME TO {original_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */" cursor.execute(rename_table_sql, _is_internal=True) except ProgrammingError: if overwrite and auto_create_table: @@ -2244,7 +2228,129 @@ def drop_object(name: str, object_type: str) -> None: return ( all(e[1] == "LOADED" for e in copy_results), - len(copy_results), + num_files_uploaded, sum(int(e[3]) for e in copy_results), copy_results, # pyright: ignore ) + + +def write_arrow( + cursor: SnowflakeCursor, + table: "pyarrow.Table", + table_name: str, + database: Optional[str] = None, + schema: Optional[str] = None, + chunk_size: Optional[int] = None, + compression: str = "gzip", + on_error: str = "abort_statement", + use_vectorized_scanner: bool = False, + parallel: int = 4, + quote_identifiers: bool = True, + auto_create_table: bool = False, + overwrite: bool = False, + table_type: Literal["", "temp", "temporary", "transient"] = "", + use_logical_type: Optional[bool] = None, + use_scoped_temp_object: bool = False, + **kwargs: Any, +) -> Tuple[ + bool, + int, + int, + Sequence[ + Tuple[ + str, + str, + int, + int, + int, + int, + Optional[str], + Optional[int], + Optional[int], + Optional[str], + ] + ], +]: + """Writes a pyarrow.Table to a Snowflake table. + + The pyarrow Table is written out to temporary files, uploaded to a temporary stage, and then copied into the final location. + + Returns whether all files were ingested correctly, number of chunks uploaded, and number of rows ingested + with all of the COPY INTO command's output for debugging purposes. + + Args: + cursor: Snowflake connector cursor used to execute queries. + table: The pyarrow Table that is written. + table_name: Table name where we want to insert into. + database: Database schema and table is in, if not provided the default one will be used (Default value = None). + schema: Schema table is in, if not provided the default one will be used (Default value = None). + chunk_size: Number of elements to be inserted in each batch, if not provided all elements will be dumped + (Default value = None). + compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives a + better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'gzip'). + on_error: Action to take when COPY INTO statements fail, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions + (Default value = 'abort_statement'). + use_vectorized_scanner: Boolean that specifies whether to use a vectorized scanner for loading Parquet files. See details at + `copy options `_. + parallel: Number of threads to be used when uploading chunks, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). + quote_identifiers: By default, identifiers, specifically database, schema, table and column names + (from df.columns) will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. + I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True) + auto_create_table: When true, will automatically create a table with corresponding columns for each column in + the passed in DataFrame. The table will not be created if it already exists + table_type: The table type of to-be-created table. The supported table types include ``temp``/``temporary`` + and ``transient``. Empty means permanent table as per SQL convention. + use_logical_type: Boolean that specifies whether to use Parquet logical types. With this file format option, + Snowflake can interpret Parquet logical types during data loading. To enable Parquet logical types, + set use_logical_type as True. Set to None to use Snowflakes default. For more information, see: + https://docs.snowflake.com/en/sql-reference/sql/create-file-format + """ + # SNOW-1904593: This function mostly copies the functionality of snowflake.connector.pandas_utils.write_pandas. + # It should be pushed down into the connector, but would require a minimum required version bump. + import pyarrow.parquet # type: ignore + + if chunk_size is None: + chunk_size = len(table) + + # Extract column names from the Arrow table + column_names = list(table.schema.names) + + # Create a generator that yields parquet file paths + def parquet_file_generator() -> Iterator[str]: + with tempfile.TemporaryDirectory() as tmp_folder: + for file_number, offset in enumerate(range(0, len(table), chunk_size)): + # Write chunk to disk + chunk_path = os.path.join( + tmp_folder, f"{table_name}_{file_number}.parquet" + ) + pyarrow.parquet.write_table( + table.slice(offset=offset, length=chunk_size), + chunk_path, + **kwargs, + ) + # Yield the path for upload + yield chunk_path + # Remove chunk file + os.remove(chunk_path) + + # Use write_parquet to handle the upload and COPY INTO operations + return write_parquet( + cursor=cursor, + parquet_files_generator=parquet_file_generator(), + column_names=column_names, + table_name=table_name, + database=database, + schema=schema, + compression=compression, + on_error=on_error, + use_vectorized_scanner=use_vectorized_scanner, + parallel=parallel, + quote_identifiers=quote_identifiers, + auto_create_table=auto_create_table, + overwrite=overwrite, + table_type=table_type, + use_logical_type=use_logical_type, + use_scoped_temp_object=use_scoped_temp_object, + ) diff --git a/tests/integ/test_df_to_arrow.py b/tests/integ/test_df_to_arrow.py index a35b0def1d..a24b976591 100644 --- a/tests/integ/test_df_to_arrow.py +++ b/tests/integ/test_df_to_arrow.py @@ -361,7 +361,7 @@ def test_misc_settings( def test_write_arrow_negative(session, basic_arrow_table): with pytest.raises( ProgrammingError, - match="Schema has to be provided to write_arrow when a database is provided", + match="Schema has to be provided to write_arrow or write_parquet when a database is provided", ): session.write_arrow(basic_arrow_table, "temp_table", database="foo") From cd8cffce5474f14c04e7319c7cc1f5c4df460275 Mon Sep 17 00:00:00 2001 From: nicornk Date: Fri, 17 Oct 2025 16:49:39 +0200 Subject: [PATCH 2/4] add public method to Session class. --- .../_internal/analyzer/analyzer_utils.py | 20 ++- src/snowflake/snowpark/session.py | 155 ++++++++++++++++-- tests/integ/test_df_to_arrow.py | 134 ++++++++++++++- 3 files changed, 287 insertions(+), 22 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 181cc483df..24eb5da09f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -2004,8 +2004,8 @@ def cte_statement(queries: List[str], table_names: List[str]) -> str: def write_parquet( cursor: SnowflakeCursor, parquet_files_generator: Iterator[str], - column_names: List[str], table_name: str, + column_names: Optional[List[str]] = None, database: Optional[str] = None, schema: Optional[str] = None, compression: str = "auto", @@ -2047,8 +2047,9 @@ def write_parquet( Args: cursor: Snowflake connector cursor used to execute queries. parquet_files_generator: An iterator that yields local parquet file paths to upload. - column_names: List of column names in the order they appear in the parquet files. table_name: Table name where we want to insert into. + column_names: List of column names in the order they appear in the parquet files. + If None, column names will be inferred from the first parquet file. (Default value = None). database: Database schema and table is in, if not provided the default one will be used (Default value = None). schema: Schema table is in, if not provided the default one will be used (Default value = None). compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives a @@ -2076,7 +2077,12 @@ def write_parquet( raise ProgrammingError( "Schema has to be provided to write_arrow or write_parquet when a database is provided" ) - compression_map = {"gzip": "auto", "snappy": "snappy", "none": "none"} + compression_map = { + "gzip": "auto", + "snappy": "snappy", + "none": "none", + "auto": "auto", + } if compression not in compression_map.keys(): raise ProgrammingError( f"Invalid compression '{compression}', only acceptable values are: {compression_map.keys()}" @@ -2108,6 +2114,12 @@ def write_parquet( # Upload all parquet files from the generator num_files_uploaded = 0 for chunk_path in parquet_files_generator: + # Infer column names from first parquet file. + if num_files_uploaded == 0 and column_names is None: + import pyarrow.parquet # type: ignore + + column_names = pyarrow.parquet.read_table(chunk_path).schema.names + upload_sql = ( "PUT /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */ " "'file://{path}' @{stage_location} PARALLEL={parallel}" @@ -2339,8 +2351,8 @@ def parquet_file_generator() -> Iterator[str]: return write_parquet( cursor=cursor, parquet_files_generator=parquet_file_generator(), - column_names=column_names, table_name=table_name, + column_names=column_names, database=database, schema=schema, compression=compression, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 1afe626720..2030e206f3 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -6,6 +6,7 @@ import atexit import datetime import decimal +import importlib.metadata import inspect import json import os @@ -35,7 +36,6 @@ ) import cloudpickle -import importlib.metadata from packaging.requirements import Requirement from packaging.version import parse as parse_version @@ -49,6 +49,7 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import ( result_scan_statement, write_arrow, + write_parquet, ) from snowflake.snowpark._internal.analyzer.datatype_mapper import str_to_sql from snowflake.snowpark._internal.analyzer.expression import Attribute @@ -108,8 +109,10 @@ MODULE_NAME_TO_PACKAGE_NAME_MAP, STAGE_PREFIX, SUPPORTED_TABLE_TYPES, - XPATH_HANDLERS_FILE_PATH, XPATH_HANDLER_MAP, + XPATH_HANDLERS_FILE_PATH, + AstFlagSource, + AstMode, PythonObjJSONEncoder, TempObjectType, calculate_checksum, @@ -127,6 +130,7 @@ get_temp_type_for_object, get_version, import_or_missing_modin_pandas, + is_ast_enabled, is_in_stored_procedure, normalize_local_file, normalize_remote_file_or_dir, @@ -135,6 +139,7 @@ publicapi, quote_name, random_name_for_temp_object, + set_ast_state, strip_double_quotes_in_like_statement_in_table_name, unwrap_single_quote, unwrap_stage_location_single_quote, @@ -142,10 +147,6 @@ warn_session_config_update_in_multithreaded_mode, warning, zip_file_or_directory_to_stream, - set_ast_state, - is_ast_enabled, - AstFlagSource, - AstMode, ) from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType from snowflake.snowpark.column import Column @@ -154,6 +155,7 @@ _use_scoped_temp_objects, ) from snowflake.snowpark.dataframe import DataFrame +from snowflake.snowpark.dataframe_profiler import DataframeProfiler from snowflake.snowpark.dataframe_reader import DataFrameReader from snowflake.snowpark.exceptions import ( SnowparkClientException, @@ -161,7 +163,6 @@ ) from snowflake.snowpark.file_operation import FileOperation from snowflake.snowpark.functions import ( - to_file, array_agg, col, column, @@ -169,6 +170,7 @@ parse_json, to_date, to_decimal, + to_file, to_geography, to_geometry, to_time, @@ -196,7 +198,6 @@ from snowflake.snowpark.row import Row from snowflake.snowpark.stored_procedure import StoredProcedureRegistration from snowflake.snowpark.stored_procedure_profiler import StoredProcedureProfiler -from snowflake.snowpark.dataframe_profiler import DataframeProfiler from snowflake.snowpark.table import Table from snowflake.snowpark.table_function import ( TableFunctionCall, @@ -208,6 +209,7 @@ DateType, DayTimeIntervalType, DecimalType, + FileType, FloatType, GeographyType, GeometryType, @@ -222,7 +224,6 @@ VariantType, VectorType, YearMonthIntervalType, - FileType, _AtomicType, ) from snowflake.snowpark.udaf import UDAFRegistration @@ -231,6 +232,7 @@ if TYPE_CHECKING: import modin.pandas # pragma: no cover + from snowflake.snowpark.udf import UserDefinedFunction # pragma: no cover # Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable @@ -1848,7 +1850,7 @@ def replicate_local_environment( @staticmethod def _parse_packages( - packages: List[Union[str, ModuleType]] + packages: List[Union[str, ModuleType]], ) -> Dict[str, Tuple[str, bool, Requirement]]: package_dict = dict() for package in packages: @@ -2764,7 +2766,6 @@ def table_function( self._conn, NopConnection ): if self._conn._suppress_not_implemented_error: - # TODO: Snowpark does not allow empty dataframes (no schema, no data). Have a dummy schema here. ans = self.createDataFrame( [], @@ -3182,6 +3183,138 @@ def write_arrow( f"Failed to write arrow table to Snowflake. COPY INTO output {ci_output}" ) + @experimental(version="1.41.0") + @publicapi + def write_parquet( + self, + path: str, + table_name: str, + *, + column_names: Optional[List[str]] = None, + database: Optional[str] = None, + schema: Optional[str] = None, + compression: str = "auto", + on_error: str = "abort_statement", + use_vectorized_scanner: bool = False, + parallel: int = 4, + quote_identifiers: bool = True, + auto_create_table: bool = False, + overwrite: bool = False, + table_type: Literal["", "temp", "temporary", "transient"] = "", + use_logical_type: Optional[bool] = None, + **kwargs: Dict[str, Any], + ) -> Table: + """Writes parquet file(s) to a Snowflake table. + + Parquet files are uploaded to a temporary stage and then copied into the final location. + + Returns a Snowpark Table that references the table referenced by table_name. + + Args: + path: Path to a single parquet file or a directory containing parquet files. + Can be a local file path (e.g., "/path/to/file.parquet" or "/path/to/directory/"). + table_name: Table name where we want to insert into. + column_names: List of column names in the order they appear in the parquet files. + If None, column names will be inferred from the first parquet file. (Default value = None). + database: Database schema and table is in, if not provided the default one will be used (Default value = None). + schema: Schema table is in, if not provided the default one will be used (Default value = None). + compression: The compression used on the Parquet files. Can be "auto", "gzip", "snappy", or "none". + (Default value = "auto"). + on_error: Action to take when COPY INTO statements fail, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions + (Default value = 'abort_statement'). + use_vectorized_scanner: Boolean that specifies whether to use a vectorized scanner for loading Parquet files. See details at + `copy options `_. + parallel: Number of threads to be used when uploading chunks, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). + quote_identifiers: By default, identifiers, specifically database, schema, and table names + will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. + I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True) + auto_create_table: When true, will automatically create a table with corresponding columns for each column in + the parquet files. The table will not be created if it already exists. + overwrite: If True, the table contents will be overwritten. If False, data will be appended to the table. + (Default value = False). + table_type: The table type of to-be-created table. The supported table types include ``temp``/``temporary`` + and ``transient``. Empty means permanent table as per SQL convention. + use_logical_type: Boolean that specifies whether to use Parquet logical types. With this file format option, + Snowflake can interpret Parquet logical types during data loading. To enable Parquet logical types, + set use_logical_type as True. Set to None to use Snowflakes default. For more information, see: + https://docs.snowflake.com/en/sql-reference/sql/create-file-format + + Example:: + + >>> # Write a single parquet file + >>> session.write_parquet("/path/to/file.parquet", "my_table", auto_create_table=True) # doctest: +SKIP + >>> + >>> # Write all parquet files in a directory + >>> session.write_parquet("/path/to/directory/", "my_table", auto_create_table=True) # doctest: +SKIP + """ + import glob + + cursor = self._conn._conn.cursor() + + if quote_identifiers: + location = ( + (('"' + database + '".') if database else "") + + (('"' + schema + '".') if schema else "") + + ('"' + table_name + '"') + ) + else: + location = ( + (database + "." if database else "") + + (schema + "." if schema else "") + + (table_name) + ) + + # Collect parquet files + if os.path.isfile(path): + # Single file + parquet_files = [path] + elif os.path.isdir(path): + # Directory - find all parquet files + parquet_files = glob.glob(os.path.join(path, "*.parquet")) + if not parquet_files: + raise SnowparkSessionException( + f"No parquet files found in directory: {path}" + ) + else: + raise SnowparkSessionException( + f"Path does not exist or is not accessible: {path}" + ) + + # Create generator that yields parquet file paths + def parquet_file_generator(): + yield from parquet_files + + success, _, _, ci_output = write_parquet( + cursor=cursor, + parquet_files_generator=parquet_file_generator(), + column_names=column_names, + table_name=table_name, + database=database, + schema=schema, + compression=compression, + on_error=on_error, + use_vectorized_scanner=use_vectorized_scanner, + parallel=parallel, + quote_identifiers=quote_identifiers, + auto_create_table=auto_create_table, + overwrite=overwrite, + table_type=table_type, + use_logical_type=use_logical_type, + use_scoped_temp_object=self._use_scoped_temp_objects + and is_in_stored_procedure(), + ) + + if success: + table = self.table(location, _emit_ast=False) + set_api_call_source(table, "Session.write_parquet") + return table + else: + raise SnowparkSessionException( + f"Failed to write parquet file(s) to Snowflake. COPY INTO output {ci_output}" + ) + def _write_modin_pandas_helper( self, df: Union[ diff --git a/tests/integ/test_df_to_arrow.py b/tests/integ/test_df_to_arrow.py index a24b976591..852990997f 100644 --- a/tests/integ/test_df_to_arrow.py +++ b/tests/integ/test_df_to_arrow.py @@ -3,26 +3,24 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +import math import os -import pytest import re -import math - from datetime import date, datetime from decimal import Decimal from typing import Iterator from unittest import mock -from snowflake.connector.errors import ProgrammingError +import pytest +from snowflake.connector.errors import ProgrammingError from snowflake.snowpark._internal.analyzer.analyzer_utils import write_arrow from snowflake.snowpark.exceptions import SnowparkSessionException from snowflake.snowpark.functions import col from snowflake.snowpark.row import Row -from snowflake.snowpark.types import DecimalType from snowflake.snowpark.session import WRITE_ARROW_CHUNK_SIZE - -from tests.utils import TestData, Utils, TestFiles +from snowflake.snowpark.types import DecimalType +from tests.utils import TestData, TestFiles, Utils try: import pyarrow as pa @@ -404,6 +402,128 @@ def test_write_arrow_negative(session, basic_arrow_table): session.write_arrow(basic_arrow_table, "temp_table") +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="arrow not fully supported by local testing.", +) +def test_write_parquet(session, tmp_path): + """Test the write_parquet method with single file and directory.""" + import pyarrow.parquet as pq + + # Create test parquet files + test_data = pa.Table.from_arrays([[1, 2, 3], ["a", "b", "c"]], names=["id", "name"]) + + # Test 1: Write a single parquet file + table_name1 = Utils.random_table_name() + single_file = tmp_path / "test_single.parquet" + pq.write_table(test_data, single_file) + + try: + result_table = session.write_parquet( + str(single_file), table_name1, auto_create_table=True + ) + Utils.check_answer( + result_table, + [Row(id=1, name="a"), Row(id=2, name="b"), Row(id=3, name="c")], + ) + finally: + Utils.drop_table(session, table_name1) + + # Test 2: Write from a directory with multiple parquet files + table_name2 = Utils.random_table_name() + parquet_dir = tmp_path / "parquet_files" + parquet_dir.mkdir() + + # Create multiple parquet files in the directory + pq.write_table(test_data, parquet_dir / "file1.parquet") + pq.write_table(test_data, parquet_dir / "file2.parquet") + + try: + result_table = session.write_parquet( + str(parquet_dir), table_name2, auto_create_table=True + ) + # Should have 6 rows (3 from each file) + assert result_table.count() == 6 + # Check that data contains expected values + result_data = result_table.collect() + assert all(row["id"] in [1, 2, 3] for row in result_data) + assert all(row["name"] in ["a", "b", "c"] for row in result_data) + finally: + Utils.drop_table(session, table_name2) + + # Test 3: Write with explicit column names + table_name3 = Utils.random_table_name() + try: + result_table = session.write_parquet( + str(single_file), + table_name3, + column_names=["id", "name"], + auto_create_table=True, + ) + Utils.check_answer( + result_table, + [Row(id=1, name="a"), Row(id=2, name="b"), Row(id=3, name="c")], + ) + finally: + Utils.drop_table(session, table_name3) + + # Test 4: Test overwrite functionality + table_name4 = Utils.random_table_name() + try: + # Initial write + session.write_parquet(str(single_file), table_name4, auto_create_table=True) + table1 = session.table(table_name4) + assert table1.count() == 3 + + # Append (default behavior) + session.write_parquet(str(single_file), table_name4) + table2 = session.table(table_name4) + assert table2.count() == 6 + + # Overwrite + session.write_parquet(str(single_file), table_name4, overwrite=True) + table3 = session.table(table_name4) + assert table3.count() == 3 + finally: + Utils.drop_table(session, table_name4) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="arrow not fully supported by local testing.", +) +def test_write_parquet_negative(session, tmp_path): + """Test error cases for write_parquet.""" + # Test non-existent path + with pytest.raises( + SnowparkSessionException, + match="Path does not exist or is not accessible", + ): + session.write_parquet("/nonexistent/path.parquet", "temp_table") + + # Test directory with no parquet files + empty_dir = tmp_path / "empty" + empty_dir.mkdir() + with pytest.raises( + SnowparkSessionException, + match="No parquet files found in directory", + ): + session.write_parquet(str(empty_dir), "temp_table") + + # Test schema/database mismatch + import pyarrow.parquet as pq + + test_data = pa.Table.from_arrays([[1, 2, 3]], names=["a"]) + test_file = tmp_path / "test.parquet" + pq.write_table(test_data, test_file) + + with pytest.raises( + ProgrammingError, + match="Schema has to be provided to write_arrow or write_parquet when a database is provided", + ): + session.write_parquet(str(test_file), "temp_table", database="foo") + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="arrow not fully supported by local testing.", From 295d11c434d0e3f76ea83bbf9709246ba08abb87 Mon Sep 17 00:00:00 2001 From: nicornk Date: Fri, 17 Oct 2025 17:36:44 +0200 Subject: [PATCH 3/4] intermediate state. TODO: verify the PUT command actually included the globstring. Validate that if write_files_in_parallel is True only a single directory with parquet files is passed. --- .../_internal/analyzer/analyzer_utils.py | 48 ++++++++++++++---- src/snowflake/snowpark/session.py | 9 +++- tests/integ/test_df_to_arrow.py | 49 +++++++++++++++++++ 3 files changed, 95 insertions(+), 11 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 24eb5da09f..c77f425250 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -2012,6 +2012,7 @@ def write_parquet( on_error: str = "abort_statement", use_vectorized_scanner: bool = False, parallel: int = 4, + write_files_in_parallel: bool = True, quote_identifiers: bool = True, auto_create_table: bool = False, overwrite: bool = False, @@ -2061,6 +2062,9 @@ def write_parquet( `copy options `_. parallel: Number of threads to be used when uploading chunks, default follows documentation at: https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). + write_files_in_parallel: Whether to parallelize over the files while uploading. + This will pass a glob string to the PUT upload call (does not support divergent directory paths). + https://docs.snowflake.com/en/sql-reference/sql/put#usage-notes (Default value = True) quote_identifiers: By default, identifiers, specifically database, schema, table and column names will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True) @@ -2111,25 +2115,49 @@ def write_parquet( use_scoped_temp_object, ) - # Upload all parquet files from the generator - num_files_uploaded = 0 - for chunk_path in parquet_files_generator: + if write_files_in_parallel: + parquet_files = list(parquet_files_generator) + # Infer column names from first parquet file. - if num_files_uploaded == 0 and column_names is None: + if column_names is None: import pyarrow.parquet # type: ignore - column_names = pyarrow.parquet.read_table(chunk_path).schema.names - + column_names = pyarrow.parquet.read_table(parquet_files[0]).schema.names upload_sql = ( - "PUT /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */ " + "PUT /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */ " "'file://{path}' @{stage_location} PARALLEL={parallel}" ).format( - path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"), + path=parquet_files[0] + # make a glob string out of the first file + .replace(os.path.basename(parquet_files[0]), "*.parquet") + .replace("\\", "\\\\") + .replace("'", "\\'"), stage_location=stage_location, parallel=parallel, ) - cursor.execute(upload_sql, _is_internal=True) - num_files_uploaded += 1 + + cursor.execute(upload_sql, _is_internal=False) + num_files_uploaded = len(parquet_files) + else: + # Upload all parquet files from the generator, in sequence + num_files_uploaded = 0 + for chunk_path in parquet_files_generator: + # Infer column names from first parquet file. + if num_files_uploaded == 0 and column_names is None: + import pyarrow.parquet # type: ignore + + column_names = pyarrow.parquet.read_table(chunk_path).schema.names + + upload_sql = ( + "PUT /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */ " + "'file://{path}' @{stage_location} PARALLEL={parallel}" + ).format( + path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"), + stage_location=stage_location, + parallel=parallel, + ) + cursor.execute(upload_sql, _is_internal=True) + num_files_uploaded += 1 if quote_identifiers: quote = '"' diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 2030e206f3..17d770037b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3165,6 +3165,8 @@ def write_arrow( on_error=on_error, use_vectorized_scanner=use_vectorized_scanner, parallel=parallel, + # to preserve existing behavior + write_files_in_parallel=False, quote_identifiers=quote_identifiers, auto_create_table=auto_create_table, overwrite=overwrite, @@ -3197,6 +3199,7 @@ def write_parquet( on_error: str = "abort_statement", use_vectorized_scanner: bool = False, parallel: int = 4, + write_files_in_parallel: bool = True, quote_identifiers: bool = True, auto_create_table: bool = False, overwrite: bool = False, @@ -3225,8 +3228,11 @@ def write_parquet( (Default value = 'abort_statement'). use_vectorized_scanner: Boolean that specifies whether to use a vectorized scanner for loading Parquet files. See details at `copy options `_. - parallel: Number of threads to be used when uploading chunks, default follows documentation at: + parallel: Number of threads to be used when uploading chunks or files, default follows documentation at: https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). + write_files_in_parallel: Whether to parallelize over the files while uploading. + This will pass a glob string to the PUT upload call (does not support divergent directory paths). + https://docs.snowflake.com/en/sql-reference/sql/put#usage-notes (Default value = True) quote_identifiers: By default, identifiers, specifically database, schema, and table names will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True) @@ -3297,6 +3303,7 @@ def parquet_file_generator(): on_error=on_error, use_vectorized_scanner=use_vectorized_scanner, parallel=parallel, + write_files_in_parallel=write_files_in_parallel, quote_identifiers=quote_identifiers, auto_create_table=auto_create_table, overwrite=overwrite, diff --git a/tests/integ/test_df_to_arrow.py b/tests/integ/test_df_to_arrow.py index 852990997f..1938269169 100644 --- a/tests/integ/test_df_to_arrow.py +++ b/tests/integ/test_df_to_arrow.py @@ -524,6 +524,55 @@ def test_write_parquet_negative(session, tmp_path): session.write_parquet(str(test_file), "temp_table", database="foo") +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="arrow not fully supported by local testing.", +) +def test_write_parquet_parallel_upload(session, tmp_path): + """Test that parallel upload optimization works for multiple files in a flat directory.""" + import pyarrow.parquet as pq + + # Create multiple parquet files in a flat directory + test_dir = tmp_path / "parquet_files" + test_dir.mkdir() + + # Create 3 parquet files with different data + file1_data = pa.Table.from_arrays([[1, 2], ["a", "b"]], names=["id", "name"]) + file2_data = pa.Table.from_arrays([[3, 4], ["c", "d"]], names=["id", "name"]) + file3_data = pa.Table.from_arrays([[5, 6], ["e", "f"]], names=["id", "name"]) + + file1 = test_dir / "file1.parquet" + file2 = test_dir / "file2.parquet" + file3 = test_dir / "file3.parquet" + + pq.write_table(file1_data, file1) + pq.write_table(file2_data, file2) + pq.write_table(file3_data, file3) + + # Write parquet files using the parallel upload optimization + table_name = Utils.random_table_name() + try: + result_table = session.write_parquet( + str(test_dir), + table_name, + auto_create_table=True, + overwrite=True, + ) + + # Verify all data was loaded + result_data = result_table.collect() + assert len(result_data) == 6 # 2 rows from each of 3 files + + # Verify data content + ids = {row["id"] for row in result_data} + names = {row["name"] for row in result_data} + assert ids == {1, 2, 3, 4, 5, 6} + assert names == {"a", "b", "c", "d", "e", "f"} + + finally: + Utils.drop_table(session, table_name) + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="arrow not fully supported by local testing.", From 98f4ea171188a8a330dfe7296053e28761dfd380 Mon Sep 17 00:00:00 2001 From: nicornk Date: Sat, 18 Oct 2025 17:05:49 +0200 Subject: [PATCH 4/4] verify PUT commands, restrict write_file_in_parallel to not allow subdirectories. --- .../_internal/analyzer/analyzer_utils.py | 7 +- src/snowflake/snowpark/session.py | 23 +++- tests/integ/test_df_to_arrow.py | 121 +++++++++++++++++- 3 files changed, 134 insertions(+), 17 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index c77f425250..a0d0f30890 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -2053,8 +2053,8 @@ def write_parquet( If None, column names will be inferred from the first parquet file. (Default value = None). database: Database schema and table is in, if not provided the default one will be used (Default value = None). schema: Schema table is in, if not provided the default one will be used (Default value = None). - compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives a - better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'gzip'). + compression: The compression used on the Parquet files, can only be auto, gzip or snappy. Gzip gives a + better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'auto'). on_error: Action to take when COPY INTO statements fail, default follows documentation at: https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions (Default value = 'abort_statement'). @@ -2136,7 +2136,7 @@ def write_parquet( parallel=parallel, ) - cursor.execute(upload_sql, _is_internal=False) + cursor.execute(upload_sql, _is_internal=True) num_files_uploaded = len(parquet_files) else: # Upload all parquet files from the generator, in sequence @@ -2387,6 +2387,7 @@ def parquet_file_generator() -> Iterator[str]: on_error=on_error, use_vectorized_scanner=use_vectorized_scanner, parallel=parallel, + write_files_in_parallel=False, quote_identifiers=quote_identifiers, auto_create_table=auto_create_table, overwrite=overwrite, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 17d770037b..f7bca8ffc2 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3165,8 +3165,6 @@ def write_arrow( on_error=on_error, use_vectorized_scanner=use_vectorized_scanner, parallel=parallel, - # to preserve existing behavior - write_files_in_parallel=False, quote_identifiers=quote_identifiers, auto_create_table=auto_create_table, overwrite=overwrite, @@ -3221,7 +3219,7 @@ def write_parquet( If None, column names will be inferred from the first parquet file. (Default value = None). database: Database schema and table is in, if not provided the default one will be used (Default value = None). schema: Schema table is in, if not provided the default one will be used (Default value = None). - compression: The compression used on the Parquet files. Can be "auto", "gzip", "snappy", or "none". + compression: The compression used while uploading the Parquet files. Can be "auto", "gzip", "snappy", or "none". (Default value = "auto"). on_error: Action to take when COPY INTO statements fail, default follows documentation at: https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions @@ -3272,17 +3270,28 @@ def write_parquet( + (table_name) ) - # Collect parquet files if os.path.isfile(path): - # Single file parquet_files = [path] elif os.path.isdir(path): - # Directory - find all parquet files - parquet_files = glob.glob(os.path.join(path, "*.parquet")) + parquet_files = glob.glob(os.path.join(path, "**/*.parquet")) + glob.glob( + os.path.join(path, "*.parquet") + ) if not parquet_files: raise SnowparkSessionException( f"No parquet files found in directory: {path}" ) + + # Check if write_files_in_parallel is enabled but subdirectories exist + if write_files_in_parallel: + # Check if any parquet files are in subdirectories + files_in_subdirs = [ + f for f in parquet_files if os.path.dirname(f) != path + ] + if files_in_subdirs: + raise ProgrammingError( + "write_files_in_parallel=True is not supported when parquet files exist in subdirectories. " + "Please ensure all parquet files are in the root of the specified path, or set write_files_in_parallel=False." + ) else: raise SnowparkSessionException( f"Path does not exist or is not accessible: {path}" diff --git a/tests/integ/test_df_to_arrow.py b/tests/integ/test_df_to_arrow.py index 1938269169..a892f35015 100644 --- a/tests/integ/test_df_to_arrow.py +++ b/tests/integ/test_df_to_arrow.py @@ -523,15 +523,35 @@ def test_write_parquet_negative(session, tmp_path): ): session.write_parquet(str(test_file), "temp_table", database="foo") + # Test write_files_in_parallel with subdirectories + subdir_test = tmp_path / "subdir_test" + subdir_test.mkdir() + subdir = subdir_test / "subdir" + subdir.mkdir() + + # Create parquet file in subdirectory + subdir_file = subdir / "test.parquet" + pq.write_table(test_data, subdir_file) + + with pytest.raises( + ProgrammingError, + match="write_files_in_parallel=True is not supported when parquet files exist in subdirectories", + ): + session.write_parquet( + str(subdir_test), "temp_table", write_files_in_parallel=True + ) + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="arrow not fully supported by local testing.", ) -def test_write_parquet_parallel_upload(session, tmp_path): +def test_write_parquet_parallel_upload(session, tmp_path, capfd): """Test that parallel upload optimization works for multiple files in a flat directory.""" import pyarrow.parquet as pq + from snowflake.connector.cursor import SnowflakeCursor + # Create multiple parquet files in a flat directory test_dir = tmp_path / "parquet_files" test_dir.mkdir() @@ -549,15 +569,102 @@ def test_write_parquet_parallel_upload(session, tmp_path): pq.write_table(file2_data, file2) pq.write_table(file3_data, file3) + # Track execute calls + execute_calls = [] + original_execute = SnowflakeCursor.execute + + def execute_wrapper(self, *args, **kwargs): + execute_calls.append(args[0] if args else kwargs.get("command")) + return original_execute(self, *args, **kwargs) + # Write parquet files using the parallel upload optimization table_name = Utils.random_table_name() try: - result_table = session.write_parquet( - str(test_dir), - table_name, - auto_create_table=True, - overwrite=True, - ) + with mock.patch.object( + SnowflakeCursor, "execute", side_effect=execute_wrapper, autospec=True + ): + result_table = session.write_parquet( + str(test_dir), + table_name, + auto_create_table=True, + overwrite=True, + ) + + # Verify that there was a single PUT call with a glob pattern (parallel upload) + put_calls = [call for call in execute_calls if call and "PUT" in call] + assert len(put_calls) == 1, f"Expected 1 PUT call, got {len(put_calls)}" + assert "*.parquet" in put_calls[0], "Expected *.parquet glob in PUT call." + + # Verify all data was loaded + result_data = result_table.collect() + assert len(result_data) == 6 # 2 rows from each of 3 files + + # Verify data content + ids = {row["id"] for row in result_data} + names = {row["name"] for row in result_data} + assert ids == {1, 2, 3, 4, 5, 6} + assert names == {"a", "b", "c", "d", "e", "f"} + + finally: + Utils.drop_table(session, table_name) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="arrow not fully supported by local testing.", +) +def test_write_parquet_subfolders_sequential_fallback(session, tmp_path): + """Test that parallel upload optimization works for multiple files in a flat directory.""" + import pyarrow.parquet as pq + + from snowflake.connector.cursor import SnowflakeCursor + + # Create multiple parquet files in a flat directory + test_dir = tmp_path / "parquet_files" + test_dir.mkdir() + sub_dir_1 = tmp_path / "parquet_files" / "subdirectory1" + sub_dir_2 = tmp_path / "parquet_files" / "subdirectory2" + + sub_dir_1.mkdir() + sub_dir_2.mkdir() + + # Create 3 parquet files with different data + file1_data = pa.Table.from_arrays([[1, 2], ["a", "b"]], names=["id", "name"]) + file2_data = pa.Table.from_arrays([[3, 4], ["c", "d"]], names=["id", "name"]) + file3_data = pa.Table.from_arrays([[5, 6], ["e", "f"]], names=["id", "name"]) + + file1 = test_dir / "file1.parquet" + file2 = sub_dir_1 / "file2.parquet" + file3 = sub_dir_2 / "file3.parquet" + + pq.write_table(file1_data, file1) + pq.write_table(file2_data, file2) + pq.write_table(file3_data, file3) + + # Track execute calls + execute_calls = [] + original_execute = SnowflakeCursor.execute + + def execute_wrapper(self, *args, **kwargs): + execute_calls.append(args[0] if args else kwargs.get("command")) + return original_execute(self, *args, **kwargs) + + table_name = Utils.random_table_name() + try: + with mock.patch.object( + SnowflakeCursor, "execute", side_effect=execute_wrapper, autospec=True + ): + result_table = session.write_parquet( + str(test_dir), + table_name, + auto_create_table=True, + overwrite=True, + write_files_in_parallel=False, + ) + + # Verify that there were three PUT calls (one for each subdirectory/file) + put_calls = [call for call in execute_calls if call and "PUT" in call] + assert len(put_calls) == 3, f"Expected 3 PUT calls, got {len(put_calls)}" # Verify all data was loaded result_data = result_table.collect()