Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 115 additions & 83 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,71 @@ class Decorator:
)
__wrap_exception_regex_sub = re.compile(r"""^"|"$""")

@staticmethod
def get_quoted_identifiers(
sf_plan_in_arg: Optional["SnowflakePlan"],
) -> List[str]:
"""
Get all quoted identifiers from the children plan nodes of the SnowflakePlan.
"""
if sf_plan_in_arg is None:
return []

from snowflake.snowpark._internal.analyzer.select_statement import (
Selectable,
)

quoted_identifiers = []
plan_nodes = sf_plan_in_arg.children_plan_nodes
for node in plan_nodes:
if isinstance(node, Selectable):
quoted_identifiers.extend(node.snowflake_plan.quoted_identifiers)
else:
quoted_identifiers.extend(node.quoted_identifiers)
return quoted_identifiers

@staticmethod
def get_debug_message_for_invalid_identifier(
col: str, quoted_identifiers: List[str]
) -> str:
"""
Get the debug message and suggest the closest valid identifier for invalid identifier failure.
"""

def add_single_quote(string: str) -> str:
return f"'{string}'"

# We can't display all column identifiers in the error message
if len(quoted_identifiers) > 10:
quoted_identifiers_str = f"[{', '.join(add_single_quote(q) for q in quoted_identifiers[:10])}, ...]"
else:
quoted_identifiers_str = (
f"[{', '.join(add_single_quote(q) for q in quoted_identifiers)}]"
)

msg = (
f"There are existing quoted column identifiers: {quoted_identifiers_str}. "
f"Please use one of them to reference the column. See more details on Snowflake identifier requirements "
f"https://docs.snowflake.com/en/sql-reference/identifiers-syntax"
)

# Currently, when Snowpark user a Python string as identifier to access a column:
# 1) if a column name is unquoted and
# a) contains no special characters, it is automatically uppercased and quoted in SQL, or
# b) if it includes special characters, it is simply quoted without uppercasing.
# 2) If the name is explicitly quoted by the user, Snowpark preserves it as-is.
# Therefore, if `col` is an invalid identifier, it is most likely due to 1a) above.
# We attempt to provide a more helpful error message by suggesting the closest valid identifier.
if UNQUOTED_CASE_INSENSITIVE.match(col):
identifier = quote_name_without_upper_casing(col.lower())
match = difflib.get_close_matches(identifier, quoted_identifiers)
if match:
# if there is an exact match, just remind users this one
if identifier in match:
match = [identifier]
msg = f"{msg}\nDo you mean {' or '.join(add_single_quote(q) for q in match)}?"
return msg

@staticmethod
def wrap_exception(func):
def wrap(*args, **kwargs):
Expand All @@ -157,6 +222,16 @@ def wrap(*args, **kwargs):

query = getattr(e, "query", None)
tb = sys.exc_info()[2]
sf_plan_in_arg = None
for arg in args:
if isinstance(arg, SnowflakePlan):
# this wrapper is triggered through collect or describe queries through
# ServerConnection.get_result_set, SnowflakePlan._analyze_attributes,
# or Selectable._analyze_attributes. In all these cases, there can be at
# most one SnowflakePlan in the args.
sf_plan_in_arg = arg
break

