-
Notifications
You must be signed in to change notification settings - Fork 143
SNOW-2084165 Add dataframe operation lineage on SnowparkSQLException #3339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5d7117a
4ddd5d8
6b3e259
c62053a
44f93a2
7c58414
ee12e72
c3f0374
f5b4946
108dcdc
432ec6b
f6c72a3
b410188
e1d15b4
34bd358
3c9f5f8
b239195
6dfd0b4
db6a354
b5bd7e5
24a4e77
eaa59b3
c197a2e
8bc7243
b8f4428
7e22397
617e105
d79129f
3cb32d7
4bed385
f89114a
ada9782
9388547
71a2894
5b74a13
890a06d
746d054
b4b34d4
1353d26
fbe633a
361c23f
645e17c
88ae65e
02bdab4
c5a6d50
fe60424
f6bfb4f
a89f95a
6f60d3b
896b9a7
1204935
89eab2b
ffad17e
3ab7c3d
1e5ce2d
14d38dd
bd3961d
12bc2bc
2c87fc0
e449e47
c8b4b4a
f5cdd8f
19b226d
a45bfdc
91ae396
ba2f4d7
6176b5d
ab12938
853cd40
30d0cdc
e754699
35dc3dd
8fc3de1
bd76e46
b61234b
a1f679b
04de2a3
1f0b227
04d3770
02eee1f
2ffada8
73e86b6
32560cb
5a38b4a
394b10b
cc8eb7c
697e969
ff95ad8
452f9ec
670295d
8b12c3e
19b20ff
24ff32c
acb68c9
a53c60c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,200 @@ | ||||||
| # | ||||||
| # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. | ||||||
| # | ||||||
|
|
||||||
| from functools import cached_property | ||||||
| import os | ||||||
| import sys | ||||||
| from typing import Dict, List, Optional | ||||||
| import itertools | ||||||
|
|
||||||
| from snowflake.snowpark._internal.ast.batch import get_dependent_bind_ids | ||||||
| from snowflake.snowpark._internal.ast.utils import __STRING_INTERNING_MAP__ | ||||||
| import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto | ||||||
|
|
||||||
| UNKNOWN_FILE = "__UNKNOWN_FILE__" | ||||||
| SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH = ( | ||||||
| "SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH" | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class DataFrameTraceNode: | ||||||
| """A node representing a dataframe operation in the DAG that represents the lineage of a DataFrame.""" | ||||||
|
|
||||||
| def __init__(self, batch_id: int, stmt_cache: Dict[int, proto.Stmt]) -> None: | ||||||
| self.batch_id = batch_id | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I would argue that this isn't meant to be batch ID anymore. Within each Python session that imports the Snowpark module, each AST ID for a Table or Dataframe will be a UID. |
||||||
| self.stmt_cache = stmt_cache | ||||||
|
|
||||||
| @cached_property | ||||||
| def children(self) -> set[int]: | ||||||
| """Returns the batch_ids of the children of this node.""" | ||||||
| return get_dependent_bind_ids(self.stmt_cache[self.batch_id]) | ||||||
|
|
||||||
| def get_src(self) -> Optional[proto.SrcPosition]: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the hybrid client prototype we are using a slightly different method to get the source location; by just using inspect to walk the stack to the appropriate source location. We have to do this because modin is not using any of the AST stuff, but it's also relatively straight forward. I sort of want to use your debugging tool for snowpandas as well; but we may want to refactor this so we don't require any of the protobuf work.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about using this function?
Essentially we seem to have three approaches to this problem. I'm /less/ of a fan of the AST because it doesn't help pandas for this type of debugging. but it seems like we might be able to consolidate w/ the open telemetry approach. |
||||||
| """The source Stmt of the DataFrame described by the batch_id.""" | ||||||
| stmt = self.stmt_cache[self.batch_id] | ||||||
| api_call = stmt.bind.expr.WhichOneof("variant") | ||||||
| return ( | ||||||
| getattr(stmt.bind.expr, api_call).src | ||||||
| if api_call and getattr(stmt.bind.expr, api_call).HasField("src") | ||||||
| else None | ||||||
| ) | ||||||
|
|
||||||
| def _read_file( | ||||||
| self, filename, start_line, end_line, start_column, end_column | ||||||
sfc-gh-aalam marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| ) -> str: | ||||||
| """Read the relevant code snippets of where the DataFrame was created. The filename given here | ||||||
| must have read permissions for the executing user.""" | ||||||
| with open(filename) as f: | ||||||
| code_lines = [] | ||||||
| if sys.version_info >= (3, 11): | ||||||
| # Skip to start_line and read only the required lines | ||||||
| lines = itertools.islice(f, start_line - 1, end_line) | ||||||
| code_lines = list(lines) | ||||||
| if start_line == end_line: | ||||||
| code_lines[0] = code_lines[0][start_column:end_column] | ||||||
| else: | ||||||
| code_lines[0] = code_lines[0][start_column:] | ||||||
| code_lines[-1] = code_lines[-1][:end_column] | ||||||
| else: | ||||||
| # For python 3.9/3.10, we do not extract the end line from the source code | ||||||
| # so we just read the start line and return. | ||||||
| for line in itertools.islice(f, start_line - 1, start_line): | ||||||
| code_lines.append(line) | ||||||
|
|
||||||
| code_lines = [line.rstrip() for line in code_lines] | ||||||
| return "\n".join(code_lines) | ||||||
|
|
||||||
| @cached_property | ||||||
| def source_id(self) -> str: | ||||||
| """Unique identifier of the location of the DataFrame creation in the source code.""" | ||||||
| src = self.get_src() | ||||||
| if src is None: # pragma: no cover | ||||||
| return "" | ||||||
|
|
||||||
| fileno = src.file | ||||||
| start_line = src.start_line | ||||||
| start_column = src.start_column | ||||||
| end_line = src.end_line | ||||||
| end_column = src.end_column | ||||||
| return f"{fileno}:{start_line}:{start_column}-{end_line}:{end_column}" | ||||||
|
|
||||||
| def get_source_snippet(self) -> str: | ||||||
| """Read the source file and extract the snippet where the dataframe is created.""" | ||||||
| src = self.get_src() | ||||||
| if src is None: # pragma: no cover | ||||||
| return "No source" | ||||||
|
|
||||||
| # get the latest mapping of fileno to filename | ||||||
| _fileno_to_filename_map = {v: k for k, v in __STRING_INTERNING_MAP__.items()} | ||||||
| fileno = src.file | ||||||
| filename = _fileno_to_filename_map.get(fileno, UNKNOWN_FILE) | ||||||
|
|
||||||
| start_line = src.start_line | ||||||
| end_line = src.end_line | ||||||
| start_column = src.start_column | ||||||
| end_column = src.end_column | ||||||
|
|
||||||
| # Build the code identifier to find the operations where the DataFrame was created | ||||||
| if sys.version_info >= (3, 11): | ||||||
| code_identifier = ( | ||||||
| f"{filename}|{start_line}:{start_column}-{end_line}:{end_column}" | ||||||
| ) | ||||||
| else: | ||||||
| code_identifier = f"{filename}|{start_line}" | ||||||
|
|
||||||
| if filename != UNKNOWN_FILE and os.access(filename, os.R_OK): | ||||||
| # If the file is readable, read the code snippet | ||||||
| code = self._read_file( | ||||||
| filename, start_line, end_line, start_column, end_column | ||||||
| ) | ||||||
| return f"{code_identifier}: {code}" | ||||||
| return code_identifier # pragma: no cover | ||||||
|
|
||||||
|
|
||||||
| def _get_df_transform_trace( | ||||||
| batch_id: int, | ||||||
| stmt_cache: Dict[int, proto.Stmt], | ||||||
| ) -> List[DataFrameTraceNode]: | ||||||
| """Helper function to get the transform trace of the dataframe involved in the exception. | ||||||
| It gathers the lineage in the following way: | ||||||
|
|
||||||
| 1. Start by creating a DataFrameTraceNode for the given batch_id. | ||||||
| 2. We use BFS to traverse the lineage using the node created in 1. as the first layer. | ||||||
| 3. During each iteration, we check if the node's source_id has been visited. If not, | ||||||
| we add it to the visited set and append its source format to the trace. This step | ||||||
| is needed to avoid source_id added multiple times in lineage due to loops. | ||||||
| 4. We then explore the next layer by adding the children of the current node to the | ||||||
| next layer. We check if the child ID has been visited and if not, we add it to the | ||||||
| visited set and append the DataFrameTraceNode for it to the next layer. | ||||||
| 5. We repeat this process until there are no more nodes to explore. | ||||||
|
|
||||||
| Args: | ||||||
| batch_id: The batch ID of the dataframe involved in the exception. | ||||||
| stmt_cache: The statement cache of the session. | ||||||
|
|
||||||
| Returns: | ||||||
| A list of DataFrameTraceNode objects representing the transform trace of the dataframe. | ||||||
| """ | ||||||
| visited_batch_id = set() | ||||||
| visited_source_id = set() | ||||||
|
|
||||||
| visited_batch_id.add(batch_id) | ||||||
| curr = [DataFrameTraceNode(batch_id, stmt_cache)] | ||||||
| lineage = [] | ||||||
|
|
||||||
| while curr: | ||||||
| next: List[DataFrameTraceNode] = [] | ||||||
| for node in curr: | ||||||
| # tracing updates | ||||||
| source_id = node.source_id | ||||||
| if source_id not in visited_source_id: | ||||||
| visited_source_id.add(source_id) | ||||||
| lineage.append(node) | ||||||
|
|
||||||
| # explore next layer | ||||||
| for child_id in node.children: | ||||||
| if child_id in visited_batch_id: | ||||||
| continue | ||||||
| visited_batch_id.add(child_id) | ||||||
| next.append(DataFrameTraceNode(child_id, stmt_cache)) | ||||||
|
|
||||||
| curr = next | ||||||
|
|
||||||
| return lineage | ||||||
|
|
||||||
|
|
||||||
| def get_df_transform_trace_message( | ||||||
| df_ast_id: int, stmt_cache: Dict[int, proto.Stmt] | ||||||
| ) -> Optional[str]: | ||||||
| """Get the transform trace message for the dataframe involved in the exception. | ||||||
|
|
||||||
| Args: | ||||||
| df_ast_id: The AST ID of the dataframe involved in the exception. | ||||||
| stmt_cache: The statement cache of the session. | ||||||
|
|
||||||
| Returns: | ||||||
| A string representing the transform trace message. | ||||||
| """ | ||||||
| df_transform_trace_nodes = _get_df_transform_trace(df_ast_id, stmt_cache) | ||||||
| if len(df_transform_trace_nodes) == 0: # pragma: no cover | ||||||
| return None | ||||||
|
|
||||||
| df_transform_trace_length = len(df_transform_trace_nodes) | ||||||
| show_trace_length = int( | ||||||
| os.environ.get(SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH, 5) | ||||||
| ) | ||||||
|
|
||||||
| debug_info_lines = [ | ||||||
| "\n\n--- Additional Debug Information ---\n", | ||||||
| f"Trace of the most recent dataframe operations associated with the error (total {df_transform_trace_length}):\n", | ||||||
| ] | ||||||
| for node in df_transform_trace_nodes[:show_trace_length]: | ||||||
| debug_info_lines.append(node.get_source_snippet()) | ||||||
| if df_transform_trace_length > show_trace_length: | ||||||
| debug_info_lines.append( | ||||||
| f"... and {df_transform_trace_length - show_trace_length} more.\nYou can increase " | ||||||
| f"the lineage length by setting {SNOWPARK_PYTHON_DATAFRAME_TRANSFORM_TRACE_LENGTH} " | ||||||
| "environment variable." | ||||||
| ) | ||||||
| return "\n".join(debug_info_lines) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible that there are multiple plans in the args? and if so, is the left most plan guaranteed to be the most recent plan?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have some code below:
which means there will be multiple plans ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is possible for args to have multiple
SnowflakePlansbut only the first arg is relevant to the failure. This is not obvious and we have to see the usage of this decorator to be sure of this.The decorator is used for the following functions
SnowflakePlan._analyze_attributeslinkSelectable._analyze_attributeslinkSnowflakePlanBuilder.buildlinkSnowflakePlanBuilder.build_binarylinkServerConnection.get_result_setlinkOut of these only
buildandbuild_binarycan have multiple snowflake plans. For each of them, the source of failure can only possible come from describe queries. But we have already wrappedSnowflakePlan._analyze_attributesandSelectable._analyze_attributesand they can have at most one SnowflakePlan. Outside of describe queries, the wrapper will be triggered to wrap programming exception on getting full result for example for.show()or.collect(). For these cases, we have wrappedget_result_setwhich only has oneSnowflakePlanarg. IMO, the wrapper ofbuildandbuild_binaryare redundant. @sfc-gh-jdu did you notice any other case where parsing plans in these functions is actually giving you more information?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, thanks for the context.
I think we can put a brief context into the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#3437