diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 86b4bd799f..badd914cb3 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -145,6 +145,71 @@ class Decorator: ) __wrap_exception_regex_sub = re.compile(r"""^"|"$""") + @staticmethod + def get_quoted_identifiers( + sf_plan_in_arg: Optional["SnowflakePlan"], + ) -> List[str]: + """ + Get all quoted identifiers from the children plan nodes of the SnowflakePlan. + """ + if sf_plan_in_arg is None: + return [] + + from snowflake.snowpark._internal.analyzer.select_statement import ( + Selectable, + ) + + quoted_identifiers = [] + plan_nodes = sf_plan_in_arg.children_plan_nodes + for node in plan_nodes: + if isinstance(node, Selectable): + quoted_identifiers.extend(node.snowflake_plan.quoted_identifiers) + else: + quoted_identifiers.extend(node.quoted_identifiers) + return quoted_identifiers + + @staticmethod + def get_debug_message_for_invalid_identifier( + col: str, quoted_identifiers: List[str] + ) -> str: + """ + Get the debug message and suggest the closest valid identifier for invalid identifier failure. + """ + + def add_single_quote(string: str) -> str: + return f"'{string}'" + + # We can't display all column identifiers in the error message + if len(quoted_identifiers) > 10: + quoted_identifiers_str = f"[{', '.join(add_single_quote(q) for q in quoted_identifiers[:10])}, ...]" + else: + quoted_identifiers_str = ( + f"[{', '.join(add_single_quote(q) for q in quoted_identifiers)}]" + ) + + msg = ( + f"There are existing quoted column identifiers: {quoted_identifiers_str}. " + f"Please use one of them to reference the column. See more details on Snowflake identifier requirements " + f"https://docs.snowflake.com/en/sql-reference/identifiers-syntax" + ) + + # Currently, when Snowpark user a Python string as identifier to access a column: + # 1) if a column name is unquoted and + # a) contains no special characters, it is automatically uppercased and quoted in SQL, or + # b) if it includes special characters, it is simply quoted without uppercasing. + # 2) If the name is explicitly quoted by the user, Snowpark preserves it as-is. + # Therefore, if `col` is an invalid identifier, it is most likely due to 1a) above. + # We attempt to provide a more helpful error message by suggesting the closest valid identifier. + if UNQUOTED_CASE_INSENSITIVE.match(col): + identifier = quote_name_without_upper_casing(col.lower()) + match = difflib.get_close_matches(identifier, quoted_identifiers) + if match: + # if there is an exact match, just remind users this one + if identifier in match: + match = [identifier] + msg = f"{msg}\nDo you mean {' or '.join(add_single_quote(q) for q in match)}?" + return msg + @staticmethod def wrap_exception(func): def wrap(*args, **kwargs): @@ -157,6 +222,16 @@ def wrap(*args, **kwargs): query = getattr(e, "query", None) tb = sys.exc_info()[2] + sf_plan_in_arg = None + for arg in args: + if isinstance(arg, SnowflakePlan): + # this wrapper is triggered through collect or describe queries through + # ServerConnection.get_result_set, SnowflakePlan._analyze_attributes, + # or Selectable._analyze_attributes. In all these cases, there can be at + # most one SnowflakePlan in the args. + sf_plan_in_arg = arg + break + assert e.msg is not None if "unexpected 'as'" in e.msg.lower(): ne = SnowparkClientExceptionMessages.SQL_PYTHON_REPORT_UNEXPECTED_ALIAS( @@ -175,25 +250,33 @@ def wrap(*args, **kwargs): ) raise ne.with_traceback(tb) from None col = match.group(1) - children = [ - arg for arg in args if isinstance(arg, SnowflakePlan) - ] - remapped = [ - SnowflakePlan.Decorator.__wrap_exception_regex_sub.sub( - "", val + remapped = [] + if sf_plan_in_arg is not None: + remapped = [ + SnowflakePlan.Decorator.__wrap_exception_regex_sub.sub( + "", val + ) + for val in sf_plan_in_arg.expr_to_alias.values() + ] + quoted_identifiers = ( + SnowflakePlan.Decorator.get_quoted_identifiers( + sf_plan_in_arg ) - for child in children - for val in child.expr_to_alias.values() - ] + ) if col in remapped: unaliased_cols = ( snowflake.snowpark.dataframe._get_unaliased(col) ) - orig_col_name = ( - unaliased_cols[0] if unaliased_cols else "" + orig_col_name = unaliased_cols[0] if unaliased_cols else col + debug_msg = ( + SnowflakePlan.Decorator.get_debug_message_for_invalid_identifier( + col, quoted_identifiers + ) + if quoted_identifiers + else None ) ne = SnowparkClientExceptionMessages.SQL_PYTHON_REPORT_INVALID_ID( - orig_col_name, query + orig_col_name, query, debug_context=debug_msg ) raise ne.with_traceback(tb) from None elif ( @@ -225,19 +308,11 @@ def wrap(*args, **kwargs): raise ne.with_traceback(tb) from None col = match.group(1) - quoted_identifiers = [] - for child in children: - plan_nodes = child.children_plan_nodes - for node in plan_nodes: - if isinstance(node, Selectable): - quoted_identifiers.extend( - node.snowflake_plan.quoted_identifiers - ) - else: - quoted_identifiers.extend( - node.quoted_identifiers - ) - + quoted_identifiers = ( + SnowflakePlan.Decorator.get_quoted_identifiers( + sf_plan_in_arg + ) + ) # No context available to enhance error message if not quoted_identifiers: ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR( @@ -245,44 +320,11 @@ def wrap(*args, **kwargs): ) raise ne.with_traceback(tb) from None - def add_single_quote(string: str) -> str: - return f"'{string}'" - - # We can't display all column identifiers in the error message - if len(quoted_identifiers) > 10: - quoted_identifiers_str = f"[{', '.join(add_single_quote(q) for q in quoted_identifiers[:10])}, ...]" - else: - quoted_identifiers_str = f"[{', '.join(add_single_quote(q) for q in quoted_identifiers)}]" - - msg = ( - f"There are existing quoted column identifiers: {quoted_identifiers_str}. " - f"Please use one of them to reference the column. See more details on Snowflake identifier requirements " - f"https://docs.snowflake.com/en/sql-reference/identifiers-syntax" + msg = SnowflakePlan.Decorator.get_debug_message_for_invalid_identifier( + col, quoted_identifiers ) - - # Currently, when Snowpark user a Python string as identifier to access a column: - # 1) if a column name is unquoted and - # a) contains no special characters, it is automatically uppercased and quoted in SQL, or - # b) if it includes special characters, it is simply quoted without uppercasing. - # 2) If the name is explicitly quoted by the user, Snowpark preserves it as-is. - # Therefore, if `col` is an invalid identifier, it is most likely due to 1a) above. - # We attempt to provide a more helpful error message by suggesting the closest valid identifier. - if UNQUOTED_CASE_INSENSITIVE.match(col): - identifier = quote_name_without_upper_casing( - col.lower() - ) - match = difflib.get_close_matches( - identifier, quoted_identifiers - ) - if match: - # if there is an exact match, just remind users this one - if identifier in match: - match = [identifier] - msg = f"{msg}\nDo you mean {' or '.join(add_single_quote(q) for q in match)}?" - - e.msg = f"{e.msg}\n{msg}" ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR( - e + e, debug_context=msg ) raise ne.with_traceback(tb) from None elif e.sqlstate == "42601" and "SELECT with no columns" in e.msg: @@ -305,28 +347,18 @@ def search_read_file_node( return result return None - for arg in args: - if isinstance(arg, SnowflakePlan): - read_file_node = search_read_file_node(arg) - if ( - read_file_node - and read_file_node.xml_reader_udtf is not None - ): - row_tag = read_file_node.options.get( - XML_ROW_TAG_STRING - ) - file_path = read_file_node.path - ne = SnowparkClientExceptionMessages.DF_XML_ROW_TAG_NOT_FOUND( - row_tag, file_path - ) - raise ne.with_traceback(tb) from None - # when the describe query fails, the arg is a query string - elif isinstance(arg, str): - if f'"{XML_ROW_DATA_COLUMN_NAME}"' in arg: - ne = ( - SnowparkClientExceptionMessages.DF_XML_ROW_TAG_NOT_FOUND() - ) - raise ne.with_traceback(tb) from None + if sf_plan_in_arg is not None: + read_file_node = search_read_file_node(sf_plan_in_arg) + if ( + read_file_node + and read_file_node.xml_reader_udtf is not None + ): + row_tag = read_file_node.options.get(XML_ROW_TAG_STRING) + file_path = read_file_node.path + ne = SnowparkClientExceptionMessages.DF_XML_ROW_TAG_NOT_FOUND( + row_tag, file_path + ) + raise ne.with_traceback(tb) from None ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR( e diff --git a/src/snowflake/snowpark/exceptions.py b/src/snowflake/snowpark/exceptions.py index 1142e9545e..c989e92529 100644 --- a/src/snowflake/snowpark/exceptions.py +++ b/src/snowflake/snowpark/exceptions.py @@ -99,11 +99,12 @@ def __init__( self.sql_error_code = sql_error_code or getattr(self.conn_error, "errno", None) self.raw_message = raw_message or getattr(self.conn_error, "raw_msg", None) self.debug_context = debug_context + debug_message = f"\n{self.debug_context}" if self.debug_context else "" pretty_error_code = f"({self.error_code}): " if self.error_code else "" pretty_sfqid = f"{self.sfqid}: " if self.sfqid else "" self._pretty_msg = ( - f"{pretty_error_code}{pretty_sfqid}{self.message}{self.debug_context or ''}" + f"{pretty_error_code}{pretty_sfqid}{self.message}{debug_message}" ) def __repr__(self): diff --git a/tests/unit/test_error_message.py b/tests/unit/test_error_message.py index ece69d4267..0d1690dca4 100644 --- a/tests/unit/test_error_message.py +++ b/tests/unit/test_error_message.py @@ -28,8 +28,9 @@ def test_sql_exception_from_programming_error(debug_context): assert ex.sql_error_code == 123 assert ex.raw_message == "test message" assert ex.debug_context == debug_context + debug_message = f"\n{debug_context}" if debug_context else "" - assert str(ex) == f"(1304): 0000-1111: 000123: test message{debug_context or ''}" + assert str(ex) == f"(1304): 0000-1111: 000123: test message{debug_message}" def test_sql_exception_from_operational_error():