Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/lean_dojo/data_extraction/lean.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,11 @@ class LeanGitRepo:
You can also use tags such as ``v3.5.0``. They will be converted to commit hashes.
"""

subdir: str = field(default="")
"""The subdirectory of the repo containing the Lean project. Default is the repo's root directory.
This cannot start with a ``/``.
"""

repo: Union[Repository, Repo] = field(init=False, repr=False)
"""A :class:`github.Repository` object for GitHub repos or
a :class:`git.Repo` object for local or remote Git repos.
Expand Down Expand Up @@ -773,7 +778,7 @@ def _get_config_url(self, filename: str) -> str:
assert self.repo_type == RepoType.GITHUB
assert "github.com" in self.url, f"Unsupported URL: {self.url}"
url = self.url.replace("github.com", "raw.githubusercontent.com")
return f"{url}/{self.commit}/{filename}"
return f"{url}/{self.commit}/{self.subdir}/{filename}"

def get_config(self, filename: str, num_retries: int = 2) -> Dict[str, Any]:
"""Return the repo's files."""
Expand All @@ -782,7 +787,7 @@ def get_config(self, filename: str, num_retries: int = 2) -> Dict[str, Any]:
content = read_url(config_url, num_retries)
else:
working_dir = self.repo.working_dir
with open(os.path.join(working_dir, filename), "r") as f:
with open(os.path.join(working_dir, self.subdir, filename), "r") as f:
content = f.read()
if filename.endswith(".toml"):
return toml.loads(content)
Expand All @@ -797,7 +802,7 @@ def uses_lakefile_lean(self) -> bool:
url = self._get_config_url("lakefile.lean")
return url_exists(url)
else:
lakefile_path = Path(self.repo.working_dir) / "lakefile.lean"
lakefile_path = Path(self.repo.working_dir) / self.subdir / "lakefile.lean"
return lakefile_path.exists()

def uses_lakefile_toml(self) -> bool:
Expand All @@ -806,7 +811,7 @@ def uses_lakefile_toml(self) -> bool:
url = self._get_config_url("lakefile.toml")
return url_exists(url)
else:
lakefile_path = Path(self.repo.working_dir) / "lakefile.toml"
lakefile_path = Path(self.repo.working_dir) / self.subdir / "lakefile.toml"
return lakefile_path.exists()


Expand Down
6 changes: 3 additions & 3 deletions src/lean_dojo/data_extraction/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _trace(repo: LeanGitRepo, build_deps: bool) -> None:
repo.clone_and_checkout()
logger.debug(f"Tracing {repo}")

with working_directory(repo.name):
with working_directory(os.path.join(repo.name, repo.subdir)):
# Build the repo using lake.
if not build_deps:
try:
Expand Down Expand Up @@ -204,14 +204,14 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path:
Returns:
Path: The path of the traced repo in the cache, e.g. :file:`/home/kaiyu/.cache/lean_dojo/leanprover-community-mathlib-2196ab363eb097c008d4497125e0dde23fb36db2`
"""
rel_cache_dir = repo.get_cache_dirname() / repo.name
rel_cache_dir = repo.get_cache_dirname() / repo.name / repo.subdir
path = cache.get(rel_cache_dir)
if path is None:
logger.info(f"Tracing {repo}")
with working_directory() as tmp_dir:
logger.debug(f"Working in the temporary directory {tmp_dir}")
_trace(repo, build_deps)
src_dir = tmp_dir / repo.name
src_dir = tmp_dir / repo.name / repo.subdir
traced_repo = TracedRepo.from_traced_files(src_dir, build_deps)
traced_repo.save_to_disk()
path = cache.store(src_dir, rel_cache_dir)
Expand Down