assert e.msg is not None
if "unexpected 'as'" in e.msg.lower():
ne = SnowparkClientExceptionMessages.SQL_PYTHON_REPORT_UNEXPECTED_ALIAS(
Expand All @@ -175,25 +250,33 @@ def wrap(*args, **kwargs):
)
raise ne.with_traceback(tb) from None
col = match.group(1)
children = [
arg for arg in args if isinstance(arg, SnowflakePlan)
]
remapped = [
SnowflakePlan.Decorator.__wrap_exception_regex_sub.sub(
"", val
remapped = []
if sf_plan_in_arg is not None:
remapped = [
SnowflakePlan.Decorator.__wrap_exception_regex_sub.sub(
"", val
)
for val in sf_plan_in_arg.expr_to_alias.values()
]
quoted_identifiers = (
SnowflakePlan.Decorator.get_quoted_identifiers(
sf_plan_in_arg
)
for child in children
for val in child.expr_to_alias.values()
]
)
if col in remapped:
unaliased_cols = (
snowflake.snowpark.dataframe._get_unaliased(col)
)
orig_col_name = (
unaliased_cols[0] if unaliased_cols else "<colname>"
orig_col_name = unaliased_cols[0] if unaliased_cols else col
debug_msg = (
SnowflakePlan.Decorator.get_debug_message_for_invalid_identifier(
col, quoted_identifiers
)
if quoted_identifiers
else None
)
ne = SnowparkClientExceptionMessages.SQL_PYTHON_REPORT_INVALID_ID(
orig_col_name, query
orig_col_name, query, debug_context=debug_msg
)
raise ne.with_traceback(tb) from None
elif (
Expand Down Expand Up @@ -225,64 +308,23 @@ def wrap(*args, **kwargs):
raise ne.with_traceback(tb) from None
col = match.group(1)

quoted_identifiers = []
for child in children:
plan_nodes = child.children_plan_nodes
for node in plan_nodes:
if isinstance(node, Selectable):
quoted_identifiers.extend(
node.snowflake_plan.quoted_identifiers
)
else:
quoted_identifiers.extend(
node.quoted_identifiers
)

quoted_identifiers = (
SnowflakePlan.Decorator.get_quoted_identifiers(
sf_plan_in_arg
)
)
# No context available to enhance error message
if not quoted_identifiers:
ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR(
e
)
raise ne.with_traceback(tb) from None

def add_single_quote(string: str) -> str:
return f"'{string}'"

# We can't display all column identifiers in the error message
if len(quoted_identifiers) > 10:
quoted_identifiers_str = f"[{', '.join(add_single_quote(q) for q in quoted_identifiers[:10])}, ...]"
else:
quoted_identifiers_str = f"[{', '.join(add_single_quote(q) for q in quoted_identifiers)}]"

msg = (
f"There are existing quoted column identifiers: {quoted_identifiers_str}. "
f"Please use one of them to reference the column. See more details on Snowflake identifier requirements "
f"https://docs.snowflake.com/en/sql-reference/identifiers-syntax"
msg = SnowflakePlan.Decorator.get_debug_message_for_invalid_identifier(
col, quoted_identifiers
)

# Currently, when Snowpark user a Python string as identifier to access a column:
# 1) if a column name is unquoted and
# a) contains no special characters, it is automatically uppercased and quoted in SQL, or
# b) if it includes special characters, it is simply quoted without uppercasing.
# 2) If the name is explicitly quoted by the user, Snowpark preserves it as-is.
# Therefore, if `col` is an invalid identifier, it is most likely due to 1a) above.
# We attempt to provide a more helpful error message by suggesting the closest valid identifier.
if UNQUOTED_CASE_INSENSITIVE.match(col):
identifier = quote_name_without_upper_casing(
col.lower()
)
match = difflib.get_close_matches(
identifier, quoted_identifiers
)
if match:
# if there is an exact match, just remind users this one
if identifier in match:
match = [identifier]
msg = f"{msg}\nDo you mean {' or '.join(add_single_quote(q) for q in match)}?"

e.msg = f"{e.msg}\n{msg}"
ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR(
e
e, debug_context=msg
)
raise ne.with_traceback(tb) from None
elif e.sqlstate == "42601" and "SELECT with no columns" in e.msg:
Expand All @@ -305,28 +347,18 @@ def search_read_file_node(
return result
return None

for arg in args:
if isinstance(arg, SnowflakePlan):
read_file_node = search_read_file_node(arg)
if (
read_file_node
and read_file_node.xml_reader_udtf is not None
):
row_tag = read_file_node.options.get(
XML_ROW_TAG_STRING
)
file_path = read_file_node.path
ne = SnowparkClientExceptionMessages.DF_XML_ROW_TAG_NOT_FOUND(
row_tag, file_path
)
raise ne.with_traceback(tb) from None
# when the describe query fails, the arg is a query string
elif isinstance(arg, str):
if f'"{XML_ROW_DATA_COLUMN_NAME}"' in arg:
ne = (
SnowparkClientExceptionMessages.DF_XML_ROW_TAG_NOT_FOUND()
)
raise ne.with_traceback(tb) from None
if sf_plan_in_arg is not None:
read_file_node = search_read_file_node(sf_plan_in_arg)
if (
read_file_node
and read_file_node.xml_reader_udtf is not None
):
row_tag = read_file_node.options.get(XML_ROW_TAG_STRING)
file_path = read_file_node.path
ne = SnowparkClientExceptionMessages.DF_XML_ROW_TAG_NOT_FOUND(
row_tag, file_path
)
raise ne.with_traceback(tb) from None

ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR(
e
Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/snowpark/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ def __init__(
self.sql_error_code = sql_error_code or getattr(self.conn_error, "errno", None)
self.raw_message = raw_message or getattr(self.conn_error, "raw_msg", None)
self.debug_context = debug_context
debug_message = f"\n{self.debug_context}" if self.debug_context else ""

pretty_error_code = f"({self.error_code}): " if self.error_code else ""
pretty_sfqid = f"{self.sfqid}: " if self.sfqid else ""
self._pretty_msg = (
f"{pretty_error_code}{pretty_sfqid}{self.message}{self.debug_context or ''}"
f"{pretty_error_code}{pretty_sfqid}{self.message}{debug_message}"
)

def __repr__(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_error_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def test_sql_exception_from_programming_error(debug_context):
assert ex.sql_error_code == 123
assert ex.raw_message == "test message"
assert ex.debug_context == debug_context
debug_message = f"\n{debug_context}" if debug_context else ""

assert str(ex) == f"(1304): 0000-1111: 000123: test message{debug_context or ''}"
assert str(ex) == f"(1304): 0000-1111: 000123: test message{debug_message}"


def test_sql_exception_from_operational_error():
Expand Down
Loading