From 0e8a6eb4789f006e31d217f8523445afa413ef0b Mon Sep 17 00:00:00 2001 From: Yansy Date: Tue, 27 Aug 2024 17:52:17 +0800 Subject: [PATCH 01/13] fix(typing): mypy errors --- src/lean_dojo/data_extraction/ast.py | 124 +++++++++++-------- src/lean_dojo/data_extraction/trace.py | 10 +- src/lean_dojo/data_extraction/traced_data.py | 36 +++--- src/lean_dojo/interaction/dojo.py | 3 +- 4 files changed, 104 insertions(+), 69 deletions(-) diff --git a/src/lean_dojo/data_extraction/ast.py b/src/lean_dojo/data_extraction/ast.py index 2b886408..5ad74345 100644 --- a/src/lean_dojo/data_extraction/ast.py +++ b/src/lean_dojo/data_extraction/ast.py @@ -2,7 +2,9 @@ from pathlib import Path from dataclasses import dataclass, field from xml.sax.saxutils import escape, unescape -from typing import List, Dict, Any, Optional, Callable, Tuple, Generator +from typing import List, Dict, Any, Optional, Callable, Tuple, Generator, TypeVar, Union, cast, Protocol, Type, Generic, \ + TypeGuard +from typing_extensions import Self from ..utils import ( camel_case, @@ -13,6 +15,8 @@ ) from .lean import Pos, LeanFile +T = TypeVar("T", bound="Node") + @dataclass(frozen=True) class Node: @@ -22,7 +26,7 @@ class Node: children: List["Node"] = field(repr=False) @classmethod - def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "Node": + def from_data(cls: Type[T], node_data: Dict[str, Any], lean_file: LeanFile) -> T: subcls = cls._kind_to_node_type(node_data["kind"]) return subcls.from_data(node_data, lean_file) @@ -78,7 +82,7 @@ def to_xml(self, parent: etree.Element) -> None: child.to_xml(tree) @classmethod - def from_xml(cls, tree: etree.Element, lean_file: LeanFile) -> "Node": + def from_xml(cls: Type[T], tree: etree.Element, lean_file: LeanFile) -> T: subcls = globals()[tree.tag] start = Pos.from_str(tree.attrib["start"]) if "start" in tree.attrib else None end = Pos.from_str(tree.attrib["end"]) if "end" in tree.attrib else None @@ -148,10 +152,12 @@ class AtomNode(Node): @classmethod def from_data( cls, atom_data: Dict[str, Any], lean_file: LeanFile - ) -> Optional["AtomNode"]: + ) -> "AtomNode": info = atom_data["info"] - start, end = _parse_pos(info, lean_file) - + pos_pair = _parse_pos(info, lean_file) + if pos_pair is None: + raise ValueError("Synthetic atom nodes are not supported") + start, end = pos_pair if "original" in info: leading = info["original"]["leading"] trailing = info["original"]["trailing"] @@ -179,9 +185,12 @@ class IdentNode(Node): @classmethod def from_data( cls, ident_data: Dict[str, Any], lean_file: LeanFile - ) -> Optional["IdentNode"]: + ) -> "IdentNode": info = ident_data["info"] - start, end = _parse_pos(info, lean_file) + pos_pair = _parse_pos(info, lean_file) + if pos_pair is None: + raise ValueError("Synthetic ident nodes are not supported") + start, end = pos_pair assert ident_data["preresolved"] == [] if "original" in info: @@ -212,10 +221,14 @@ def is_leaf(node: Node) -> bool: return isinstance(node, AtomNode) or isinstance(node, IdentNode) +def filter_leaf(nodes: List[Node]) -> List[Union[AtomNode | IdentNode]]: + return [node for node in nodes if isinstance(node, (AtomNode, IdentNode))] + + @dataclass(frozen=True) class FileNode(Node): @classmethod - def from_data(cls, data: Dict[str, Any], lean_file: LeanFile) -> "FileNode": + def from_data(cls: Type[T], data: Dict[str, Any], lean_file: LeanFile) -> T: children = [] def _get_closure(node: Node, child_spans: List[Tuple[Pos, Pos]]): @@ -316,7 +329,7 @@ def from_data( return cls(lean_file, start, end, children) def get_ident(self) -> str: - return "".join(gc.val for gc in self.children if is_leaf(gc)) + return "".join(gc.val for gc in filter_leaf(self.children)) @dataclass(frozen=True) @@ -363,7 +376,9 @@ def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "GroupNode class MathlibTacticLemmaNode(Node): name: str full_name: Optional[str] = None - _is_private_decl: Optional[bool] = ( + # If the Private usage state is False, the user will not explicitly use the parameter None. + # And False and None are handled in the same way, so there is no need to use Optional. + _is_private_decl: bool = ( False # `_is_private` doesn't play well with lxml. ) @@ -412,9 +427,11 @@ def is_mutual(self) -> bool: @dataclass(frozen=True) class LemmaNode(Node): - name: str + name: Optional[str] full_name: Optional[str] = None - _is_private_decl: Optional[bool] = ( + # If the Private usage state is False, the user will not explicitly use the parameter None. + # And False and None are handled in the same way, so there is no need to use Optional. + _is_private_decl: bool = ( False # `_is_private` doesn't play well with lxml. ) @@ -468,7 +485,7 @@ def is_mutual(self) -> bool: @dataclass(frozen=True) class CommandDeclarationNode(Node): - name: str + name: Optional[str] full_name: Optional[str] = None @classmethod @@ -514,7 +531,7 @@ def is_theorem(self) -> bool: def get_theorem_node(self) -> "CommandTheoremNode": assert self.is_theorem - return self.children[1] + return cast(CommandTheoremNode, self.children[1]) @property def is_example(self) -> bool: @@ -547,7 +564,7 @@ def from_data( def is_private(self) -> bool: result = False - def _callback(node: CommandPrivateNode, _) -> bool: + def _callback(node: "CommandDeclmodifiersNode", _) -> bool: nonlocal result result = True return True @@ -630,7 +647,7 @@ def from_data( @dataclass(frozen=True) class CommandStructureNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -682,7 +699,7 @@ def from_data( @dataclass(frozen=True) class CommandClassinductiveNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -776,7 +793,7 @@ def get_ident(self) -> Optional[str]: @dataclass(frozen=True) class StdTacticAliasAliasNode(Node): - name: str + name: Optional[str] full_name: Optional[str] = None @classmethod @@ -826,9 +843,13 @@ def from_data( children[5], (LeanBinderidentNode, LeanBinderidentAntiquotNode) ) name.append(children[5].get_ident()) - name = [n for n in name if n is not None] - return cls(lean_file, start, end, children, name) + def _filter_names(name_lst: List[str | None]) -> List[str]: + return [n for n in name_lst if n is not None] + + names = _filter_names(name) + + return cls(lean_file, start, end, children, names) @property def is_mutual(self) -> bool: @@ -837,7 +858,7 @@ def is_mutual(self) -> bool: @dataclass(frozen=True) class CommandAbbrevNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -862,7 +883,7 @@ def from_data( @dataclass(frozen=True) class CommandOpaqueNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -887,7 +908,7 @@ def from_data( @dataclass(frozen=True) class CommandAxiomNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -912,7 +933,7 @@ def from_data( @dataclass(frozen=True) class CommandExampleNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -928,7 +949,7 @@ def from_data( @dataclass(frozen=True) class CommandInstanceNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -957,7 +978,7 @@ def from_data( @dataclass(frozen=True) class CommandDefNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -988,7 +1009,7 @@ def from_data( @dataclass(frozen=True) class CommandDefinitionNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -1127,9 +1148,11 @@ def from_data( @dataclass(frozen=True) class CommandTheoremNode(Node): - name: str + name: Optional[str] full_name: Optional[str] = None - _is_private_decl: Optional[bool] = ( + # If the Private usage state is False, the user will not explicitly use the parameter None. + # And False and None are handled in the same way, so there is no need to use Optional. + _is_private_decl: bool = ( False # `_is_private` doesn't play well with lxml. ) @@ -1216,6 +1239,15 @@ def from_data( return cls(lean_file, start, end, children) +class TacticProtocol(Protocol): + def get_tactic_nodes(self, atomic_only: bool = False) -> Generator[Node, None, None]: + ... + + +class TacticNode(Node, TacticProtocol): + ... + + @dataclass(frozen=True) class TacticTacticseq1IndentedAntiquotNode(Node): @classmethod @@ -1229,7 +1261,7 @@ def from_data( def get_tactic_nodes( self, atomic_only: bool = False - ) -> Generator[Node, None, None]: + ) -> None: return @@ -1255,7 +1287,8 @@ def from_data( def get_tactic_nodes( self, atomic_only: bool = False ) -> Generator[Node, None, None]: - yield from self.children[0].get_tactic_nodes(atomic_only) + child = cast(TacticNode, self.children[0]) + yield from child.get_tactic_nodes(atomic_only) @dataclass(frozen=True) @@ -1344,6 +1377,7 @@ def _callback(x, _) -> bool: nonlocal result result = True return True + return False node.traverse_preorder(_callback, node_cls=None) return result @@ -1373,18 +1407,6 @@ def from_data( return cls(lean_file, start, end, children) -@dataclass(frozen=True) -class ModulePreludeNode(Node): - @classmethod - def from_data( - cls, node_data: Dict[str, Any], lean_file: LeanFile - ) -> "ModulePreludeNode": - assert node_data["info"] == "none" - start, end = None, None - children = _parse_children(node_data, lean_file) - return cls(lean_file, start, end, children) - - @dataclass(frozen=True) class ModuleImportNode(Node): module: Optional[str] @@ -1419,8 +1441,9 @@ def from_data( start, end = None, None children = _parse_children(node_data, lean_file) assert len(children) == 2 and all(isinstance(_, AtomNode) for _ in children) - assert children[0].val == "/-!" - comment = children[1].val + cast_children = cast(List[AtomNode], children) + assert cast_children[0].val == "/-!" + comment = cast_children[1].val return cls(lean_file, start, end, children, comment) @@ -1436,14 +1459,15 @@ def from_data( start, end = None, None children = _parse_children(node_data, lean_file) assert len(children) == 2 and all(isinstance(_, AtomNode) for _ in children) - assert children[0].val == "/--" - comment = children[1].val + cast_children = cast(List[AtomNode], children) + assert cast_children[0].val == "/--" + comment = cast_children[1].val return cls(lean_file, start, end, children, comment) @dataclass(frozen=True) class CommandNamespaceNode(Node): - name: str + name: Optional[str] @classmethod def from_data( @@ -1470,7 +1494,7 @@ class CommandSectionNode(Node): @classmethod def from_data( cls, node_data: Dict[str, Any], lean_file: LeanFile - ) -> "CommandNamespaceNode": + ) -> "CommandSectionNode": assert node_data["info"] == "none" start, end = None, None children = _parse_children(node_data, lean_file) diff --git a/src/lean_dojo/data_extraction/trace.py b/src/lean_dojo/data_extraction/trace.py index eeea2967..c8a94969 100644 --- a/src/lean_dojo/data_extraction/trace.py +++ b/src/lean_dojo/data_extraction/trace.py @@ -69,7 +69,10 @@ def launch_progressbar(paths: List[Path]) -> Generator[None, None, None]: def get_lean_version() -> str: """Get the version of Lean.""" - output = execute("lean --version", capture_output=True)[0].strip() + res = execute("lean --version", capture_output=True) + if res is None: + raise CalledProcessError(1, "lean --version") + output = res[0].strip() m = re.match(r"Lean \(version (?P\S+?),", output) return m["version"] # type: ignore @@ -140,7 +143,10 @@ def _trace(repo: LeanGitRepo, build_deps: bool) -> None: execute("lake build") # Copy the Lean 4 stdlib into the path of packages. - lean_prefix = execute(f"lean --print-prefix", capture_output=True)[0].strip() + lean_prefix_output = execute("lean --print-prefix", capture_output=True) + if lean_prefix_output is None: + raise CalledProcessError(1, "lean --print-prefix") + lean_prefix = Path(lean_prefix_output[0].strip()) if is_new_version(get_lean_version()): packages_path = Path(".lake/packages") build_path = Path(".lake/build") diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index 5e56050b..292f630d 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -126,7 +126,7 @@ def get_code_without_comments( for c in comments: if base <= c.start and c.end <= end: - code_segs.append(lean_file[base : c.start]) + code_segs.append(lean_file[base: c.start]) base = c.end code_segs.append(lean_file[base:end]) @@ -144,7 +144,8 @@ class TracedTactic: its AST and the states before/after the tactic. """ - ast: Node = field(repr=False) + # ast: Node = field(repr=False) + ast: OtherNode | TacticTacticseqbracketedNode = field(repr=False) """AST of the tactic. """ @@ -160,7 +161,7 @@ def __getstate__(self) -> Dict[str, Any]: return d @property - def tactic(self) -> str: + def tactic(self) -> Optional[str]: """The raw tactic string.""" return self.ast.tactic @@ -212,7 +213,7 @@ def get_annotated_tactic(self) -> Tuple[str, List[Dict[str, Any]]]: ) lean_file = self.traced_theorem.traced_file.lean_file annot_tac = [] - provenances = [] + provenances: List[Dict[str,Any]] = [] cur = self.start def _callback4(node: IdentNode, _): @@ -225,12 +226,14 @@ def _callback4(node: IdentNode, _): and node.def_end is not None ): if cur <= node.start: - annot_tac.append(lean_file[cur : node.start]) - annot_tac.append("" + lean_file[node.start : node.end] + "") - prov = {"full_name": node.full_name} - prov["def_path"] = node.def_path - prov["def_pos"] = list(node.def_start) - prov["def_end_pos"] = list(node.def_end) + annot_tac.append(lean_file[cur: node.start]) + annot_tac.append("" + lean_file[node.start: node.end] + "") + prov = { + "full_name": node.full_name, + "def_path": node.def_path, + "def_pos": list(node.def_start), + "def_end_pos": list(node.def_end), + } provenances.append(prov) cur = node.end @@ -361,7 +364,7 @@ def get_premise_full_names(self) -> List[str]: """Return the fully qualified names of all premises used in the proof.""" names = [] - def _callback(node: IdentNode, _: List[Node]): + def _callback(node: "IdentNode", _: List[Node]): if node.full_name is not None: names.append(node.full_name) @@ -625,11 +628,12 @@ def _callback(node: Node, _): for ns in inside_sections_namespaces if isinstance(ns, CommandNamespaceNode) ) - full_name = ( - [_qualify_name(name, prefix) for name in node.name] - if is_mutual_lean4(node) - else _qualify_name(node.name, prefix) - ) + + if is_mutual_lean4(node): + full_name = [_qualify_name(name, prefix) for name in node.name] + else: + full_name = _qualify_name(node.name, prefix) + object.__setattr__(node, "full_name", full_name) if isinstance(node, CommandDeclarationNode) and node.is_theorem: object.__setattr__(node.get_theorem_node(), "full_name", full_name) diff --git a/src/lean_dojo/interaction/dojo.py b/src/lean_dojo/interaction/dojo.py index b8ab0fb3..15b34cd0 100644 --- a/src/lean_dojo/interaction/dojo.py +++ b/src/lean_dojo/interaction/dojo.py @@ -8,7 +8,7 @@ from pathlib import Path from loguru import logger from dataclasses import dataclass, field -from typing import Union, Tuple, List, Dict, Any, Optional, TextIO +from typing import Union, Tuple, List, Dict, Any, Optional, TextIO, cast from .parse_goals import parse_goals, Goal from ..utils import to_json_path, working_directory @@ -266,6 +266,7 @@ def _modify_file(self, traced_file: TracedFile) -> None: else: # Interaction through commands (via CommandElabM). lean_file = traced_file.lean_file + self.entry = cast(Tuple[LeanGitRepo, Path, int], self.entry) pos = Pos(line_nb=self.entry[2], column_nb=1) code_before = get_code_without_comments( lean_file, lean_file.start_pos, pos, traced_file.comments From 474acfb9e0d984e3830fc23f0e922eebd1db55f3 Mon Sep 17 00:00:00 2001 From: Yansy Date: Tue, 27 Aug 2024 17:53:54 +0800 Subject: [PATCH 02/13] fix(typing): remove unuse import --- src/lean_dojo/data_extraction/ast.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lean_dojo/data_extraction/ast.py b/src/lean_dojo/data_extraction/ast.py index 5ad74345..6ad2a824 100644 --- a/src/lean_dojo/data_extraction/ast.py +++ b/src/lean_dojo/data_extraction/ast.py @@ -2,9 +2,7 @@ from pathlib import Path from dataclasses import dataclass, field from xml.sax.saxutils import escape, unescape -from typing import List, Dict, Any, Optional, Callable, Tuple, Generator, TypeVar, Union, cast, Protocol, Type, Generic, \ - TypeGuard -from typing_extensions import Self +from typing import List, Dict, Any, Optional, Callable, Tuple, Generator, TypeVar, Union, cast, Protocol, Type from ..utils import ( camel_case, From d4e51ce2421a7d6ac668988957801a58e67ac2c2 Mon Sep 17 00:00:00 2001 From: Yansy Date: Tue, 27 Aug 2024 17:55:12 +0800 Subject: [PATCH 03/13] fix(typing): callback node type comment --- src/lean_dojo/data_extraction/ast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lean_dojo/data_extraction/ast.py b/src/lean_dojo/data_extraction/ast.py index 6ad2a824..1c6236bf 100644 --- a/src/lean_dojo/data_extraction/ast.py +++ b/src/lean_dojo/data_extraction/ast.py @@ -562,7 +562,7 @@ def from_data( def is_private(self) -> bool: result = False - def _callback(node: "CommandDeclmodifiersNode", _) -> bool: + def _callback(node: "CommandPrivateNode", _) -> bool: nonlocal result result = True return True From f5820db8282a504a1fc8c996bcc458c8a2fa0c3b Mon Sep 17 00:00:00 2001 From: Yansy Date: Tue, 27 Aug 2024 17:58:11 +0800 Subject: [PATCH 04/13] fix(typing): remove double quote on node IdentNode typing comment --- src/lean_dojo/data_extraction/traced_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index 292f630d..87d6d20b 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -364,7 +364,7 @@ def get_premise_full_names(self) -> List[str]: """Return the fully qualified names of all premises used in the proof.""" names = [] - def _callback(node: "IdentNode", _: List[Node]): + def _callback(node: IdentNode, _: List[Node]): if node.full_name is not None: names.append(node.full_name) From 7f9cdcc206442f000313606efc271a7e3767ee60 Mon Sep 17 00:00:00 2001 From: Yansy Date: Thu, 29 Aug 2024 15:41:22 +0800 Subject: [PATCH 05/13] fix(typing): mypy errors --- src/lean_dojo/data_extraction/ast.py | 37 ++++++---- src/lean_dojo/data_extraction/lean.py | 3 + src/lean_dojo/data_extraction/traced_data.py | 76 ++++++++++---------- 3 files changed, 66 insertions(+), 50 deletions(-) diff --git a/src/lean_dojo/data_extraction/ast.py b/src/lean_dojo/data_extraction/ast.py index 1c6236bf..1487166b 100644 --- a/src/lean_dojo/data_extraction/ast.py +++ b/src/lean_dojo/data_extraction/ast.py @@ -1,8 +1,12 @@ +from typing_extensions import TypeGuard + from lxml import etree from pathlib import Path from dataclasses import dataclass, field from xml.sax.saxutils import escape, unescape -from typing import List, Dict, Any, Optional, Callable, Tuple, Generator, TypeVar, Union, cast, Protocol, Type +from typing import List, Dict, Any, Optional, Callable, Tuple, Generator, TypeVar, Union, cast, Protocol, Type, \ + Sequence, Literal +from typing_extensions import Annotated from ..utils import ( camel_case, @@ -13,7 +17,13 @@ ) from .lean import Pos, LeanFile -T = TypeVar("T", bound="Node") +N = TypeVar("N", bound="Node", covariant=True) +T = TypeVar("T") + + +def cast_away_optional(x: Optional[T]) -> T: + assert x is not None + return x @dataclass(frozen=True) @@ -21,15 +31,15 @@ class Node: lean_file: LeanFile start: Optional[Pos] end: Optional[Pos] - children: List["Node"] = field(repr=False) + children: Sequence["Node"] = field(repr=False) @classmethod - def from_data(cls: Type[T], node_data: Dict[str, Any], lean_file: LeanFile) -> T: + def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "Node": subcls = cls._kind_to_node_type(node_data["kind"]) return subcls.from_data(node_data, lean_file) @classmethod - def _kind_to_node_type(cls, kind: str) -> type["Node"]: + def _kind_to_node_type(cls: Type[N], kind: str) -> Type["Node"]: prefix = "Lean.Parser." if kind.startswith(prefix): kind = kind[len(prefix) :] @@ -48,9 +58,9 @@ def kind(cls: type) -> str: def traverse_preorder( self, callback: Callable[["Node", List["Node"]], Any], - node_cls: Optional[type], + node_cls: Optional[type["Node"]], parents: List["Node"] = [], - ) -> None: + ): if node_cls is None or isinstance(self, node_cls): if callback(self, parents): return @@ -80,7 +90,7 @@ def to_xml(self, parent: etree.Element) -> None: child.to_xml(tree) @classmethod - def from_xml(cls: Type[T], tree: etree.Element, lean_file: LeanFile) -> T: + def from_xml(cls: Type[N], tree: etree.Element, lean_file: LeanFile) -> N: subcls = globals()[tree.tag] start = Pos.from_str(tree.attrib["start"]) if "start" in tree.attrib else None end = Pos.from_str(tree.attrib["end"]) if "end" in tree.attrib else None @@ -219,14 +229,14 @@ def is_leaf(node: Node) -> bool: return isinstance(node, AtomNode) or isinstance(node, IdentNode) -def filter_leaf(nodes: List[Node]) -> List[Union[AtomNode | IdentNode]]: +def filter_leaf(nodes: Sequence[Node]) -> Sequence[Union[AtomNode | IdentNode]]: return [node for node in nodes if isinstance(node, (AtomNode, IdentNode))] @dataclass(frozen=True) class FileNode(Node): @classmethod - def from_data(cls: Type[T], data: Dict[str, Any], lean_file: LeanFile) -> T: + def from_data(cls: Type[N], data: Dict[str, Any], lean_file: LeanFile) -> N: children = [] def _get_closure(node: Node, child_spans: List[Tuple[Pos, Pos]]): @@ -562,7 +572,7 @@ def from_data( def is_private(self) -> bool: result = False - def _callback(node: "CommandPrivateNode", _) -> bool: + def _callback(node: CommandDeclmodifiersNode, _) -> bool: nonlocal result result = True return True @@ -1579,7 +1589,8 @@ def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "OtherNode return cls(lean_file, start, end, children, node_data["kind"]) -def is_potential_premise_lean4(node: Node) -> bool: +def is_potential_premise_lean4(node: Node) -> TypeGuard[Union[ + CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode, LeanElabCommandCommandIrreducibleDefNode, StdTacticAliasAliasNode, StdTacticAliasAliaslrNode,]]: """Check if ``node`` is a theorem/definition that can be used as a premise.""" if (isinstance(node, CommandDeclarationNode) and not node.is_example) or isinstance( node, @@ -1596,7 +1607,7 @@ def is_potential_premise_lean4(node: Node) -> bool: return False -def is_mutual_lean4(node: Node) -> bool: +def is_mutual_lean4(node: Node) -> TypeGuard[Union[IdentNode,CommandTheoremNode, StdTacticAliasAliaslrNode]]: return ( isinstance(node, (IdentNode, CommandTheoremNode, StdTacticAliasAliaslrNode)) and node.is_mutual diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index b86afd9a..c0859916 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -242,6 +242,9 @@ def __lt__(self, other): def __le__(self, other): return self < other or self == other + def __index__(self): + return self.line_nb, self.column_nb + @dataclass(frozen=True) class LeanFile: diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index 87d6d20b..da7ba021 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -25,29 +25,11 @@ to_json_path, to_xml_path, ) -from .ast import ( - Node, - FileNode, - OtherNode, - LemmaNode, - IdentNode, - CommandEndNode, - ModuleImportNode, - ModulePreludeNode, - CommandSectionNode, - CommandTheoremNode, - CommandModuledocNode, - CommandNamespaceNode, - CommandDoccommentNode, - CommandDeclarationNode, - MathlibTacticLemmaNode, - TacticTacticseqbracketedNode, - TacticTacticseq1IndentedNode, - CommandNoncomputablesectionNode, - is_leaf, - is_mutual_lean4, - is_potential_premise_lean4, -) +from .ast import (Node, FileNode, OtherNode, LemmaNode, IdentNode, CommandEndNode, ModuleImportNode, ModulePreludeNode, + CommandSectionNode, CommandTheoremNode, CommandModuledocNode, CommandNamespaceNode, + CommandDoccommentNode, CommandDeclarationNode, MathlibTacticLemmaNode, TacticTacticseqbracketedNode, + TacticTacticseq1IndentedNode, CommandNoncomputablesectionNode, is_leaf, is_mutual_lean4, + is_potential_premise_lean4, cast_away_optional) from .lean import LeanFile, LeanGitRepo, Theorem, Pos from ..constants import NUM_WORKERS, LOAD_USED_PACKAGES_ONLY, LEAN4_PACKAGES_DIR @@ -336,8 +318,11 @@ def get_proof_node(self) -> Node: def locate_proof(self) -> Tuple[Pos, Pos]: """Return the start/end positions of the proof.""" start, end = self.get_proof_node().get_closure() - if end < self.end: - end = self.end + start = cast_away_optional(start) + self_end = cast_away_optional(self.end) + end = cast_away_optional(end) + if end < self_end: + end = self_end return start, end def get_tactic_proof(self) -> Optional[str]: @@ -346,6 +331,8 @@ def get_tactic_proof(self) -> Optional[str]: return None node = self.get_proof_node() start, end = node.get_closure() + start = cast_away_optional(start) + end = cast_away_optional(end) proof = get_code_without_comments(node.lean_file, start, end, self.comments) if not re.match(r"^(by|begin)\s", proof): return None @@ -356,8 +343,9 @@ def get_theorem_statement(self) -> str: """Return the theorem statement.""" proof_start, _ = self.locate_proof() assert self.traced_file is not None + start = cast_away_optional(self.ast.start) return get_code_without_comments( - self.traced_file.lean_file, self.ast.start, proof_start, self.comments + self.traced_file.lean_file, start, proof_start, self.comments ) def get_premise_full_names(self) -> List[str]: @@ -506,7 +494,7 @@ def has_prelude(self) -> bool: """ result = False - def _callback(node: ModulePreludeNode, _: List[Node]): + def _callback(node: Node, _: List[Node]): nonlocal result result = True return True # Stop traversing. @@ -650,11 +638,16 @@ def _callback(node: Node, _): ) if (tac_node.start, tac_node.end) not in pos2tactics: continue - t = pos2tactics[(tac_node.start, tac_node.end)] + tac_node_start = cast_away_optional(tac_node.start) + tac_node_end = cast_away_optional(tac_node.end) + t = pos2tactics[(tac_node_start, tac_node_end)] + tac_start = cast_away_optional(tac_node.start) + tac_end = cast_away_optional(tac_node.end) tac = get_code_without_comments( - lean_file, tac_node.start, tac_node.end, comments + lean_file, tac_start, tac_end, comments ) - tac = _fix_indentation(tac, tac_node.start.column_nb - 1) + tac_start = cast_away_optional(tac_node.start) + tac = _fix_indentation(tac, tac_start.column_nb - 1) object.__setattr__(tac_node, "state_before", t["stateBefore"]) object.__setattr__(tac_node, "state_after", t["stateAfter"]) object.__setattr__(tac_node, "tactic", tac) @@ -768,7 +761,9 @@ def _callback( ) and node.full_name == thm.full_name ): - comments = self._filter_comments(node.start, node.end) + start = cast_away_optional(node.start) + end = cast_away_optional(node.end) + comments = self._filter_comments(start, end) t = TracedTheorem(self.root_dir, thm, node, comments, self) if t.is_private: private_result = t @@ -800,8 +795,11 @@ def _callback( ): return False repo, path = self._get_repo_and_relative_path() - thm = Theorem(repo, path, node.full_name) - comments = self._filter_comments(node.start, node.end) + full_name = cast_away_optional(node.full_name) + thm = Theorem(repo, path, full_name) + start = cast_away_optional(node.start) + end = cast_away_optional(node.end) + comments = self._filter_comments(start, end) traced_theorems.append( TracedTheorem(self.root_dir, thm, node, comments, self) ) @@ -858,12 +856,16 @@ def _callback4(node: Node, _) -> None: proof_start, _ = ( node.get_theorem_node().get_proof_node().get_closure() ) + start = cast_away_optional(start) + proof_start = cast_away_optional(proof_start) code = get_code_without_comments( self.lean_file, start, proof_start, self.comments ) if code.endswith(":="): code = code[:-2].strip() else: + start = cast_away_optional(start) + end = cast_away_optional(end) code = get_code_without_comments( self.lean_file, start, end, self.comments ) @@ -874,8 +876,8 @@ def _callback4(node: Node, _) -> None: { "full_name": s, "code": code, - "start": list(start), - "end": list(end), + "start": list(start) if start is not None else [], + "end": list(end) if end is not None else [], "kind": node.kind(), } ) @@ -884,8 +886,8 @@ def _callback4(node: Node, _) -> None: { "full_name": node.full_name, "code": code, - "start": list(start), - "end": list(end), + "start": list(start) if start is not None else [], + "end": list(end) if end is not None else [], "kind": node.kind(), } ) From fe5edfd7714c8c5e3a574c8b5677a0838fa8edef Mon Sep 17 00:00:00 2001 From: Yansy Date: Thu, 29 Aug 2024 15:46:13 +0800 Subject: [PATCH 06/13] fix(typing): mypy errors --- src/lean_dojo/data_extraction/ast.py | 6 ++--- src/lean_dojo/data_extraction/traced_data.py | 28 ++++++++++++++++---- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/lean_dojo/data_extraction/ast.py b/src/lean_dojo/data_extraction/ast.py index 1487166b..3ef57e96 100644 --- a/src/lean_dojo/data_extraction/ast.py +++ b/src/lean_dojo/data_extraction/ast.py @@ -4,9 +4,7 @@ from pathlib import Path from dataclasses import dataclass, field from xml.sax.saxutils import escape, unescape -from typing import List, Dict, Any, Optional, Callable, Tuple, Generator, TypeVar, Union, cast, Protocol, Type, \ - Sequence, Literal -from typing_extensions import Annotated +from typing import List, Dict, Any, Optional, Callable, Tuple, Generator, TypeVar, Union, cast, Protocol, Type,Sequence from ..utils import ( camel_case, @@ -39,7 +37,7 @@ def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "Node": return subcls.from_data(node_data, lean_file) @classmethod - def _kind_to_node_type(cls: Type[N], kind: str) -> Type["Node"]: + def _kind_to_node_type(cls, kind: str) -> Type["Node"]: prefix = "Lean.Parser." if kind.startswith(prefix): kind = kind[len(prefix) :] diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index da7ba021..1d7a6a7a 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -25,11 +25,29 @@ to_json_path, to_xml_path, ) -from .ast import (Node, FileNode, OtherNode, LemmaNode, IdentNode, CommandEndNode, ModuleImportNode, ModulePreludeNode, - CommandSectionNode, CommandTheoremNode, CommandModuledocNode, CommandNamespaceNode, - CommandDoccommentNode, CommandDeclarationNode, MathlibTacticLemmaNode, TacticTacticseqbracketedNode, - TacticTacticseq1IndentedNode, CommandNoncomputablesectionNode, is_leaf, is_mutual_lean4, - is_potential_premise_lean4, cast_away_optional) +from .ast import ( + Node, + FileNode, + OtherNode, + LemmaNode, + IdentNode, + CommandEndNode, + ModuleImportNode, + ModulePreludeNode, + CommandSectionNode, + CommandTheoremNode, + CommandModuledocNode, + CommandNamespaceNode, + CommandDoccommentNode, + CommandDeclarationNode, + MathlibTacticLemmaNode, + TacticTacticseqbracketedNode, + TacticTacticseq1IndentedNode, + CommandNoncomputablesectionNode, + is_leaf, + is_mutual_lean4, + is_potential_premise_lean4, +) from .lean import LeanFile, LeanGitRepo, Theorem, Pos from ..constants import NUM_WORKERS, LOAD_USED_PACKAGES_ONLY, LEAN4_PACKAGES_DIR From f10ebbed410a4a6c3b2a0666f7cbc6ead1821a87 Mon Sep 17 00:00:00 2001 From: Yansy Date: Fri, 30 Aug 2024 09:44:22 +0800 Subject: [PATCH 07/13] fix(typing): mypy errors - missing import --- src/lean_dojo/data_extraction/traced_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index 1d7a6a7a..57c9dc3f 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -47,6 +47,7 @@ is_leaf, is_mutual_lean4, is_potential_premise_lean4, + cast_away_optional ) from .lean import LeanFile, LeanGitRepo, Theorem, Pos from ..constants import NUM_WORKERS, LOAD_USED_PACKAGES_ONLY, LEAN4_PACKAGES_DIR From 3e56f6e8b8e3fe5084ad37f10f4d14699c299a23 Mon Sep 17 00:00:00 2001 From: Knot Date: Sun, 15 Sep 2024 22:29:00 +0800 Subject: [PATCH 08/13] fix(typing): mypy errors - black --- src/lean_dojo/data_extraction/ast.py | 64 ++++++++++++-------- src/lean_dojo/data_extraction/traced_data.py | 12 ++-- 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/src/lean_dojo/data_extraction/ast.py b/src/lean_dojo/data_extraction/ast.py index 3ef57e96..2e42e04a 100644 --- a/src/lean_dojo/data_extraction/ast.py +++ b/src/lean_dojo/data_extraction/ast.py @@ -4,7 +4,21 @@ from pathlib import Path from dataclasses import dataclass, field from xml.sax.saxutils import escape, unescape -from typing import List, Dict, Any, Optional, Callable, Tuple, Generator, TypeVar, Union, cast, Protocol, Type,Sequence +from typing import ( + List, + Dict, + Any, + Optional, + Callable, + Tuple, + Generator, + TypeVar, + Union, + cast, + Protocol, + Type, + Sequence, +) from ..utils import ( camel_case, @@ -156,9 +170,7 @@ class AtomNode(Node): val: str @classmethod - def from_data( - cls, atom_data: Dict[str, Any], lean_file: LeanFile - ) -> "AtomNode": + def from_data(cls, atom_data: Dict[str, Any], lean_file: LeanFile) -> "AtomNode": info = atom_data["info"] pos_pair = _parse_pos(info, lean_file) if pos_pair is None: @@ -189,9 +201,7 @@ class IdentNode(Node): def_end: Optional[Pos] = None @classmethod - def from_data( - cls, ident_data: Dict[str, Any], lean_file: LeanFile - ) -> "IdentNode": + def from_data(cls, ident_data: Dict[str, Any], lean_file: LeanFile) -> "IdentNode": info = ident_data["info"] pos_pair = _parse_pos(info, lean_file) if pos_pair is None: @@ -384,9 +394,7 @@ class MathlibTacticLemmaNode(Node): full_name: Optional[str] = None # If the Private usage state is False, the user will not explicitly use the parameter None. # And False and None are handled in the same way, so there is no need to use Optional. - _is_private_decl: bool = ( - False # `_is_private` doesn't play well with lxml. - ) + _is_private_decl: bool = False # `_is_private` doesn't play well with lxml. @classmethod def from_data( @@ -437,9 +445,7 @@ class LemmaNode(Node): full_name: Optional[str] = None # If the Private usage state is False, the user will not explicitly use the parameter None. # And False and None are handled in the same way, so there is no need to use Optional. - _is_private_decl: bool = ( - False # `_is_private` doesn't play well with lxml. - ) + _is_private_decl: bool = False # `_is_private` doesn't play well with lxml. @classmethod def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "LemmaNode": @@ -1158,9 +1164,7 @@ class CommandTheoremNode(Node): full_name: Optional[str] = None # If the Private usage state is False, the user will not explicitly use the parameter None. # And False and None are handled in the same way, so there is no need to use Optional. - _is_private_decl: bool = ( - False # `_is_private` doesn't play well with lxml. - ) + _is_private_decl: bool = False # `_is_private` doesn't play well with lxml. @classmethod def from_data( @@ -1246,12 +1250,12 @@ def from_data( class TacticProtocol(Protocol): - def get_tactic_nodes(self, atomic_only: bool = False) -> Generator[Node, None, None]: - ... + def get_tactic_nodes( + self, atomic_only: bool = False + ) -> Generator[Node, None, None]: ... -class TacticNode(Node, TacticProtocol): - ... +class TacticNode(Node, TacticProtocol): ... @dataclass(frozen=True) @@ -1265,9 +1269,7 @@ def from_data( children = _parse_children(node_data, lean_file) return cls(lean_file, start, end, children) - def get_tactic_nodes( - self, atomic_only: bool = False - ) -> None: + def get_tactic_nodes(self, atomic_only: bool = False) -> None: return @@ -1587,8 +1589,16 @@ def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "OtherNode return cls(lean_file, start, end, children, node_data["kind"]) -def is_potential_premise_lean4(node: Node) -> TypeGuard[Union[ - CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode, LeanElabCommandCommandIrreducibleDefNode, StdTacticAliasAliasNode, StdTacticAliasAliaslrNode,]]: +def is_potential_premise_lean4(node: Node) -> TypeGuard[ + Union[ + CommandTheoremNode, + LemmaNode, + MathlibTacticLemmaNode, + LeanElabCommandCommandIrreducibleDefNode, + StdTacticAliasAliasNode, + StdTacticAliasAliaslrNode, + ] +]: """Check if ``node`` is a theorem/definition that can be used as a premise.""" if (isinstance(node, CommandDeclarationNode) and not node.is_example) or isinstance( node, @@ -1605,7 +1615,9 @@ def is_potential_premise_lean4(node: Node) -> TypeGuard[Union[ return False -def is_mutual_lean4(node: Node) -> TypeGuard[Union[IdentNode,CommandTheoremNode, StdTacticAliasAliaslrNode]]: +def is_mutual_lean4( + node: Node, +) -> TypeGuard[Union[IdentNode, CommandTheoremNode, StdTacticAliasAliaslrNode]]: return ( isinstance(node, (IdentNode, CommandTheoremNode, StdTacticAliasAliaslrNode)) and node.is_mutual diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index 442b82a1..d6647298 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -47,7 +47,7 @@ is_leaf, is_mutual_lean4, is_potential_premise_lean4, - cast_away_optional + cast_away_optional, ) from .lean import LeanFile, LeanGitRepo, Theorem, Pos from ..constants import NUM_WORKERS, LOAD_USED_PACKAGES_ONLY, LEAN4_PACKAGES_DIR @@ -127,7 +127,7 @@ def get_code_without_comments( for c in comments: if base <= c.start and c.end <= end: - code_segs.append(lean_file[base: c.start]) + code_segs.append(lean_file[base : c.start]) base = c.end code_segs.append(lean_file[base:end]) @@ -214,7 +214,7 @@ def get_annotated_tactic(self) -> Tuple[str, List[Dict[str, Any]]]: ) lean_file = self.traced_theorem.traced_file.lean_file annot_tac = [] - provenances: List[Dict[str,Any]] = [] + provenances: List[Dict[str, Any]] = [] cur = self.start def _callback4(node: IdentNode, _): @@ -227,14 +227,14 @@ def _callback4(node: IdentNode, _): and node.def_end is not None ): if cur <= node.start: - annot_tac.append(lean_file[cur: node.start]) - annot_tac.append("" + lean_file[node.start: node.end] + "") + annot_tac.append(lean_file[cur : node.start]) + annot_tac.append("" + lean_file[node.start : node.end] + "") prov = { "full_name": node.full_name, "def_path": node.def_path, "def_pos": list(node.def_start), "def_end_pos": list(node.def_end), - } + } provenances.append(prov) cur = node.end From a9f5f6eee77ec359628eb112a52ae86ce033d6e8 Mon Sep 17 00:00:00 2001 From: Yansy Date: Wed, 18 Sep 2024 17:32:17 +0800 Subject: [PATCH 09/13] fix(typing): callback type check --- src/lean_dojo/data_extraction/ast.py | 15 ++----- src/lean_dojo/data_extraction/traced_data.py | 41 +++++++++----------- 2 files changed, 23 insertions(+), 33 deletions(-) diff --git a/src/lean_dojo/data_extraction/ast.py b/src/lean_dojo/data_extraction/ast.py index 2e42e04a..00b712df 100644 --- a/src/lean_dojo/data_extraction/ast.py +++ b/src/lean_dojo/data_extraction/ast.py @@ -576,7 +576,9 @@ def from_data( def is_private(self) -> bool: result = False - def _callback(node: CommandDeclmodifiersNode, _) -> bool: + def _callback(node: Node, _) -> bool: + if not isinstance(node, CommandDeclmodifiersNode): + raise TypeError("Excepted CommandDeclmodifiersNode") nonlocal result result = True return True @@ -1589,16 +1591,7 @@ def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "OtherNode return cls(lean_file, start, end, children, node_data["kind"]) -def is_potential_premise_lean4(node: Node) -> TypeGuard[ - Union[ - CommandTheoremNode, - LemmaNode, - MathlibTacticLemmaNode, - LeanElabCommandCommandIrreducibleDefNode, - StdTacticAliasAliasNode, - StdTacticAliasAliaslrNode, - ] -]: +def is_potential_premise_lean4(node: Node) -> bool: """Check if ``node`` is a theorem/definition that can be used as a premise.""" if (isinstance(node, CommandDeclarationNode) and not node.is_example) or isinstance( node, diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index d6647298..dbd7a049 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -217,7 +217,9 @@ def get_annotated_tactic(self) -> Tuple[str, List[Dict[str, Any]]]: provenances: List[Dict[str, Any]] = [] cur = self.start - def _callback4(node: IdentNode, _): + def _callback4(node: Node, _): + if not isinstance(node, IdentNode): + raise TypeError("Excepted IdentNode") nonlocal cur if ( @@ -371,7 +373,9 @@ def get_premise_full_names(self) -> List[str]: """Return the fully qualified names of all premises used in the proof.""" names = [] - def _callback(node: IdentNode, _: List[Node]): + def _callback(node: Node, _: List[Node]): + if not isinstance(node, IdentNode): + raise TypeError("Excepted IdentNode") if node.full_name is not None: names.append(node.full_name) @@ -770,28 +774,21 @@ def get_traced_theorem( private_result = None def _callback( - node: Union[CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode], _ + node: Node, _ ) -> bool: nonlocal result, private_result - if ( - isinstance( - node, - ( - CommandTheoremNode, - LemmaNode, - MathlibTacticLemmaNode, - ), - ) - and node.full_name == thm.full_name - ): - start = cast_away_optional(node.start) - end = cast_away_optional(node.end) - comments = self._filter_comments(start, end) - t = TracedTheorem(self.root_dir, thm, node, comments, self) - if t.is_private: - private_result = t - else: - result = t + if not isinstance(node, (CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode)): + raise TypeError("Except CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode") + if not node.full_name == thm.full_name: + return False + start = cast_away_optional(node.start) + end = cast_away_optional(node.end) + comments = self._filter_comments(start, end) + t = TracedTheorem(self.root_dir, thm, node, comments, self) + if t.is_private: + private_result = t + else: + result = t return False self.ast.traverse_preorder(_callback, node_cls=None) From bfa7b22c87104be8a20eac5164eab98f3cdc8727 Mon Sep 17 00:00:00 2001 From: Yansy Date: Thu, 19 Sep 2024 11:04:55 +0800 Subject: [PATCH 10/13] chore(foramt): black check fix --- src/lean_dojo/data_extraction/traced_data.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index dbd7a049..fcd1f699 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -773,12 +773,14 @@ def get_traced_theorem( result = None private_result = None - def _callback( - node: Node, _ - ) -> bool: + def _callback(node: Node, _) -> bool: nonlocal result, private_result - if not isinstance(node, (CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode)): - raise TypeError("Except CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode") + if not isinstance( + node, (CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode) + ): + raise TypeError( + "Except CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode" + ) if not node.full_name == thm.full_name: return False start = cast_away_optional(node.start) From 0d94e66086d1fb6480ed405b623e8464d18286d0 Mon Sep 17 00:00:00 2001 From: Yansy Date: Fri, 20 Sep 2024 10:31:44 +0800 Subject: [PATCH 11/13] fix(typing): Unsupported left operand type for <= ('None') --- src/lean_dojo/data_extraction/traced_data.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index fcd1f699..eb56df1c 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -181,11 +181,15 @@ def state_after(self) -> str: @property def start(self) -> Pos: """Start position in :file:`*.lean` file.""" + if not isinstance(self.ast.start, Pos): + raise TypeError("start expected to be Pos") return self.ast.start @property def end(self) -> Pos: """End position in :file:`*.lean` file.""" + if not isinstance(self.ast.end, Pos): + raise TypeError("end expected to be Pos") return self.ast.end def to_string(self) -> str: @@ -220,6 +224,8 @@ def get_annotated_tactic(self) -> Tuple[str, List[Dict[str, Any]]]: def _callback4(node: Node, _): if not isinstance(node, IdentNode): raise TypeError("Excepted IdentNode") + if node.start is None or node.end is None: + raise TypeError("start/end expected to be Pos, Unsupported left operand type for <= ('None')") nonlocal cur if ( @@ -287,11 +293,15 @@ def __getstate__(self) -> Dict[str, Any]: @property def start(self) -> Pos: """Start position in :file:`*.lean` file.""" + if not isinstance(self.ast.start, Pos): + raise TypeError("start expected to be Pos") return self.ast.start @property def end(self) -> Pos: """End position in :file:`*.lean` file.""" + if not isinstance(self.ast.end, Pos): + raise TypeError("end expected to be Pos") return self.ast.end @property From 38bc4996b881ac387757e0069af2750faa59b260 Mon Sep 17 00:00:00 2001 From: Yansy Date: Fri, 20 Sep 2024 10:45:48 +0800 Subject: [PATCH 12/13] fix(typing): join type error --- src/lean_dojo/data_extraction/traced_data.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index eb56df1c..ec2f0c55 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -648,11 +648,13 @@ def _callback(node: Node, _): ): inside_sections_namespaces.pop() elif is_potential_premise_lean4(node): - prefix = ".".join( - ns.name - for ns in inside_sections_namespaces - if isinstance(ns, CommandNamespaceNode) - ) + names = [] + for ns in inside_sections_namespaces: + if isinstance(ns, CommandNamespaceNode): + if ns.name is None: + raise TypeError("Expected ns.name to be str") + names.append(ns.name) + prefix = ".".join(names) if is_mutual_lean4(node): full_name = [_qualify_name(name, prefix) for name in node.name] From 094ee3fa288f0f7b8288a004b13bd43ec5e2a0cf Mon Sep 17 00:00:00 2001 From: Yansy Date: Fri, 20 Sep 2024 14:14:04 +0800 Subject: [PATCH 13/13] fix(typing): mypy errors --- src/lean_dojo/data_extraction/ast.py | 23 ++++++++++- src/lean_dojo/data_extraction/traced_data.py | 43 +++++++++++++++----- 2 files changed, 55 insertions(+), 11 deletions(-) diff --git a/src/lean_dojo/data_extraction/ast.py b/src/lean_dojo/data_extraction/ast.py index 00b712df..95a456a8 100644 --- a/src/lean_dojo/data_extraction/ast.py +++ b/src/lean_dojo/data_extraction/ast.py @@ -38,6 +38,18 @@ def cast_away_optional(x: Optional[T]) -> T: return x +def cast_list_str(x: Union[None, str, List[str]]) -> List[str]: + if isinstance(x, list) and all(isinstance(i, str) for i in x): + return x # Return the list if all elements are strings + raise TypeError(f"Expected None, str, or List[str], but got {type(x).__name__}") + + +def cast_str(x: Union[None, str, List[str]]) -> str: + if isinstance(x, str): + return x + raise TypeError(f"Expected str, but got {type(x).__name__}") + + @dataclass(frozen=True) class Node: lean_file: LeanFile @@ -1591,7 +1603,16 @@ def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "OtherNode return cls(lean_file, start, end, children, node_data["kind"]) -def is_potential_premise_lean4(node: Node) -> bool: +def is_potential_premise_lean4(node: Node) -> TypeGuard[ + Union[ + CommandDeclarationNode, + LemmaNode, + MathlibTacticLemmaNode, + LeanElabCommandCommandIrreducibleDefNode, + StdTacticAliasAliasNode, + StdTacticAliasAliaslrNode, + ] +]: """Check if ``node`` is a theorem/definition that can be used as a premise.""" if (isinstance(node, CommandDeclarationNode) and not node.is_example) or isinstance( node, diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index ec2f0c55..5c828458 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -14,7 +14,7 @@ from pathlib import Path from loguru import logger from dataclasses import dataclass, field -from typing import List, Optional, Dict, Any, Tuple, Union +from typing import List, Optional, Dict, Any, Tuple, Union, cast from ..utils import ( is_git_repo, @@ -48,6 +48,11 @@ is_mutual_lean4, is_potential_premise_lean4, cast_away_optional, + StdTacticAliasAliaslrNode, + LeanElabCommandCommandIrreducibleDefNode, + StdTacticAliasAliasNode, + cast_list_str, + cast_str, ) from .lean import LeanFile, LeanGitRepo, Theorem, Pos from ..constants import NUM_WORKERS, LOAD_USED_PACKAGES_ONLY, LEAN4_PACKAGES_DIR @@ -225,7 +230,9 @@ def _callback4(node: Node, _): if not isinstance(node, IdentNode): raise TypeError("Excepted IdentNode") if node.start is None or node.end is None: - raise TypeError("start/end expected to be Pos, Unsupported left operand type for <= ('None')") + raise TypeError( + "start/end expected to be Pos, Unsupported left operand type for <= ('None')" + ) nonlocal cur if ( @@ -648,6 +655,7 @@ def _callback(node: Node, _): ): inside_sections_namespaces.pop() elif is_potential_premise_lean4(node): + assert node.name is not None names = [] for ns in inside_sections_namespaces: if isinstance(ns, CommandNamespaceNode): @@ -656,14 +664,26 @@ def _callback(node: Node, _): names.append(ns.name) prefix = ".".join(names) - if is_mutual_lean4(node): + full_name: Union[str, List[str]] + if isinstance(node, StdTacticAliasAliaslrNode) and node.is_mutual: full_name = [_qualify_name(name, prefix) for name in node.name] - else: + object.__setattr__(node, "full_name", full_name) + elif isinstance( + node, + ( + CommandDeclarationNode, + LemmaNode, + MathlibTacticLemmaNode, + LeanElabCommandCommandIrreducibleDefNode, + StdTacticAliasAliasNode, + ), + ): full_name = _qualify_name(node.name, prefix) - - object.__setattr__(node, "full_name", full_name) - if isinstance(node, CommandDeclarationNode) and node.is_theorem: - object.__setattr__(node.get_theorem_node(), "full_name", full_name) + object.__setattr__(node, "full_name", full_name) + if isinstance(node, CommandDeclarationNode) and node.is_theorem: + object.__setattr__( + node.get_theorem_node(), "full_name", full_name + ) elif isinstance( node, ( @@ -905,8 +925,10 @@ def _callback4(node: Node, _) -> None: self.lean_file, start, end, self.comments ) # TODO: For alias, restate_axiom, etc., the code is not very informative. + full_name: Union[None, str, List[str]] if is_mutual_lean4(node): - for s in node.full_name: + full_name = cast_list_str(node.full_name) + for s in full_name: results.append( { "full_name": s, @@ -917,9 +939,10 @@ def _callback4(node: Node, _) -> None: } ) else: + full_name = cast_str(node.full_name) results.append( { - "full_name": node.full_name, + "full_name": full_name, "code": code, "start": list(start) if start is not None else [], "end": list(end) if end is not None else [],