From 1be90b712be7681733fccee5c19dbd7f7d7f1acd Mon Sep 17 00:00:00 2001 From: "J. Simon Richard" Date: Wed, 5 Mar 2025 16:37:45 -0500 Subject: [PATCH] Enable lake project in subdir of repo --- src/lean_dojo/data_extraction/lean.py | 13 +++++++++---- src/lean_dojo/data_extraction/trace.py | 6 +++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index 9f8ac518..33d76942 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -527,6 +527,11 @@ class LeanGitRepo: You can also use tags such as ``v3.5.0``. They will be converted to commit hashes. """ + subdir: str = field(default="") + """The subdirectory of the repo containing the Lean project. Default is the repo's root directory. + This cannot start with a ``/``. + """ + repo: Union[Repository, Repo] = field(init=False, repr=False) """A :class:`github.Repository` object for GitHub repos or a :class:`git.Repo` object for local or remote Git repos. @@ -773,7 +778,7 @@ def _get_config_url(self, filename: str) -> str: assert self.repo_type == RepoType.GITHUB assert "github.com" in self.url, f"Unsupported URL: {self.url}" url = self.url.replace("github.com", "raw.githubusercontent.com") - return f"{url}/{self.commit}/{filename}" + return f"{url}/{self.commit}/{self.subdir}/{filename}" def get_config(self, filename: str, num_retries: int = 2) -> Dict[str, Any]: """Return the repo's files.""" @@ -782,7 +787,7 @@ def get_config(self, filename: str, num_retries: int = 2) -> Dict[str, Any]: content = read_url(config_url, num_retries) else: working_dir = self.repo.working_dir - with open(os.path.join(working_dir, filename), "r") as f: + with open(os.path.join(working_dir, self.subdir, filename), "r") as f: content = f.read() if filename.endswith(".toml"): return toml.loads(content) @@ -797,7 +802,7 @@ def uses_lakefile_lean(self) -> bool: url = self._get_config_url("lakefile.lean") return url_exists(url) else: - lakefile_path = Path(self.repo.working_dir) / "lakefile.lean" + lakefile_path = Path(self.repo.working_dir) / self.subdir / "lakefile.lean" return lakefile_path.exists() def uses_lakefile_toml(self) -> bool: @@ -806,7 +811,7 @@ def uses_lakefile_toml(self) -> bool: url = self._get_config_url("lakefile.toml") return url_exists(url) else: - lakefile_path = Path(self.repo.working_dir) / "lakefile.toml" + lakefile_path = Path(self.repo.working_dir) / self.subdir / "lakefile.toml" return lakefile_path.exists() diff --git a/src/lean_dojo/data_extraction/trace.py b/src/lean_dojo/data_extraction/trace.py index 379b7bc0..5b1f34e1 100644 --- a/src/lean_dojo/data_extraction/trace.py +++ b/src/lean_dojo/data_extraction/trace.py @@ -130,7 +130,7 @@ def _trace(repo: LeanGitRepo, build_deps: bool) -> None: repo.clone_and_checkout() logger.debug(f"Tracing {repo}") - with working_directory(repo.name): + with working_directory(os.path.join(repo.name, repo.subdir)): # Build the repo using lake. if not build_deps: try: @@ -204,14 +204,14 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path: Returns: Path: The path of the traced repo in the cache, e.g. :file:`/home/kaiyu/.cache/lean_dojo/leanprover-community-mathlib-2196ab363eb097c008d4497125e0dde23fb36db2` """ - rel_cache_dir = repo.get_cache_dirname() / repo.name + rel_cache_dir = repo.get_cache_dirname() / repo.name / repo.subdir path = cache.get(rel_cache_dir) if path is None: logger.info(f"Tracing {repo}") with working_directory() as tmp_dir: logger.debug(f"Working in the temporary directory {tmp_dir}") _trace(repo, build_deps) - src_dir = tmp_dir / repo.name + src_dir = tmp_dir / repo.name / repo.subdir traced_repo = TracedRepo.from_traced_files(src_dir, build_deps) traced_repo.save_to_disk() path = cache.store(src_dir, rel_cache_dir)