diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 3b84ad61ee..a0d0f30890 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -8,14 +8,14 @@ import os import sys import tempfile -from typing import Any, Dict, List, Optional, Tuple, Union, Literal, Sequence +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union from snowflake.connector import ProgrammingError from snowflake.connector.cursor import SnowflakeCursor from snowflake.connector.options import pyarrow from snowflake.connector.pandas_tools import ( - _create_temp_stage, _create_temp_file_format, + _create_temp_stage, build_location_helper, ) from snowflake.snowpark._internal.analyzer.binary_plan_node import ( @@ -55,9 +55,9 @@ # Python 3.9 can use both # Python 3.10 needs to use collections.abc.Iterable because typing.Iterable is removed if sys.version_info <= (3, 9): - from typing import Iterable + from typing import Iterable, Iterator else: - from collections.abc import Iterable + from collections.abc import Iterable, Iterator LEFT_PARENTHESIS = "(" RIGHT_PARENTHESIS = ")" @@ -735,7 +735,9 @@ def schema_query_for_values_statement(output: List[Attribute]) -> str: query = ( SELECT - + COMMA.join([f"{DOLLAR}{i+1}{AS}{attr.name}" for i, attr in enumerate(output)]) + + COMMA.join( + [f"{DOLLAR}{i + 1}{AS}{attr.name}" for i, attr in enumerate(output)] + ) + FROM + VALUES + LEFT_PARENTHESIS @@ -758,7 +760,7 @@ def values_statement(output: List[Attribute], data: List[Row]) -> str: query = ( SELECT - + COMMA.join([f"{DOLLAR}{i+1}{AS}{c}" for i, c in enumerate(names)]) + + COMMA.join([f"{DOLLAR}{i + 1}{AS}{c}" for i, c in enumerate(names)]) + FROM + VALUES + COMMA.join(rows) @@ -1074,7 +1076,7 @@ def batch_insert_into_statement( if paramstyle == "qmark": placeholder_marks = [QUESTION_MARK] * num_cols elif paramstyle == "numeric": - placeholder_marks = [f"{SINGLE_COLON}{i+1}" for i in range(num_cols)] + placeholder_marks = [f"{SINGLE_COLON}{i + 1}" for i in range(num_cols)] elif paramstyle in ("format", "pyformat"): placeholder_marks = [PERCENT_S] * num_cols else: @@ -1999,24 +2001,24 @@ def cte_statement(queries: List[str], table_names: List[str]) -> str: return f"{WITH}{result}" -def write_arrow( +def write_parquet( cursor: SnowflakeCursor, - table: "pyarrow.Table", + parquet_files_generator: Iterator[str], table_name: str, + column_names: Optional[List[str]] = None, database: Optional[str] = None, schema: Optional[str] = None, - chunk_size: Optional[int] = None, - compression: str = "gzip", + compression: str = "auto", on_error: str = "abort_statement", use_vectorized_scanner: bool = False, parallel: int = 4, + write_files_in_parallel: bool = True, quote_identifiers: bool = True, auto_create_table: bool = False, overwrite: bool = False, table_type: Literal["", "temp", "temporary", "transient"] = "", use_logical_type: Optional[bool] = None, use_scoped_temp_object: bool = False, - **kwargs: Any, ) -> Tuple[ bool, int, @@ -2036,23 +2038,23 @@ def write_arrow( ] ], ]: - """Writes a pyarrow.Table to a Snowflake table. + """Uploads parquet files to a Snowflake table. - The pyarrow Table is written out to temporary files, uploaded to a temporary stage, and then copied into the final location. + Parquet files are uploaded to a temporary stage and then copied into the final location. Returns whether all files were ingested correctly, number of chunks uploaded, and number of rows ingested with all of the COPY INTO command's output for debugging purposes. Args: cursor: Snowflake connector cursor used to execute queries. - table: The pyarrow Table that is written. + parquet_files_generator: An iterator that yields local parquet file paths to upload. table_name: Table name where we want to insert into. + column_names: List of column names in the order they appear in the parquet files. + If None, column names will be inferred from the first parquet file. (Default value = None). database: Database schema and table is in, if not provided the default one will be used (Default value = None). schema: Schema table is in, if not provided the default one will be used (Default value = None). - chunk_size: Number of elements to be inserted in each batch, if not provided all elements will be dumped - (Default value = None). - compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives a - better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'gzip'). + compression: The compression used on the Parquet files, can only be auto, gzip or snappy. Gzip gives a + better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'auto'). on_error: Action to take when COPY INTO statements fail, default follows documentation at: https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions (Default value = 'abort_statement'). @@ -2060,11 +2062,14 @@ def write_arrow( `copy options `_. parallel: Number of threads to be used when uploading chunks, default follows documentation at: https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). + write_files_in_parallel: Whether to parallelize over the files while uploading. + This will pass a glob string to the PUT upload call (does not support divergent directory paths). + https://docs.snowflake.com/en/sql-reference/sql/put#usage-notes (Default value = True) quote_identifiers: By default, identifiers, specifically database, schema, table and column names - (from df.columns) will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. + will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True) auto_create_table: When true, will automatically create a table with corresponding columns for each column in - the passed in DataFrame. The table will not be created if it already exists + the parquet files. The table will not be created if it already exists table_type: The table type of to-be-created table. The supported table types include ``temp``/``temporary`` and ``transient``. Empty means permanent table as per SQL convention. use_logical_type: Boolean that specifies whether to use Parquet logical types. With this file format option, @@ -2072,15 +2077,16 @@ def write_arrow( set use_logical_type as True. Set to None to use Snowflakes default. For more information, see: https://docs.snowflake.com/en/sql-reference/sql/create-file-format """ - # SNOW-1904593: This function mostly copies the functionality of snowflake.connector.pandas_utils.write_pandas. - # It should be pushed down into the connector, but would require a minimum required version bump. - import pyarrow.parquet # type: ignore - if database is not None and schema is None: raise ProgrammingError( - "Schema has to be provided to write_arrow when a database is provided" + "Schema has to be provided to write_arrow or write_parquet when a database is provided" ) - compression_map = {"gzip": "auto", "snappy": "snappy", "none": "none"} + compression_map = { + "gzip": "auto", + "snappy": "snappy", + "none": "none", + "auto": "auto", + } if compression not in compression_map.keys(): raise ProgrammingError( f"Invalid compression '{compression}', only acceptable values are: {compression_map.keys()}" @@ -2091,9 +2097,6 @@ def write_arrow( "Unsupported table type. Expected table types: temp/temporary, transient" ) - if chunk_size is None: - chunk_size = len(table) - if use_logical_type is None: sql_use_logical_type = "" elif use_logical_type: @@ -2111,18 +2114,42 @@ def write_arrow( overwrite, use_scoped_temp_object, ) - with tempfile.TemporaryDirectory() as tmp_folder: - for file_number, offset in enumerate(range(0, len(table), chunk_size)): - # write chunk to disk - chunk_path = os.path.join(tmp_folder, f"{table_name}_{file_number}.parquet") - pyarrow.parquet.write_table( - table.slice(offset=offset, length=chunk_size), - chunk_path, - **kwargs, - ) - # upload chunk + + if write_files_in_parallel: + parquet_files = list(parquet_files_generator) + + # Infer column names from first parquet file. + if column_names is None: + import pyarrow.parquet # type: ignore + + column_names = pyarrow.parquet.read_table(parquet_files[0]).schema.names + upload_sql = ( + "PUT /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */ " + "'file://{path}' @{stage_location} PARALLEL={parallel}" + ).format( + path=parquet_files[0] + # make a glob string out of the first file + .replace(os.path.basename(parquet_files[0]), "*.parquet") + .replace("\\", "\\\\") + .replace("'", "\\'"), + stage_location=stage_location, + parallel=parallel, + ) + + cursor.execute(upload_sql, _is_internal=True) + num_files_uploaded = len(parquet_files) + else: + # Upload all parquet files from the generator, in sequence + num_files_uploaded = 0 + for chunk_path in parquet_files_generator: + # Infer column names from first parquet file. + if num_files_uploaded == 0 and column_names is None: + import pyarrow.parquet # type: ignore + + column_names = pyarrow.parquet.read_table(chunk_path).schema.names + upload_sql = ( - "PUT /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */ " + "PUT /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */ " "'file://{path}' @{stage_location} PARALLEL={parallel}" ).format( path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"), @@ -2130,19 +2157,18 @@ def write_arrow( parallel=parallel, ) cursor.execute(upload_sql, _is_internal=True) - # Remove chunk file - os.remove(chunk_path) + num_files_uploaded += 1 if quote_identifiers: quote = '"' - snowflake_column_names = [str(c).replace('"', '""') for c in table.schema.names] + snowflake_column_names = [str(c).replace('"', '""') for c in column_names] else: quote = "" - snowflake_column_names = list(table.schema.names) + snowflake_column_names = list(column_names) columns = quote + f"{quote},{quote}".join(snowflake_column_names) + quote def drop_object(name: str, object_type: str) -> None: - drop_sql = f"DROP {object_type.upper()} IF EXISTS {name} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */" + drop_sql = f"DROP {object_type.upper()} IF EXISTS {name} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */" cursor.execute(drop_sql, _is_internal=True) if auto_create_table or overwrite: @@ -2173,22 +2199,20 @@ def drop_object(name: str, object_type: str) -> None: parquet_columns = "$1:" + ",$1:".join( f"{quote}{snowflake_col}{quote}::{column_type_mapping[col]}" - for snowflake_col, col in zip(snowflake_column_names, table.schema.names) + for snowflake_col, col in zip(snowflake_column_names, column_names) ) if auto_create_table: create_table_columns = ", ".join( [ f"{quote}{snowflake_col}{quote} {column_type_mapping[col]}" - for snowflake_col, col in zip( - snowflake_column_names, table.schema.names - ) + for snowflake_col, col in zip(snowflake_column_names, column_names) ] ) create_table_sql = ( f"CREATE {table_type.upper()} TABLE IF NOT EXISTS {target_table_location} " f"({create_table_columns})" - f" /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */ " + f" /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */ " ) cursor.execute(create_table_sql, _is_internal=True) else: @@ -2204,11 +2228,11 @@ def drop_object(name: str, object_type: str) -> None: try: if overwrite and (not auto_create_table): - truncate_sql = f"TRUNCATE TABLE {target_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */" + truncate_sql = f"TRUNCATE TABLE {target_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */" cursor.execute(truncate_sql, _is_internal=True) copy_into_sql = ( - f"COPY INTO {target_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */ " + f"COPY INTO {target_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */ " f"({columns}) " f"FROM (SELECT {parquet_columns} FROM @{stage_location}) " f"FILE_FORMAT=(" @@ -2232,7 +2256,7 @@ def drop_object(name: str, object_type: str) -> None: quote_identifiers=quote_identifiers, ) drop_object(original_table_location, "table") - rename_table_sql = f"ALTER TABLE {target_table_location} RENAME TO {original_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_arrow() */" + rename_table_sql = f"ALTER TABLE {target_table_location} RENAME TO {original_table_location} /* Python:snowflake.snowpark._internal.analyzer.analyzer_utils.write_parquet() */" cursor.execute(rename_table_sql, _is_internal=True) except ProgrammingError: if overwrite and auto_create_table: @@ -2244,7 +2268,130 @@ def drop_object(name: str, object_type: str) -> None: return ( all(e[1] == "LOADED" for e in copy_results), - len(copy_results), + num_files_uploaded, sum(int(e[3]) for e in copy_results), copy_results, # pyright: ignore ) + + +def write_arrow( + cursor: SnowflakeCursor, + table: "pyarrow.Table", + table_name: str, + database: Optional[str] = None, + schema: Optional[str] = None, + chunk_size: Optional[int] = None, + compression: str = "gzip", + on_error: str = "abort_statement", + use_vectorized_scanner: bool = False, + parallel: int = 4, + quote_identifiers: bool = True, + auto_create_table: bool = False, + overwrite: bool = False, + table_type: Literal["", "temp", "temporary", "transient"] = "", + use_logical_type: Optional[bool] = None, + use_scoped_temp_object: bool = False, + **kwargs: Any, +) -> Tuple[ + bool, + int, + int, + Sequence[ + Tuple[ + str, + str, + int, + int, + int, + int, + Optional[str], + Optional[int], + Optional[int], + Optional[str], + ] + ], +]: + """Writes a pyarrow.Table to a Snowflake table. + + The pyarrow Table is written out to temporary files, uploaded to a temporary stage, and then copied into the final location. + + Returns whether all files were ingested correctly, number of chunks uploaded, and number of rows ingested + with all of the COPY INTO command's output for debugging purposes. + + Args: + cursor: Snowflake connector cursor used to execute queries. + table: The pyarrow Table that is written. + table_name: Table name where we want to insert into. + database: Database schema and table is in, if not provided the default one will be used (Default value = None). + schema: Schema table is in, if not provided the default one will be used (Default value = None). + chunk_size: Number of elements to be inserted in each batch, if not provided all elements will be dumped + (Default value = None). + compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives a + better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'gzip'). + on_error: Action to take when COPY INTO statements fail, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions + (Default value = 'abort_statement'). + use_vectorized_scanner: Boolean that specifies whether to use a vectorized scanner for loading Parquet files. See details at + `copy options `_. + parallel: Number of threads to be used when uploading chunks, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). + quote_identifiers: By default, identifiers, specifically database, schema, table and column names + (from df.columns) will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. + I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True) + auto_create_table: When true, will automatically create a table with corresponding columns for each column in + the passed in DataFrame. The table will not be created if it already exists + table_type: The table type of to-be-created table. The supported table types include ``temp``/``temporary`` + and ``transient``. Empty means permanent table as per SQL convention. + use_logical_type: Boolean that specifies whether to use Parquet logical types. With this file format option, + Snowflake can interpret Parquet logical types during data loading. To enable Parquet logical types, + set use_logical_type as True. Set to None to use Snowflakes default. For more information, see: + https://docs.snowflake.com/en/sql-reference/sql/create-file-format + """ + # SNOW-1904593: This function mostly copies the functionality of snowflake.connector.pandas_utils.write_pandas. + # It should be pushed down into the connector, but would require a minimum required version bump. + import pyarrow.parquet # type: ignore + + if chunk_size is None: + chunk_size = len(table) + + # Extract column names from the Arrow table + column_names = list(table.schema.names) + + # Create a generator that yields parquet file paths + def parquet_file_generator() -> Iterator[str]: + with tempfile.TemporaryDirectory() as tmp_folder: + for file_number, offset in enumerate(range(0, len(table), chunk_size)): + # Write chunk to disk + chunk_path = os.path.join( + tmp_folder, f"{table_name}_{file_number}.parquet" + ) + pyarrow.parquet.write_table( + table.slice(offset=offset, length=chunk_size), + chunk_path, + **kwargs, + ) + # Yield the path for upload + yield chunk_path + # Remove chunk file + os.remove(chunk_path) + + # Use write_parquet to handle the upload and COPY INTO operations + return write_parquet( + cursor=cursor, + parquet_files_generator=parquet_file_generator(), + table_name=table_name, + column_names=column_names, + database=database, + schema=schema, + compression=compression, + on_error=on_error, + use_vectorized_scanner=use_vectorized_scanner, + parallel=parallel, + write_files_in_parallel=False, + quote_identifiers=quote_identifiers, + auto_create_table=auto_create_table, + overwrite=overwrite, + table_type=table_type, + use_logical_type=use_logical_type, + use_scoped_temp_object=use_scoped_temp_object, + ) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 1afe626720..f7bca8ffc2 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -6,6 +6,7 @@ import atexit import datetime import decimal +import importlib.metadata import inspect import json import os @@ -35,7 +36,6 @@ ) import cloudpickle -import importlib.metadata from packaging.requirements import Requirement from packaging.version import parse as parse_version @@ -49,6 +49,7 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import ( result_scan_statement, write_arrow, + write_parquet, ) from snowflake.snowpark._internal.analyzer.datatype_mapper import str_to_sql from snowflake.snowpark._internal.analyzer.expression import Attribute @@ -108,8 +109,10 @@ MODULE_NAME_TO_PACKAGE_NAME_MAP, STAGE_PREFIX, SUPPORTED_TABLE_TYPES, - XPATH_HANDLERS_FILE_PATH, XPATH_HANDLER_MAP, + XPATH_HANDLERS_FILE_PATH, + AstFlagSource, + AstMode, PythonObjJSONEncoder, TempObjectType, calculate_checksum, @@ -127,6 +130,7 @@ get_temp_type_for_object, get_version, import_or_missing_modin_pandas, + is_ast_enabled, is_in_stored_procedure, normalize_local_file, normalize_remote_file_or_dir, @@ -135,6 +139,7 @@ publicapi, quote_name, random_name_for_temp_object, + set_ast_state, strip_double_quotes_in_like_statement_in_table_name, unwrap_single_quote, unwrap_stage_location_single_quote, @@ -142,10 +147,6 @@ warn_session_config_update_in_multithreaded_mode, warning, zip_file_or_directory_to_stream, - set_ast_state, - is_ast_enabled, - AstFlagSource, - AstMode, ) from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType from snowflake.snowpark.column import Column @@ -154,6 +155,7 @@ _use_scoped_temp_objects, ) from snowflake.snowpark.dataframe import DataFrame +from snowflake.snowpark.dataframe_profiler import DataframeProfiler from snowflake.snowpark.dataframe_reader import DataFrameReader from snowflake.snowpark.exceptions import ( SnowparkClientException, @@ -161,7 +163,6 @@ ) from snowflake.snowpark.file_operation import FileOperation from snowflake.snowpark.functions import ( - to_file, array_agg, col, column, @@ -169,6 +170,7 @@ parse_json, to_date, to_decimal, + to_file, to_geography, to_geometry, to_time, @@ -196,7 +198,6 @@ from snowflake.snowpark.row import Row from snowflake.snowpark.stored_procedure import StoredProcedureRegistration from snowflake.snowpark.stored_procedure_profiler import StoredProcedureProfiler -from snowflake.snowpark.dataframe_profiler import DataframeProfiler from snowflake.snowpark.table import Table from snowflake.snowpark.table_function import ( TableFunctionCall, @@ -208,6 +209,7 @@ DateType, DayTimeIntervalType, DecimalType, + FileType, FloatType, GeographyType, GeometryType, @@ -222,7 +224,6 @@ VariantType, VectorType, YearMonthIntervalType, - FileType, _AtomicType, ) from snowflake.snowpark.udaf import UDAFRegistration @@ -231,6 +232,7 @@ if TYPE_CHECKING: import modin.pandas # pragma: no cover + from snowflake.snowpark.udf import UserDefinedFunction # pragma: no cover # Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable @@ -1848,7 +1850,7 @@ def replicate_local_environment( @staticmethod def _parse_packages( - packages: List[Union[str, ModuleType]] + packages: List[Union[str, ModuleType]], ) -> Dict[str, Tuple[str, bool, Requirement]]: package_dict = dict() for package in packages: @@ -2764,7 +2766,6 @@ def table_function( self._conn, NopConnection ): if self._conn._suppress_not_implemented_error: - # TODO: Snowpark does not allow empty dataframes (no schema, no data). Have a dummy schema here. ans = self.createDataFrame( [], @@ -3182,6 +3183,154 @@ def write_arrow( f"Failed to write arrow table to Snowflake. COPY INTO output {ci_output}" ) + @experimental(version="1.41.0") + @publicapi + def write_parquet( + self, + path: str, + table_name: str, + *, + column_names: Optional[List[str]] = None, + database: Optional[str] = None, + schema: Optional[str] = None, + compression: str = "auto", + on_error: str = "abort_statement", + use_vectorized_scanner: bool = False, + parallel: int = 4, + write_files_in_parallel: bool = True, + quote_identifiers: bool = True, + auto_create_table: bool = False, + overwrite: bool = False, + table_type: Literal["", "temp", "temporary", "transient"] = "", + use_logical_type: Optional[bool] = None, + **kwargs: Dict[str, Any], + ) -> Table: + """Writes parquet file(s) to a Snowflake table. + + Parquet files are uploaded to a temporary stage and then copied into the final location. + + Returns a Snowpark Table that references the table referenced by table_name. + + Args: + path: Path to a single parquet file or a directory containing parquet files. + Can be a local file path (e.g., "/path/to/file.parquet" or "/path/to/directory/"). + table_name: Table name where we want to insert into. + column_names: List of column names in the order they appear in the parquet files. + If None, column names will be inferred from the first parquet file. (Default value = None). + database: Database schema and table is in, if not provided the default one will be used (Default value = None). + schema: Schema table is in, if not provided the default one will be used (Default value = None). + compression: The compression used while uploading the Parquet files. Can be "auto", "gzip", "snappy", or "none". + (Default value = "auto"). + on_error: Action to take when COPY INTO statements fail, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions + (Default value = 'abort_statement'). + use_vectorized_scanner: Boolean that specifies whether to use a vectorized scanner for loading Parquet files. See details at + `copy options `_. + parallel: Number of threads to be used when uploading chunks or files, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). + write_files_in_parallel: Whether to parallelize over the files while uploading. + This will pass a glob string to the PUT upload call (does not support divergent directory paths). + https://docs.snowflake.com/en/sql-reference/sql/put#usage-notes (Default value = True) + quote_identifiers: By default, identifiers, specifically database, schema, and table names + will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. + I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True) + auto_create_table: When true, will automatically create a table with corresponding columns for each column in + the parquet files. The table will not be created if it already exists. + overwrite: If True, the table contents will be overwritten. If False, data will be appended to the table. + (Default value = False). + table_type: The table type of to-be-created table. The supported table types include ``temp``/``temporary`` + and ``transient``. Empty means permanent table as per SQL convention. + use_logical_type: Boolean that specifies whether to use Parquet logical types. With this file format option, + Snowflake can interpret Parquet logical types during data loading. To enable Parquet logical types, + set use_logical_type as True. Set to None to use Snowflakes default. For more information, see: + https://docs.snowflake.com/en/sql-reference/sql/create-file-format + + Example:: + + >>> # Write a single parquet file + >>> session.write_parquet("/path/to/file.parquet", "my_table", auto_create_table=True) # doctest: +SKIP + >>> + >>> # Write all parquet files in a directory + >>> session.write_parquet("/path/to/directory/", "my_table", auto_create_table=True) # doctest: +SKIP + """ + import glob + + cursor = self._conn._conn.cursor() + + if quote_identifiers: + location = ( + (('"' + database + '".') if database else "") + + (('"' + schema + '".') if schema else "") + + ('"' + table_name + '"') + ) + else: + location = ( + (database + "." if database else "") + + (schema + "." if schema else "") + + (table_name) + ) + + if os.path.isfile(path): + parquet_files = [path] + elif os.path.isdir(path): + parquet_files = glob.glob(os.path.join(path, "**/*.parquet")) + glob.glob( + os.path.join(path, "*.parquet") + ) + if not parquet_files: + raise SnowparkSessionException( + f"No parquet files found in directory: {path}" + ) + + # Check if write_files_in_parallel is enabled but subdirectories exist + if write_files_in_parallel: + # Check if any parquet files are in subdirectories + files_in_subdirs = [ + f for f in parquet_files if os.path.dirname(f) != path + ] + if files_in_subdirs: + raise ProgrammingError( + "write_files_in_parallel=True is not supported when parquet files exist in subdirectories. " + "Please ensure all parquet files are in the root of the specified path, or set write_files_in_parallel=False." + ) + else: + raise SnowparkSessionException( + f"Path does not exist or is not accessible: {path}" + ) + + # Create generator that yields parquet file paths + def parquet_file_generator(): + yield from parquet_files + + success, _, _, ci_output = write_parquet( + cursor=cursor, + parquet_files_generator=parquet_file_generator(), + column_names=column_names, + table_name=table_name, + database=database, + schema=schema, + compression=compression, + on_error=on_error, + use_vectorized_scanner=use_vectorized_scanner, + parallel=parallel, + write_files_in_parallel=write_files_in_parallel, + quote_identifiers=quote_identifiers, + auto_create_table=auto_create_table, + overwrite=overwrite, + table_type=table_type, + use_logical_type=use_logical_type, + use_scoped_temp_object=self._use_scoped_temp_objects + and is_in_stored_procedure(), + ) + + if success: + table = self.table(location, _emit_ast=False) + set_api_call_source(table, "Session.write_parquet") + return table + else: + raise SnowparkSessionException( + f"Failed to write parquet file(s) to Snowflake. COPY INTO output {ci_output}" + ) + def _write_modin_pandas_helper( self, df: Union[ diff --git a/tests/integ/test_df_to_arrow.py b/tests/integ/test_df_to_arrow.py index a35b0def1d..a892f35015 100644 --- a/tests/integ/test_df_to_arrow.py +++ b/tests/integ/test_df_to_arrow.py @@ -3,26 +3,24 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +import math import os -import pytest import re -import math - from datetime import date, datetime from decimal import Decimal from typing import Iterator from unittest import mock -from snowflake.connector.errors import ProgrammingError +import pytest +from snowflake.connector.errors import ProgrammingError from snowflake.snowpark._internal.analyzer.analyzer_utils import write_arrow from snowflake.snowpark.exceptions import SnowparkSessionException from snowflake.snowpark.functions import col from snowflake.snowpark.row import Row -from snowflake.snowpark.types import DecimalType from snowflake.snowpark.session import WRITE_ARROW_CHUNK_SIZE - -from tests.utils import TestData, Utils, TestFiles +from snowflake.snowpark.types import DecimalType +from tests.utils import TestData, TestFiles, Utils try: import pyarrow as pa @@ -361,7 +359,7 @@ def test_misc_settings( def test_write_arrow_negative(session, basic_arrow_table): with pytest.raises( ProgrammingError, - match="Schema has to be provided to write_arrow when a database is provided", + match="Schema has to be provided to write_arrow or write_parquet when a database is provided", ): session.write_arrow(basic_arrow_table, "temp_table", database="foo") @@ -404,6 +402,284 @@ def test_write_arrow_negative(session, basic_arrow_table): session.write_arrow(basic_arrow_table, "temp_table") +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="arrow not fully supported by local testing.", +) +def test_write_parquet(session, tmp_path): + """Test the write_parquet method with single file and directory.""" + import pyarrow.parquet as pq + + # Create test parquet files + test_data = pa.Table.from_arrays([[1, 2, 3], ["a", "b", "c"]], names=["id", "name"]) + + # Test 1: Write a single parquet file + table_name1 = Utils.random_table_name() + single_file = tmp_path / "test_single.parquet" + pq.write_table(test_data, single_file) + + try: + result_table = session.write_parquet( + str(single_file), table_name1, auto_create_table=True + ) + Utils.check_answer( + result_table, + [Row(id=1, name="a"), Row(id=2, name="b"), Row(id=3, name="c")], + ) + finally: + Utils.drop_table(session, table_name1) + + # Test 2: Write from a directory with multiple parquet files + table_name2 = Utils.random_table_name() + parquet_dir = tmp_path / "parquet_files" + parquet_dir.mkdir() + + # Create multiple parquet files in the directory + pq.write_table(test_data, parquet_dir / "file1.parquet") + pq.write_table(test_data, parquet_dir / "file2.parquet") + + try: + result_table = session.write_parquet( + str(parquet_dir), table_name2, auto_create_table=True + ) + # Should have 6 rows (3 from each file) + assert result_table.count() == 6 + # Check that data contains expected values + result_data = result_table.collect() + assert all(row["id"] in [1, 2, 3] for row in result_data) + assert all(row["name"] in ["a", "b", "c"] for row in result_data) + finally: + Utils.drop_table(session, table_name2) + + # Test 3: Write with explicit column names + table_name3 = Utils.random_table_name() + try: + result_table = session.write_parquet( + str(single_file), + table_name3, + column_names=["id", "name"], + auto_create_table=True, + ) + Utils.check_answer( + result_table, + [Row(id=1, name="a"), Row(id=2, name="b"), Row(id=3, name="c")], + ) + finally: + Utils.drop_table(session, table_name3) + + # Test 4: Test overwrite functionality + table_name4 = Utils.random_table_name() + try: + # Initial write + session.write_parquet(str(single_file), table_name4, auto_create_table=True) + table1 = session.table(table_name4) + assert table1.count() == 3 + + # Append (default behavior) + session.write_parquet(str(single_file), table_name4) + table2 = session.table(table_name4) + assert table2.count() == 6 + + # Overwrite + session.write_parquet(str(single_file), table_name4, overwrite=True) + table3 = session.table(table_name4) + assert table3.count() == 3 + finally: + Utils.drop_table(session, table_name4) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="arrow not fully supported by local testing.", +) +def test_write_parquet_negative(session, tmp_path): + """Test error cases for write_parquet.""" + # Test non-existent path + with pytest.raises( + SnowparkSessionException, + match="Path does not exist or is not accessible", + ): + session.write_parquet("/nonexistent/path.parquet", "temp_table") + + # Test directory with no parquet files + empty_dir = tmp_path / "empty" + empty_dir.mkdir() + with pytest.raises( + SnowparkSessionException, + match="No parquet files found in directory", + ): + session.write_parquet(str(empty_dir), "temp_table") + + # Test schema/database mismatch + import pyarrow.parquet as pq + + test_data = pa.Table.from_arrays([[1, 2, 3]], names=["a"]) + test_file = tmp_path / "test.parquet" + pq.write_table(test_data, test_file) + + with pytest.raises( + ProgrammingError, + match="Schema has to be provided to write_arrow or write_parquet when a database is provided", + ): + session.write_parquet(str(test_file), "temp_table", database="foo") + + # Test write_files_in_parallel with subdirectories + subdir_test = tmp_path / "subdir_test" + subdir_test.mkdir() + subdir = subdir_test / "subdir" + subdir.mkdir() + + # Create parquet file in subdirectory + subdir_file = subdir / "test.parquet" + pq.write_table(test_data, subdir_file) + + with pytest.raises( + ProgrammingError, + match="write_files_in_parallel=True is not supported when parquet files exist in subdirectories", + ): + session.write_parquet( + str(subdir_test), "temp_table", write_files_in_parallel=True + ) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="arrow not fully supported by local testing.", +) +def test_write_parquet_parallel_upload(session, tmp_path, capfd): + """Test that parallel upload optimization works for multiple files in a flat directory.""" + import pyarrow.parquet as pq + + from snowflake.connector.cursor import SnowflakeCursor + + # Create multiple parquet files in a flat directory + test_dir = tmp_path / "parquet_files" + test_dir.mkdir() + + # Create 3 parquet files with different data + file1_data = pa.Table.from_arrays([[1, 2], ["a", "b"]], names=["id", "name"]) + file2_data = pa.Table.from_arrays([[3, 4], ["c", "d"]], names=["id", "name"]) + file3_data = pa.Table.from_arrays([[5, 6], ["e", "f"]], names=["id", "name"]) + + file1 = test_dir / "file1.parquet" + file2 = test_dir / "file2.parquet" + file3 = test_dir / "file3.parquet" + + pq.write_table(file1_data, file1) + pq.write_table(file2_data, file2) + pq.write_table(file3_data, file3) + + # Track execute calls + execute_calls = [] + original_execute = SnowflakeCursor.execute + + def execute_wrapper(self, *args, **kwargs): + execute_calls.append(args[0] if args else kwargs.get("command")) + return original_execute(self, *args, **kwargs) + + # Write parquet files using the parallel upload optimization + table_name = Utils.random_table_name() + try: + with mock.patch.object( + SnowflakeCursor, "execute", side_effect=execute_wrapper, autospec=True + ): + result_table = session.write_parquet( + str(test_dir), + table_name, + auto_create_table=True, + overwrite=True, + ) + + # Verify that there was a single PUT call with a glob pattern (parallel upload) + put_calls = [call for call in execute_calls if call and "PUT" in call] + assert len(put_calls) == 1, f"Expected 1 PUT call, got {len(put_calls)}" + assert "*.parquet" in put_calls[0], "Expected *.parquet glob in PUT call." + + # Verify all data was loaded + result_data = result_table.collect() + assert len(result_data) == 6 # 2 rows from each of 3 files + + # Verify data content + ids = {row["id"] for row in result_data} + names = {row["name"] for row in result_data} + assert ids == {1, 2, 3, 4, 5, 6} + assert names == {"a", "b", "c", "d", "e", "f"} + + finally: + Utils.drop_table(session, table_name) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="arrow not fully supported by local testing.", +) +def test_write_parquet_subfolders_sequential_fallback(session, tmp_path): + """Test that parallel upload optimization works for multiple files in a flat directory.""" + import pyarrow.parquet as pq + + from snowflake.connector.cursor import SnowflakeCursor + + # Create multiple parquet files in a flat directory + test_dir = tmp_path / "parquet_files" + test_dir.mkdir() + sub_dir_1 = tmp_path / "parquet_files" / "subdirectory1" + sub_dir_2 = tmp_path / "parquet_files" / "subdirectory2" + + sub_dir_1.mkdir() + sub_dir_2.mkdir() + + # Create 3 parquet files with different data + file1_data = pa.Table.from_arrays([[1, 2], ["a", "b"]], names=["id", "name"]) + file2_data = pa.Table.from_arrays([[3, 4], ["c", "d"]], names=["id", "name"]) + file3_data = pa.Table.from_arrays([[5, 6], ["e", "f"]], names=["id", "name"]) + + file1 = test_dir / "file1.parquet" + file2 = sub_dir_1 / "file2.parquet" + file3 = sub_dir_2 / "file3.parquet" + + pq.write_table(file1_data, file1) + pq.write_table(file2_data, file2) + pq.write_table(file3_data, file3) + + # Track execute calls + execute_calls = [] + original_execute = SnowflakeCursor.execute + + def execute_wrapper(self, *args, **kwargs): + execute_calls.append(args[0] if args else kwargs.get("command")) + return original_execute(self, *args, **kwargs) + + table_name = Utils.random_table_name() + try: + with mock.patch.object( + SnowflakeCursor, "execute", side_effect=execute_wrapper, autospec=True + ): + result_table = session.write_parquet( + str(test_dir), + table_name, + auto_create_table=True, + overwrite=True, + write_files_in_parallel=False, + ) + + # Verify that there were three PUT calls (one for each subdirectory/file) + put_calls = [call for call in execute_calls if call and "PUT" in call] + assert len(put_calls) == 3, f"Expected 3 PUT calls, got {len(put_calls)}" + + # Verify all data was loaded + result_data = result_table.collect() + assert len(result_data) == 6 # 2 rows from each of 3 files + + # Verify data content + ids = {row["id"] for row in result_data} + names = {row["name"] for row in result_data} + assert ids == {1, 2, 3, 4, 5, 6} + assert names == {"a", "b", "c", "d", "e", "f"} + + finally: + Utils.drop_table(session, table_name) + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="arrow not fully supported by local testing.",