diff --git a/refactor/ast.py b/refactor/ast.py index 3d23262..d1a5c4f 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -139,6 +139,7 @@ class PreciseUnparser(BaseUnparser): for major statements and child node recovery.""" def __init__(self, *args: Any, **kwargs: Any) -> None: + self.empty_lines = kwargs.pop("empty_lines", None) self._visited_comment_lines: set[int] = set() super().__init__(*args, **kwargs) @@ -222,10 +223,12 @@ def _write_if_unseen_comment( preceding_comments = [] for offset, line in enumerate(reversed(lines[:node_start])): comment_begin = line.find("#") - if comment_begin == -1 or comment_begin != node.col_offset: + not_valid_comment: bool = comment_begin == -1 or comment_begin != node.col_offset + no_empty_lines: bool = not self.empty_lines or line and not line.isspace() + if no_empty_lines and not_valid_comment: break - preceding_comments.append((node_start - offset, line, comment_begin)) + preceding_comments.append((node_start - offset, line, node.col_offset)) for comment_info in reversed(preceding_comments): _write_if_unseen_comment(*comment_info) @@ -240,7 +243,7 @@ def _write_if_unseen_comment( _write_if_unseen_comment( line_no=node_end + offset, line=line, - comment_begin=comment_begin, + comment_begin=node.col_offset, ) def collect_comments(self, node: ast.AST) -> ContextManager[None]: diff --git a/tests/test_ast.py b/tests/test_ast.py index 1e5dc1d..b0eddf7 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -69,7 +69,7 @@ def test_split_lines_with_encoding(case): else: start_line = lines[lineno][col_offset:] end_line = lines[end_lineno][:end_col_offset] - match = start_line + lines[lineno + 1 : end_lineno].join() + end_line + match = start_line + lines[lineno + 1: end_lineno].join() + end_line assert str(match) == ast.get_source_segment(case, node) @@ -240,3 +240,113 @@ def foo(): base = PreciseUnparser(source=source) assert base.unparse(tree) + "\n" == expected_src + + +def test_precise_empty_lines_unparser(): + source = textwrap.dedent( + """\ + def func(): + if something: + print( + call(.1), + maybe+something_else, + maybe / other, + thing . a + ) + """ + ) + + expected_src = textwrap.dedent( + """\ + def func(): + if something: + print(call(.1), maybe+something_else, maybe / other, thing . a, 3) + """ + ) + + tree = ast.parse(source) + tree.body[0].body[0].body[0].value.args.append(ast.Constant(3)) + + base = PreciseUnparser(source=source, empty_lines=True) + assert base.unparse(tree) + "\n" == expected_src + + +def test_precise_empty_lines_unparser_indented_literals(): + source = textwrap.dedent( + """\ + def func(): + if something: + print( + "bleh" + "zoom" + ) + """ + ) + + expected_src = textwrap.dedent( + """\ + def func(): + if something: + print("bleh" + "zoom", 3) + """ + ) + + tree = ast.parse(source) + tree.body[0].body[0].body[0].value.args.append(ast.Constant(3)) + + base = PreciseUnparser(source=source, empty_lines=True) + assert base.unparse(tree) + "\n" == expected_src + + +def test_precise_empty_lines_unparser_comments(): + source = textwrap.dedent( + """\ +def foo(): +# unindented comment + # indented but not connected comment + + # a + # a1 + print() + # a2 + print() + # b + + # b2 + print( + c # e + ) + # c + print(d) + # final comment +""" + ) + + expected_src = ( + """\ +def foo(): + # indented but not connected comment + + # a + # a1 + print() + # a2 + print() + # b + + # b2 + print( + c # e + ) + # c +""" + ) + + tree = ast.parse(source) + + # # Remove the print(d) + tree.body[0].body.pop() + + base = PreciseUnparser(source=source, empty_lines=True) + assert base.unparse(tree) + "\n" == expected_src