diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 12b2b13794..e83018180f 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -168,7 +168,7 @@ GENERATED_PY_FILE_EXT = (".pyc", ".pyo", ".pyd", ".pyi") -INFER_SCHEMA_FORMAT_TYPES = ("PARQUET", "ORC", "AVRO", "JSON", "CSV") +INFER_SCHEMA_FORMAT_TYPES = ("PARQUET", "ORC", "AVRO", "JSON", "CSV", "XML") COPY_INTO_TABLE_COPY_OPTIONS = { "ON_ERROR", diff --git a/src/snowflake/snowpark/_internal/xml_reader.py b/src/snowflake/snowpark/_internal/xml_reader.py index 1c6ef06754..095b9ecb1a 100644 --- a/src/snowflake/snowpark/_internal/xml_reader.py +++ b/src/snowflake/snowpark/_internal/xml_reader.py @@ -7,12 +7,33 @@ import html.entities import struct import copy -from typing import Optional, Dict, Any, Iterator, BinaryIO, Union, Tuple +import random +from typing import Optional, Dict, Any, Iterator, BinaryIO, Union, Tuple +from datetime import datetime, date, time from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted from snowflake.snowpark._internal.type_utils import type_string_to_type_object from snowflake.snowpark.files import SnowflakeFile -from snowflake.snowpark.types import StructType, ArrayType, DataType, MapType +from snowflake.snowpark.types import ( + StructType, + ArrayType, + DataType, + MapType, + NullType, + StringType, + BooleanType, + IntegerType, + LongType, + DoubleType, + DecimalType, + DateType, + TimestampType, + TimeType, + StructField, +) + + +_DECIMAL_RE = re.compile(r"^[+-]?(?:\d+(?:\.\d*)?|\.\d+)$") # lxml is only a dev dependency so use try/except to import it if available try: @@ -640,3 +661,406 @@ def process( result_template=result_template, ): yield (element,) + + +def norm_text( + ignore_surrounding_whitespace: bool, text: Optional[str] +) -> Optional[str]: + if text is None: + return None + return text.strip() if ignore_surrounding_whitespace else text + + +def infer_type( + text: str, ignore_surrounding_whitespace: bool, null_value: Optional[str] +) -> DataType: + t = norm_text(ignore_surrounding_whitespace, text) + if t is None: + return NullType() + + # Apply null_value rule consistent with element_to_dict_or_str/get_text + if t == null_value: + return NullType() + + # Keep empty string as String unless user explicitly chose null_value="" + # In ElementTree, empty tags often yield text=None; but if we do get "", honor null_value behavior above. + if t == "": + return StringType() + + low = t.lower() + + # boolean + if low in ("true", "false"): + return BooleanType() + + # integer / long (no underscores, no decimals) + try: + # reject things like "01.0" + if all(c.isdigit() for c in (t[1:] if t.startswith(("+", "-")) else t)): + int(t, 10) + return IntegerType() + except Exception: + pass + + # decimal + if _DECIMAL_RE.match(t): + if t[0] in "+-": + t = t[1:] + + if "." in t: + left, right = t.split(".", 1) + scale = len(right) + precision = len(left) + len(right) + if not (0 <= scale <= precision): + scale = 0 + if not (0 <= precision <= 38): + precision = 38 + return DecimalType(precision, scale) + + # time + try: + time.fromisoformat(t) + return TimeType() + except Exception: + pass + + # date + try: + date.fromisoformat(t) + return DateType() + except Exception: + pass + + # timestamp + try: + datetime.fromisoformat(t) + return TimestampType() + except Exception: + pass + + return StringType() + + +def merge_decimal(a: DecimalType, b: DecimalType) -> DecimalType: + # Merge by taking max precision and max scale (clamped). + precision = max(a.precision, b.precision) + scale = max(a.scale, b.scale) + precision = min(38, max(precision, scale)) + scale = min(38, scale) + return DecimalType(precision, scale) + + +def rank(dt: DataType) -> int: + # Lower rank = "narrower"/preferred; higher = "wider"/more general + if isinstance(dt, NullType): + return 0 + if isinstance(dt, BooleanType): + return 1 + if isinstance(dt, IntegerType): + return 2 + if isinstance(dt, LongType): + return 3 + if isinstance(dt, DecimalType): + return 4 + if isinstance(dt, DoubleType): + return 5 + if isinstance(dt, DateType): + return 6 + if isinstance(dt, TimestampType): + return 7 + if isinstance(dt, StringType): + return 100 + if isinstance(dt, StructType): + return 200 + if isinstance(dt, ArrayType): + return 300 + return 1000 + + +def merge_struct(a: StructType, b: StructType) -> StructType: + if a is None: + return b + # Merge fields by name (case-sensitive), preserving first-seen order. + a_fields = {f.name: f.datatype for f in a.fields} + out_order = [f.name for f in a.fields] + + for f in b.fields: + if f.name not in a_fields: + a_fields[f.name] = f.datatype + out_order.append(f.name) + else: + a_fields[f.name] = merge_types(a_fields[f.name], f.datatype) + + return StructType([StructField(name, a_fields[name], True) for name in out_order]) + + +def merge_types(a: DataType, b: DataType) -> DataType: + # Handle arrays first + if isinstance(a, ArrayType) and isinstance(b, ArrayType): + return ArrayType(merge_types(a.element_type, b.element_type)) + if isinstance(a, ArrayType): + return ArrayType(merge_types(a.element_type, b)) + if isinstance(b, ArrayType): + return ArrayType(merge_types(a, b.element_type)) + + # Structs + if isinstance(a, StructType) and isinstance(b, StructType): + return merge_struct(a, b) + + # Nulls + if isinstance(a, NullType): + return b + if isinstance(b, NullType): + return a + + # Date/timestamp promotion + if isinstance(a, TimestampType) and isinstance(b, DateType): + return a + if isinstance(a, DateType) and isinstance(b, TimestampType): + return b + + # Numeric merging + if isinstance(a, DecimalType) and isinstance(b, DecimalType): + return merge_decimal(a, b) + if isinstance(a, DecimalType) and isinstance(b, (IntegerType, LongType)): + return a + if isinstance(b, DecimalType) and isinstance(a, (IntegerType, LongType)): + return b + if isinstance(a, DoubleType) and isinstance( + b, (IntegerType, LongType, DecimalType) + ): + return a + if isinstance(b, DoubleType) and isinstance( + a, (IntegerType, LongType, DecimalType) + ): + return b + if isinstance(a, LongType) and isinstance(b, IntegerType): + return a + if isinstance(b, LongType) and isinstance(a, IntegerType): + return b + + # If types are identical, keep one + if type(a) is type(b): + return a + + # Otherwise choose the "wider" by rank; anything conflicting often ends up as string. + # (e.g., boolean + number, date + number, etc.) + if rank(a) == 100 or rank(b) == 100: + return StringType() + return a if rank(a) >= rank(b) else b + + +def infer_schema( + element: ET.Element, + exclude_attributes: bool, + attribute_prefix: str, + null_value: str, + value_tag: str, + ignore_surrounding_whitespace: bool, +): + children = list(element) + + # Case: no children and (no attributes OR attributes excluded) -> scalar + if not children and (not element.attrib or exclude_attributes): + return infer_type(element.text, ignore_surrounding_whitespace, null_value) + + fields = [] + + # Attributes (same rule as element_to_dict_or_str) + if not exclude_attributes: + for attr_name, attr_value in element.attrib.items(): + field = StructField( + f"'{attribute_prefix}{attr_name}'", + infer_type(attr_value, ignore_surrounding_whitespace, null_value), + True, + ) + fields.append(field) + + # Children + if children: + by_tag = {} + for c in children: + by_tag.setdefault(c.tag, []).append(c) + + for tag, elems in by_tag.items(): + dt: Optional[DataType] = None + for child_elem in elems: + child_dt = infer_schema( + child_elem, + exclude_attributes, + attribute_prefix, + null_value, + value_tag, + ignore_surrounding_whitespace, + ) + dt = child_dt if dt is None else merge_types(dt, child_dt) + + assert dt is not None + if len(elems) > 1: + dt = ArrayType(dt) + field = StructField(f"'{tag}'", dt, True) + fields.append(field) + else: + # No children, but has attributes -> also include the value_tag for text if present and not null + # (matches element_to_dict_or_str behavior) + t = norm_text(ignore_surrounding_whitespace, element.text) + if t is not None and t != null_value: + field = StructField( + f"'{value_tag}'", + infer_type(t, ignore_surrounding_whitespace, null_value), + True, + ) + fields.append(field) + + return StructType(fields) + + +class XMLSchemaInference: + def process( + self, + file_path: str, + row_tag: str, + num_workers: int, + i: int, + sampling_ratio: float, + charset: str, + ignore_namespace: bool, + attribute_prefix: str, + null_value: str, + value_tag: str, + ignore_surrounding_whitespace: bool, + exclude_attributes: bool, + ): + chunk_size = int(1024) + result = None + + tag_start_1 = f"<{row_tag}>".encode() + tag_start_2 = f"<{row_tag} ".encode() + closing_tag = f"".encode() + + file_size = get_file_size(file_path) or 0 + if file_size <= 0: + yield (None,) + return + + # Compute this worker's approximate byte range (same pattern as XMLReader). + if num_workers is None or num_workers <= 0: + num_workers = 1 + if i is None or i < 0: + i = 0 + if i >= num_workers: + # No work for out-of-range worker id. + yield (None,) + return + + approx_chunk_size = file_size // num_workers + approx_start = approx_chunk_size * i + approx_end = approx_chunk_size * (i + 1) if i < num_workers - 1 else file_size + + # Guard against weird tiny files / integer division producing start==end for some workers. + approx_start = max(0, min(approx_start, file_size)) + approx_end = max(0, min(approx_end, file_size)) + if approx_start >= approx_end: + yield (None,) + return + + with SnowflakeFile.open(file_path, "rb", require_scoped_url=False) as f: + f.seek(approx_start) + + while True: + # 1) Find next opening within THIS byte range + try: + open_pos = find_next_opening_tag_pos( + f, tag_start_1, tag_start_2, approx_end, chunk_size + ) + except EOFError: + break + + if open_pos >= approx_end: + break + + record_start = open_pos + f.seek(record_start) + + # 2) Find record_end (self-closing vs closing tag) + try: + is_self_close, tag_end = tag_is_self_closing( + f, chunk_size=chunk_size + ) + if is_self_close: + record_end = tag_end + else: + f.seek(tag_end) + record_end = find_next_closing_tag_pos( + f, closing_tag, chunk_size=chunk_size + ) + except Exception: + # Malformed tag boundaries -> skip and keep scanning forward + try: + f.seek(min(record_start + 1, approx_end)) + except Exception: + break + continue + + # 3) Sampling + if random.random() < sampling_ratio: + + # 4) Read full record bytes and parse (same logic pattern as process_xml_range) + try: + f.seek(record_start) + record_bytes = f.read(record_end - record_start) + record_str = record_bytes.decode(charset, errors="replace") + record_str = re.sub(r"&(\w+);", replace_entity, record_str) + + if lxml_installed: + recover = bool(":" in row_tag) + parser = ET.XMLParser(recover=recover, ns_clean=True) + try: + element = ET.fromstring(record_str, parser) + except ET.XMLSyntaxError: + if ignore_namespace: + cleaned_record = re.sub( + r"\s+(\w+):(\w+)=", r" \2=", record_str + ) + element = ET.fromstring(cleaned_record, parser) + else: + raise + else: + element = ET.fromstring(record_str) + + if ignore_namespace: + element = strip_xml_namespaces(element) + except Exception: + # Malformed record -> ALWAYS skip it + try: + f.seek(min(record_end, approx_end)) + except Exception: + break + if record_end > approx_end: + break + continue + + schema = infer_schema( + element, + exclude_attributes, + attribute_prefix, + null_value, + value_tag, + ignore_surrounding_whitespace, + ) + + if not isinstance(schema, StructType): + schema = StructType( + [StructField(f"'{value_tag}'", schema, True)] + ) + result = merge_struct(result, schema) + + # 5) Move to end of record and continue (and stop if record crosses boundary) + if record_end > approx_end: + break + try: + f.seek(min(record_end, approx_end)) + except Exception: + break + + yield (result.simple_string() if result is not None else "",) diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 2d00d631c9..08dfd60817 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -51,6 +51,7 @@ convert_sf_to_sp_type, convert_sp_to_sf_type, most_permissive_type, + type_string_to_type_object, ) from snowflake.snowpark._internal.udf_utils import get_types_from_type_hints from snowflake.snowpark._internal.utils import ( @@ -75,6 +76,7 @@ warning, experimental, ) +from snowflake.snowpark._internal.xml_reader import merge_struct from snowflake.snowpark.column import METADATA_COLUMN_TYPES, Column, _to_col_if_str from snowflake.snowpark.dataframe import DataFrame from snowflake.snowpark.exceptions import ( @@ -89,6 +91,7 @@ TimestampTimeZone, VariantType, StructField, + StringType, ) # Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable @@ -483,6 +486,8 @@ def __init__( @property def _infer_schema(self): # let _cur_options to be the source of truth + + # for POC purpose, xml infer schema is turned on by default if self._file_type in INFER_SCHEMA_FORMAT_TYPES: return self._cur_options.get("INFER_SCHEMA", True) return False @@ -1072,7 +1077,7 @@ def xml(self, path: str, _emit_ast: bool = True) -> DataFrame: # cast to input custom schema type # TODO: SNOW-2923003: remove single quote after server side BCR is done - if self._user_schema: + if self._user_schema and not self._infer_schema: cols = [ df[single_quote(field._name)] .cast(field.datatype) @@ -1080,6 +1085,12 @@ def xml(self, path: str, _emit_ast: bool = True) -> DataFrame: for field in self._user_schema.fields ] return df.select(cols) + elif self._infer_schema: + cols = [ + df[field._name].cast(field.datatype).alias(field._name) + for field in self._user_schema.fields + ] + return df.select(cols) else: return df @@ -1148,6 +1159,94 @@ def options( self.option(k, v, _emit_ast=_emit_ast) return self + def _infer_schema_for_xml( + self, + path: str, + ): + ignore_namespace = self._cur_options.get("IGNORENAMESPACE", True) + attribute_prefix = self._cur_options.get("ATTRIBUTEPREFIX", "_") + exclude_attributes = self._cur_options.get("EXCLUDEATTRIBUTES", False) + value_tag = self._cur_options.get("VALUETAG", "_VALUE") + # NULLVALUE will be mapped to NULL_IF in pre-defined mapping in `dataframe_writer.py` + null_value = self._cur_options.get("NULL_IF", "") + ignore_surrounding_whitespace = self._cur_options.get( + "IGNORESURROUNDINGWHITESPACE", False + ) + row_tag = self._cur_options[XML_ROW_TAG_STRING] + charset = self._cur_options.get("CHARSET", "utf-8") + sampling_ratio = self._cur_options.get("SAMPLINGRATIO", 1.0) + if is_in_stored_procedure(): # pragma: no cover + # create a temp stage for udtf import files + # we have to use "temp" object instead of "scoped temp" object in stored procedure + # so we need to upload the file to the temp stage first to use register_from_file + temp_stage = random_name_for_temp_object(TempObjectType.STAGE) + sql_create_temp_stage = ( + f"create temp stage if not exists {temp_stage} {XML_READER_SQL_COMMENT}" + ) + self._session.sql(sql_create_temp_stage, _emit_ast=False).collect( + _emit_ast=False + ) + self._session._conn.upload_file( + XML_READER_FILE_PATH, + temp_stage, + compress_data=False, + overwrite=True, + skip_upload_on_content_match=True, + ) + python_file_path = ( + f"{STAGE_PREFIX}{temp_stage}/{os.path.basename(XML_READER_FILE_PATH)}" + ) + else: + python_file_path = XML_READER_FILE_PATH + + # create udtf + handler_name = "XMLSchemaInference" + _, input_types = get_types_from_type_hints( + (XML_READER_FILE_PATH, handler_name), TempObjectType.TABLE_FUNCTION + ) + output_schema = StructType([StructField("schema", StringType(), True)]) + xml_infer_schema_udtf = self._session.udtf.register_from_file( + python_file_path, + handler_name, + output_schema=output_schema, + input_types=input_types, + packages=["snowflake-snowpark-python", "lxml<6"], + replace=True, + _suppress_local_package_warnings=True, + ) + + try: + file_size = int( + self._session.sql(f"ls {path}", _emit_ast=False).collect( + _emit_ast=False + )[0]["size"] + ) # type: ignore + except IndexError: + raise ValueError(f"{path} does not exist") + num_workers = min(16, file_size // 1024 + 1) + worker_column_name = "WORKER" + # Create a range from 0 to N-1 + df = self._session.range(num_workers).to_df(worker_column_name) + df = df.select( + worker_column_name, + xml_infer_schema_udtf( + lit(path), + lit(row_tag), + lit(num_workers), + col(worker_column_name), + lit(sampling_ratio), + lit(charset), + lit(ignore_namespace), + lit(attribute_prefix), + lit(null_value), + lit(value_tag), + lit(ignore_surrounding_whitespace), + lit(exclude_attributes), + ), + ) + + return df.collect() + def _infer_schema_for_file_format( self, path: str, @@ -1460,9 +1559,24 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: replace=True, _suppress_local_package_warnings=True, ) + else: xml_reader_udtf = None + if ( + format == "XML" + and XML_ROW_TAG_STRING in self._cur_options + and self._infer_schema + ): + res = self._infer_schema_for_xml(path) + result = None + for r in res: + if r[1] == "": + continue + result = merge_struct(result, type_string_to_type_object(r[1])) + schema = StructType._to_attributes(result) + self._user_schema = result + if self._session.sql_simplifier_enabled: df = DataFrame( self._session, diff --git a/tests/integ/test_xml_reader_row_tag.py b/tests/integ/test_xml_reader_row_tag.py index e6cb75f9dc..51ff7ce590 100644 --- a/tests/integ/test_xml_reader_row_tag.py +++ b/tests/integ/test_xml_reader_row_tag.py @@ -691,3 +691,12 @@ def test_user_schema_without_rowtag(session): ValueError, match="When reading XML with user schema, rowtag must be set." ): session.read.schema(user_schema).xml(f"@{tmp_stage_name}/{test_file_books_xml}") + + +def test_infer_schema(session): + df = session.read.option("rowTag", "book").xml( + f"@{tmp_stage_name}/{test_file_books2_xml}" + ) + df.show() + df.printSchema() + print(df.collect()) diff --git a/tests/unit/test_xml_reader.py b/tests/unit/test_xml_reader.py index d4b8f038eb..733e3239ef 100644 --- a/tests/unit/test_xml_reader.py +++ b/tests/unit/test_xml_reader.py @@ -25,6 +25,7 @@ DEFAULT_CHUNK_SIZE, struct_type_to_result_template, schema_string_to_result_dict_and_struct_type, + infer_schema, ) from snowflake.snowpark.types import ( StructType, @@ -925,3 +926,34 @@ def test_schema_string_to_result_dict_and_struct_type(session): "description": None, "map_type": None, } + + +def test_schema_inference(): + xml_string = """ + + The Art of Snowflake + Jane Doe + 29.99 + + + tech_guru_87 + 5 + Very insightful and practical. + + + datawizard + 4 + Great read for data engineers. + + + + + + + + """ + element = ET.fromstring(xml_string) + res = infer_schema(element, False, "_", "", "_VALUE", False) + print(res) + for f in res.fields: + print(f)