From 2808e2dd9ddaf0bb399bf3d3e627ae0d16ed1a1a Mon Sep 17 00:00:00 2001 From: "J. Simon Richard" Date: Wed, 5 Mar 2025 16:32:49 -0500 Subject: [PATCH] Only switch to github api when the access token exists --- src/lean_dojo/data_extraction/lean.py | 35 +++++++++++++++++---------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index 9f8ac518..e7127085 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -57,12 +57,14 @@ _SSH_TO_HTTPS_REGEX = re.compile(r"git@github\.com:(?P.+)/(?P.+?)(\.git)?") +_IS_SSH_REGEX = re.compile(r"^(ssh:\/\/)?(\w+@)?[\w.]+:[^:]+\.git$") + REPO_CACHE_PREFIX = "repos" class RepoType(Enum): - GITHUB = 0 - REMOTE = 1 # Remote but not GitHub. + GITHUB = 0 # Use GitHub API + REMOTE = 1 # Clone directly from URL LOCAL = 2 @@ -71,13 +73,16 @@ def normalize_url(url: str, repo_type: RepoType = RepoType.GITHUB) -> str: return os.path.abspath(url) # Remove trailing `/`. url = _URL_REGEX.fullmatch(url)["url"] # type: ignore - return ssh_to_https(url) + gh_url = os.getenv("GITHUB_ACCESS_TOKEN") and github_ssh_to_https(url) + return gh_url or url -def ssh_to_https(url: str) -> str: +def github_ssh_to_https(url: str) -> str: m = _SSH_TO_HTTPS_REGEX.fullmatch(url) - return f"https://github.com/{m.group('user')}/{m.group('repo')}" if m else url + return f"https://github.com/{m.group('user')}/{m.group('repo')}" if m else None +def is_ssh(url: str) -> bool: + return _IS_SSH_REGEX.fullmatch(url) is not None def get_repo_type(url: str) -> Optional[RepoType]: """Get the type of the repository. @@ -87,22 +92,26 @@ def get_repo_type(url: str) -> Optional[RepoType]: Returns: Optional[str]: The type of the repository (None if the repo cannot be found). """ - url = ssh_to_https(url) + + # Only convert ssh to https if GITHUB_ACCESS_TOKEN is set. + if os.getenv("GITHUB_ACCESS_TOKEN"): + url = github_ssh_to_https(url) + parsed_url = urllib.parse.urlparse(url) # type: ignore if parsed_url.scheme in ["http", "https"]: # Case 1 - GitHub URL. if "github.com" in url: - if not url.startswith("https://"): - logger.warning(f"{url} should start with https://") - return None - else: - return RepoType.GITHUB - # Case 2 - remote URL. + return RepoType.GITHUB + # Case 2 - remote http(s) URL. elif url_exists(url): # Not check whether it is a git URL return RepoType.REMOTE - # Case 3 - local path + # Case 3 - SSH URL. + elif is_ssh(url): + return RepoType.REMOTE + # Case 4 - local path elif is_git_repo(Path(parsed_url.path)): return RepoType.LOCAL + logger.warning(f"{url} is not a valid URL") return None