diff --git a/src/lean_dojo/data_extraction/ast.py b/src/lean_dojo/data_extraction/ast.py index 2b886408..95a456a8 100644 --- a/src/lean_dojo/data_extraction/ast.py +++ b/src/lean_dojo/data_extraction/ast.py @@ -1,8 +1,24 @@ +from typing_extensions import TypeGuard + from lxml import etree from pathlib import Path from dataclasses import dataclass, field from xml.sax.saxutils import escape, unescape -from typing import List, Dict, Any, Optional, Callable, Tuple, Generator +from typing import ( + List, + Dict, + Any, + Optional, + Callable, + Tuple, + Generator, + TypeVar, + Union, + cast, + Protocol, + Type, + Sequence, +) from ..utils import ( camel_case, @@ -13,13 +29,33 @@ ) from .lean import Pos, LeanFile +N = TypeVar("N", bound="Node", covariant=True) +T = TypeVar("T") + + +def cast_away_optional(x: Optional[T]) -> T: + assert x is not None + return x + + +def cast_list_str(x: Union[None, str, List[str]]) -> List[str]: + if isinstance(x, list) and all(isinstance(i, str) for i in x): + return x # Return the list if all elements are strings + raise TypeError(f"Expected None, str, or List[str], but got {type(x).__name__}") + + +def cast_str(x: Union[None, str, List[str]]) -> str: + if isinstance(x, str): + return x + raise TypeError(f"Expected str, but got {type(x).__name__}") + @dataclass(frozen=True) class Node: lean_file: LeanFile start: Optional[Pos] end: Optional[Pos] - children: List["Node"] = field(repr=False) + children: Sequence["Node"] = field(repr=False) @classmethod def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "Node": @@ -27,7 +63,7 @@ def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "Node": return subcls.from_data(node_data, lean_file) @classmethod - def _kind_to_node_type(cls, kind: str) -> type["Node"]: + def _kind_to_node_type(cls, kind: str) -> Type["Node"]: prefix = "Lean.Parser." if kind.startswith(prefix): kind = kind[len(prefix) :] @@ -46,9 +82,9 @@ def kind(cls: type) -> str: def traverse_preorder( self, callback: Callable[["Node", List["Node"]], Any], - node_cls: Optional[type], + node_cls: Optional[type["Node"]], parents: List["Node"] = [], - ) -> None: + ): if node_cls is None or isinstance(self, node_cls): if callback(self, parents): return @@ -78,7 +114,7 @@ def to_xml(self, parent: etree.Element) -> None: child.to_xml(tree) @classmethod - def from_xml(cls, tree: etree.Element, lean_file: LeanFile) -> "Node": + def from_xml(cls: Type[N], tree: etree.Element, lean_file: LeanFile) -> N: subcls = globals()[tree.tag] start = Pos.from_str(tree.attrib["start"]) if "start" in tree.attrib else None end = Pos.from_str(tree.attrib["end"]) if "end" in tree.attrib else None @@ -146,12 +182,12 @@ class AtomNode(Node): val: str @classmethod - def from_data( - cls, atom_data: Dict[str, Any], lean_file: LeanFile - ) -> Optional["AtomNode"]: + def from_data(cls, atom_data: Dict[str, Any], lean_file: LeanFile) -> "AtomNode": info = atom_data["info"] - start, end = _parse_pos(info, lean_file) - + pos_pair = _parse_pos(info, lean_file) + if pos_pair is None: + raise ValueError("Synthetic atom nodes are not supported") + start, end = pos_pair if "original" in info: leading = info["original"]["leading"] trailing = info["original"]["trailing"] @@ -177,11 +213,12 @@ class IdentNode(Node): def_end: Optional[Pos] = None @classmethod - def from_data( - cls, ident_data: Dict[str, Any], lean_file: LeanFile - ) -> Optional["IdentNode"]: + def from_data(cls, ident_data: Dict[str, Any], lean_file: LeanFile) -> "IdentNode": info = ident_data["info"] - start, end = _parse_pos(info, lean_file) + pos_pair = _parse_pos(info, lean_file) + if pos_pair is None: + raise ValueError("Synthetic ident nodes are not supported") + start, end = pos_pair assert ident_data["preresolved"] == [] if "original" in info: @@ -212,10 +249,14 @@ def is_leaf(node: Node) -> bool: return isinstance(node, AtomNode) or isinstance(node, IdentNode) +def filter_leaf(nodes: Sequence[Node]) -> Sequence[Union[AtomNode | IdentNode]]: + return [node for node in nodes if isinstance(node, (AtomNode, IdentNode))] + + @dataclass(frozen=True) class FileNode(Node): @classmethod - def from_data(cls, data: Dict[str, Any], lean_file: LeanFile) -> "FileNode": + def from_data(cls: Type[N], data: Dict[str, Any], lean_file: LeanFile) -> N: children = [] def _get_closure(node: Node, child_spans: List[Tuple[Pos, Pos]]): @@ -316,7 +357,7 @@ def from_data( return cls(lean_file, start, end, children) def get_ident(self) -> str: - return "".join(gc.val for gc in self.children if is_leaf(gc)) + return "".join(gc.val for gc in filter_leaf(self.children)) @dataclass(frozen=True) @@ -363,9 +404,9 @@ def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "GroupNode class MathlibTacticLemmaNode(Node): name: str full_name: Optional[str] = None - _is_private_decl: Optional[bool] = ( - False # `_is_private` doesn't play well with lxml. - ) + # If the Private usage state is False, the user will not explicitly use the parameter None. + # And False and None are handled in the same way, so there is no need to use Optional. + _is_private_decl: bool = False # `_is_private` doesn't play well with lxml. @classmethod def from_data( @@ -412,11 +453,11 @@ def is_mutual(self) -> bool: @dataclass(frozen=True) class LemmaNode(Node): - name: str + name: Optional[str] full_name: Optional[str] = None - _is_private_decl: Optional[bool] = ( - False # `_is_private` doesn't play well with lxml. - ) + # If the Private usage state is False, the user will not explicitly use the parameter None. + # And False and None are handled in the same way, so there is no need to use Optional. + _is_private_decl: bool = False # `_is_private` doesn't play well with lxml. @classmethod def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "LemmaNode": @@ -468,7 +509,7 @@ def is_mutual(self) -> bool: @dataclass(frozen=True) class CommandDeclarationNode(Node): - name: str + name: Optional[str] full_name: Optional[str] = None @classmethod @@ -514,7 +555,7 @@ def is_theorem(self) -> bool: def get_theorem_node(self) -> "CommandTheoremNode": assert self.is_theorem - return self.children[1] + return cast(CommandTheoremNode, self.children[1]) @property def is_example(self) -> bool: @@ -547,7 +588,9 @@ def from_data( def is_private(self) -> bool: result = False - def _callback(node: CommandPrivateNode, _) -> bool: + def _callback(node: Node, _) -> bool: + if not isinstance(node, CommandDeclmodifiersNode): + raise TypeError("Excepted CommandDeclmodifiersNode") nonlocal result result = True return True @@ -630,7 +673,7 @@ def from_data( @dataclass(frozen=True) class CommandStructureNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -682,7 +725,7 @@ def from_data( @dataclass(frozen=True) class CommandClassinductiveNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -776,7 +819,7 @@ def get_ident(self) -> Optional[str]: @dataclass(frozen=True) class StdTacticAliasAliasNode(Node): - name: str + name: Optional[str] full_name: Optional[str] = None @classmethod @@ -826,9 +869,13 @@ def from_data( children[5], (LeanBinderidentNode, LeanBinderidentAntiquotNode) ) name.append(children[5].get_ident()) - name = [n for n in name if n is not None] - return cls(lean_file, start, end, children, name) + def _filter_names(name_lst: List[str | None]) -> List[str]: + return [n for n in name_lst if n is not None] + + names = _filter_names(name) + + return cls(lean_file, start, end, children, names) @property def is_mutual(self) -> bool: @@ -837,7 +884,7 @@ def is_mutual(self) -> bool: @dataclass(frozen=True) class CommandAbbrevNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -862,7 +909,7 @@ def from_data( @dataclass(frozen=True) class CommandOpaqueNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -887,7 +934,7 @@ def from_data( @dataclass(frozen=True) class CommandAxiomNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -912,7 +959,7 @@ def from_data( @dataclass(frozen=True) class CommandExampleNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -928,7 +975,7 @@ def from_data( @dataclass(frozen=True) class CommandInstanceNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -957,7 +1004,7 @@ def from_data( @dataclass(frozen=True) class CommandDefNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -988,7 +1035,7 @@ def from_data( @dataclass(frozen=True) class CommandDefinitionNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -1127,11 +1174,11 @@ def from_data( @dataclass(frozen=True) class CommandTheoremNode(Node): - name: str + name: Optional[str] full_name: Optional[str] = None - _is_private_decl: Optional[bool] = ( - False # `_is_private` doesn't play well with lxml. - ) + # If the Private usage state is False, the user will not explicitly use the parameter None. + # And False and None are handled in the same way, so there is no need to use Optional. + _is_private_decl: bool = False # `_is_private` doesn't play well with lxml. @classmethod def from_data( @@ -1216,6 +1263,15 @@ def from_data( return cls(lean_file, start, end, children) +class TacticProtocol(Protocol): + def get_tactic_nodes( + self, atomic_only: bool = False + ) -> Generator[Node, None, None]: ... + + +class TacticNode(Node, TacticProtocol): ... + + @dataclass(frozen=True) class TacticTacticseq1IndentedAntiquotNode(Node): @classmethod @@ -1227,9 +1283,7 @@ def from_data( children = _parse_children(node_data, lean_file) return cls(lean_file, start, end, children) - def get_tactic_nodes( - self, atomic_only: bool = False - ) -> Generator[Node, None, None]: + def get_tactic_nodes(self, atomic_only: bool = False) -> None: return @@ -1255,7 +1309,8 @@ def from_data( def get_tactic_nodes( self, atomic_only: bool = False ) -> Generator[Node, None, None]: - yield from self.children[0].get_tactic_nodes(atomic_only) + child = cast(TacticNode, self.children[0]) + yield from child.get_tactic_nodes(atomic_only) @dataclass(frozen=True) @@ -1344,6 +1399,7 @@ def _callback(x, _) -> bool: nonlocal result result = True return True + return False node.traverse_preorder(_callback, node_cls=None) return result @@ -1373,18 +1429,6 @@ def from_data( return cls(lean_file, start, end, children) -@dataclass(frozen=True) -class ModulePreludeNode(Node): - @classmethod - def from_data( - cls, node_data: Dict[str, Any], lean_file: LeanFile - ) -> "ModulePreludeNode": - assert node_data["info"] == "none" - start, end = None, None - children = _parse_children(node_data, lean_file) - return cls(lean_file, start, end, children) - - @dataclass(frozen=True) class ModuleImportNode(Node): module: Optional[str] @@ -1419,8 +1463,9 @@ def from_data( start, end = None, None children = _parse_children(node_data, lean_file) assert len(children) == 2 and all(isinstance(_, AtomNode) for _ in children) - assert children[0].val == "/-!" - comment = children[1].val + cast_children = cast(List[AtomNode], children) + assert cast_children[0].val == "/-!" + comment = cast_children[1].val return cls(lean_file, start, end, children, comment) @@ -1436,14 +1481,15 @@ def from_data( start, end = None, None children = _parse_children(node_data, lean_file) assert len(children) == 2 and all(isinstance(_, AtomNode) for _ in children) - assert children[0].val == "/--" - comment = children[1].val + cast_children = cast(List[AtomNode], children) + assert cast_children[0].val == "/--" + comment = cast_children[1].val return cls(lean_file, start, end, children, comment) @dataclass(frozen=True) class CommandNamespaceNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -1470,7 +1516,7 @@ class CommandSectionNode(Node): @classmethod def from_data( cls, node_data: Dict[str, Any], lean_file: LeanFile - ) -> "CommandNamespaceNode": + ) -> "CommandSectionNode": assert node_data["info"] == "none" start, end = None, None children = _parse_children(node_data, lean_file) @@ -1557,7 +1603,16 @@ def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "OtherNode return cls(lean_file, start, end, children, node_data["kind"]) -def is_potential_premise_lean4(node: Node) -> bool: +def is_potential_premise_lean4(node: Node) -> TypeGuard[ + Union[ + CommandDeclarationNode, + LemmaNode, + MathlibTacticLemmaNode, + LeanElabCommandCommandIrreducibleDefNode, + StdTacticAliasAliasNode, + StdTacticAliasAliaslrNode, + ] +]: """Check if ``node`` is a theorem/definition that can be used as a premise.""" if (isinstance(node, CommandDeclarationNode) and not node.is_example) or isinstance( node, @@ -1574,7 +1629,9 @@ def is_potential_premise_lean4(node: Node) -> bool: return False -def is_mutual_lean4(node: Node) -> bool: +def is_mutual_lean4( + node: Node, +) -> TypeGuard[Union[IdentNode, CommandTheoremNode, StdTacticAliasAliaslrNode]]: return ( isinstance(node, (IdentNode, CommandTheoremNode, StdTacticAliasAliaslrNode)) and node.is_mutual diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index 45648d01..021d4f67 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -245,6 +245,9 @@ def __lt__(self, other): def __le__(self, other): return self < other or self == other + def __index__(self): + return self.line_nb, self.column_nb + @dataclass(frozen=True) class LeanFile: diff --git a/src/lean_dojo/data_extraction/trace.py b/src/lean_dojo/data_extraction/trace.py index eeea2967..c8a94969 100644 --- a/src/lean_dojo/data_extraction/trace.py +++ b/src/lean_dojo/data_extraction/trace.py @@ -69,7 +69,10 @@ def launch_progressbar(paths: List[Path]) -> Generator[None, None, None]: def get_lean_version() -> str: """Get the version of Lean.""" - output = execute("lean --version", capture_output=True)[0].strip() + res = execute("lean --version", capture_output=True) + if res is None: + raise CalledProcessError(1, "lean --version") + output = res[0].strip() m = re.match(r"Lean \(version (?P\S+?),", output) return m["version"] # type: ignore @@ -140,7 +143,10 @@ def _trace(repo: LeanGitRepo, build_deps: bool) -> None: execute("lake build") # Copy the Lean 4 stdlib into the path of packages. - lean_prefix = execute(f"lean --print-prefix", capture_output=True)[0].strip() + lean_prefix_output = execute("lean --print-prefix", capture_output=True) + if lean_prefix_output is None: + raise CalledProcessError(1, "lean --print-prefix") + lean_prefix = Path(lean_prefix_output[0].strip()) if is_new_version(get_lean_version()): packages_path = Path(".lake/packages") build_path = Path(".lake/build") diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index c7fa9b50..5c828458 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -14,7 +14,7 @@ from pathlib import Path from loguru import logger from dataclasses import dataclass, field -from typing import List, Optional, Dict, Any, Tuple, Union +from typing import List, Optional, Dict, Any, Tuple, Union, cast from ..utils import ( is_git_repo, @@ -47,6 +47,12 @@ is_leaf, is_mutual_lean4, is_potential_premise_lean4, + cast_away_optional, + StdTacticAliasAliaslrNode, + LeanElabCommandCommandIrreducibleDefNode, + StdTacticAliasAliasNode, + cast_list_str, + cast_str, ) from .lean import LeanFile, LeanGitRepo, Theorem, Pos from ..constants import NUM_WORKERS, LOAD_USED_PACKAGES_ONLY, LEAN4_PACKAGES_DIR @@ -144,7 +150,8 @@ class TracedTactic: its AST and the states before/after the tactic. """ - ast: Node = field(repr=False) + # ast: Node = field(repr=False) + ast: OtherNode | TacticTacticseqbracketedNode = field(repr=False) """AST of the tactic. """ @@ -160,7 +167,7 @@ def __getstate__(self) -> Dict[str, Any]: return d @property - def tactic(self) -> str: + def tactic(self) -> Optional[str]: """The raw tactic string.""" return self.ast.tactic @@ -179,11 +186,15 @@ def state_after(self) -> str: @property def start(self) -> Pos: """Start position in :file:`*.lean` file.""" + if not isinstance(self.ast.start, Pos): + raise TypeError("start expected to be Pos") return self.ast.start @property def end(self) -> Pos: """End position in :file:`*.lean` file.""" + if not isinstance(self.ast.end, Pos): + raise TypeError("end expected to be Pos") return self.ast.end def to_string(self) -> str: @@ -212,10 +223,16 @@ def get_annotated_tactic(self) -> Tuple[str, List[Dict[str, Any]]]: ) lean_file = self.traced_theorem.traced_file.lean_file annot_tac = [] - provenances = [] + provenances: List[Dict[str, Any]] = [] cur = self.start - def _callback4(node: IdentNode, _): + def _callback4(node: Node, _): + if not isinstance(node, IdentNode): + raise TypeError("Excepted IdentNode") + if node.start is None or node.end is None: + raise TypeError( + "start/end expected to be Pos, Unsupported left operand type for <= ('None')" + ) nonlocal cur if ( @@ -227,10 +244,12 @@ def _callback4(node: IdentNode, _): if cur <= node.start: annot_tac.append(lean_file[cur : node.start]) annot_tac.append("" + lean_file[node.start : node.end] + "") - prov = {"full_name": node.full_name} - prov["def_path"] = node.def_path - prov["def_pos"] = list(node.def_start) - prov["def_end_pos"] = list(node.def_end) + prov = { + "full_name": node.full_name, + "def_path": node.def_path, + "def_pos": list(node.def_start), + "def_end_pos": list(node.def_end), + } provenances.append(prov) cur = node.end @@ -281,11 +300,15 @@ def __getstate__(self) -> Dict[str, Any]: @property def start(self) -> Pos: """Start position in :file:`*.lean` file.""" + if not isinstance(self.ast.start, Pos): + raise TypeError("start expected to be Pos") return self.ast.start @property def end(self) -> Pos: """End position in :file:`*.lean` file.""" + if not isinstance(self.ast.end, Pos): + raise TypeError("end expected to be Pos") return self.ast.end @property @@ -333,8 +356,11 @@ def get_proof_node(self) -> Node: def locate_proof(self) -> Tuple[Pos, Pos]: """Return the start/end positions of the proof.""" start, end = self.get_proof_node().get_closure() - if end < self.end: - end = self.end + start = cast_away_optional(start) + self_end = cast_away_optional(self.end) + end = cast_away_optional(end) + if end < self_end: + end = self_end return start, end def get_tactic_proof(self) -> Optional[str]: @@ -343,6 +369,8 @@ def get_tactic_proof(self) -> Optional[str]: return None node = self.get_proof_node() start, end = node.get_closure() + start = cast_away_optional(start) + end = cast_away_optional(end) proof = get_code_without_comments(node.lean_file, start, end, self.comments) if not re.match(r"^(by|begin)\s", proof): return None @@ -353,15 +381,18 @@ def get_theorem_statement(self) -> str: """Return the theorem statement.""" proof_start, _ = self.locate_proof() assert self.traced_file is not None + start = cast_away_optional(self.ast.start) return get_code_without_comments( - self.traced_file.lean_file, self.ast.start, proof_start, self.comments + self.traced_file.lean_file, start, proof_start, self.comments ) def get_premise_full_names(self) -> List[str]: """Return the fully qualified names of all premises used in the proof.""" names = [] - def _callback(node: IdentNode, _: List[Node]): + def _callback(node: Node, _: List[Node]): + if not isinstance(node, IdentNode): + raise TypeError("Excepted IdentNode") if node.full_name is not None: names.append(node.full_name) @@ -503,7 +534,7 @@ def has_prelude(self) -> bool: """ result = False - def _callback(node: ModulePreludeNode, _: List[Node]): + def _callback(node: Node, _: List[Node]): nonlocal result result = True return True # Stop traversing. @@ -624,19 +655,35 @@ def _callback(node: Node, _): ): inside_sections_namespaces.pop() elif is_potential_premise_lean4(node): - prefix = ".".join( - ns.name - for ns in inside_sections_namespaces - if isinstance(ns, CommandNamespaceNode) - ) - full_name = ( - [_qualify_name(name, prefix) for name in node.name] - if is_mutual_lean4(node) - else _qualify_name(node.name, prefix) - ) - object.__setattr__(node, "full_name", full_name) - if isinstance(node, CommandDeclarationNode) and node.is_theorem: - object.__setattr__(node.get_theorem_node(), "full_name", full_name) + assert node.name is not None + names = [] + for ns in inside_sections_namespaces: + if isinstance(ns, CommandNamespaceNode): + if ns.name is None: + raise TypeError("Expected ns.name to be str") + names.append(ns.name) + prefix = ".".join(names) + + full_name: Union[str, List[str]] + if isinstance(node, StdTacticAliasAliaslrNode) and node.is_mutual: + full_name = [_qualify_name(name, prefix) for name in node.name] + object.__setattr__(node, "full_name", full_name) + elif isinstance( + node, + ( + CommandDeclarationNode, + LemmaNode, + MathlibTacticLemmaNode, + LeanElabCommandCommandIrreducibleDefNode, + StdTacticAliasAliasNode, + ), + ): + full_name = _qualify_name(node.name, prefix) + object.__setattr__(node, "full_name", full_name) + if isinstance(node, CommandDeclarationNode) and node.is_theorem: + object.__setattr__( + node.get_theorem_node(), "full_name", full_name + ) elif isinstance( node, ( @@ -650,11 +697,16 @@ def _callback(node: Node, _): ) if (tac_node.start, tac_node.end) not in pos2tactics: continue - t = pos2tactics[(tac_node.start, tac_node.end)] + tac_node_start = cast_away_optional(tac_node.start) + tac_node_end = cast_away_optional(tac_node.end) + t = pos2tactics[(tac_node_start, tac_node_end)] + tac_start = cast_away_optional(tac_node.start) + tac_end = cast_away_optional(tac_node.end) tac = get_code_without_comments( - lean_file, tac_node.start, tac_node.end, comments + lean_file, tac_start, tac_end, comments ) - tac = _fix_indentation(tac, tac_node.start.column_nb - 1) + tac_start = cast_away_optional(tac_node.start) + tac = _fix_indentation(tac, tac_start.column_nb - 1) object.__setattr__(tac_node, "state_before", t["stateBefore"]) object.__setattr__(tac_node, "state_after", t["stateAfter"]) object.__setattr__(tac_node, "tactic", tac) @@ -753,27 +805,24 @@ def get_traced_theorem( result = None private_result = None - def _callback( - node: Union[CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode], _ - ) -> bool: + def _callback(node: Node, _) -> bool: nonlocal result, private_result - if ( - isinstance( - node, - ( - CommandTheoremNode, - LemmaNode, - MathlibTacticLemmaNode, - ), - ) - and node.full_name == thm.full_name + if not isinstance( + node, (CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode) ): - comments = self._filter_comments(node.start, node.end) - t = TracedTheorem(self.root_dir, thm, node, comments, self) - if t.is_private: - private_result = t - else: - result = t + raise TypeError( + "Except CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode" + ) + if not node.full_name == thm.full_name: + return False + start = cast_away_optional(node.start) + end = cast_away_optional(node.end) + comments = self._filter_comments(start, end) + t = TracedTheorem(self.root_dir, thm, node, comments, self) + if t.is_private: + private_result = t + else: + result = t return False self.ast.traverse_preorder(_callback, node_cls=None) @@ -800,8 +849,11 @@ def _callback( ): return False repo, path = self._get_repo_and_relative_path() - thm = Theorem(repo, path, node.full_name) - comments = self._filter_comments(node.start, node.end) + full_name = cast_away_optional(node.full_name) + thm = Theorem(repo, path, full_name) + start = cast_away_optional(node.start) + end = cast_away_optional(node.end) + comments = self._filter_comments(start, end) traced_theorems.append( TracedTheorem(self.root_dir, thm, node, comments, self) ) @@ -859,34 +911,41 @@ def _callback4(node: Node, _) -> None: proof_start, _ = ( node.get_theorem_node().get_proof_node().get_closure() ) + start = cast_away_optional(start) + proof_start = cast_away_optional(proof_start) code = get_code_without_comments( self.lean_file, start, proof_start, self.comments ) if code.endswith(":="): code = code[:-2].strip() else: + start = cast_away_optional(start) + end = cast_away_optional(end) code = get_code_without_comments( self.lean_file, start, end, self.comments ) # TODO: For alias, restate_axiom, etc., the code is not very informative. + full_name: Union[None, str, List[str]] if is_mutual_lean4(node): - for s in node.full_name: + full_name = cast_list_str(node.full_name) + for s in full_name: results.append( { "full_name": s, "code": code, - "start": list(start), - "end": list(end), + "start": list(start) if start is not None else [], + "end": list(end) if end is not None else [], "kind": node.kind(), } ) else: + full_name = cast_str(node.full_name) results.append( { - "full_name": node.full_name, + "full_name": full_name, "code": code, - "start": list(start), - "end": list(end), + "start": list(start) if start is not None else [], + "end": list(end) if end is not None else [], "kind": node.kind(), } ) diff --git a/src/lean_dojo/interaction/dojo.py b/src/lean_dojo/interaction/dojo.py index 09121dac..fc0921ec 100644 --- a/src/lean_dojo/interaction/dojo.py +++ b/src/lean_dojo/interaction/dojo.py @@ -8,7 +8,7 @@ from pathlib import Path from loguru import logger from dataclasses import dataclass, field -from typing import Union, Tuple, List, Dict, Any, Optional, TextIO +from typing import Union, Tuple, List, Dict, Any, Optional, TextIO, cast from .parse_goals import parse_goals, Goal from ..utils import to_json_path, working_directory @@ -260,6 +260,7 @@ def _modify_file(self, traced_file: TracedFile) -> None: else: # Interaction through commands (via CommandElabM). lean_file = traced_file.lean_file + self.entry = cast(Tuple[LeanGitRepo, Path, int], self.entry) pos = Pos(line_nb=self.entry[2], column_nb=1) code_before = get_code_without_comments( lean_file, lean_file.start_pos, pos, traced_file.comments