Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions src/lean_dojo/data_extraction/lean.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@

_SSH_TO_HTTPS_REGEX = re.compile(r"git@github\.com:(?P<user>.+)/(?P<repo>.+?)(\.git)?")

_IS_SSH_REGEX = re.compile(r"^(ssh:\/\/)?(\w+@)?[\w.]+:[^:]+\.git$")

REPO_CACHE_PREFIX = "repos"


class RepoType(Enum):
GITHUB = 0
REMOTE = 1 # Remote but not GitHub.
GITHUB = 0 # Use GitHub API
REMOTE = 1 # Clone directly from URL
LOCAL = 2


Expand All @@ -71,13 +73,16 @@ def normalize_url(url: str, repo_type: RepoType = RepoType.GITHUB) -> str:
return os.path.abspath(url)
# Remove trailing `/`.
url = _URL_REGEX.fullmatch(url)["url"] # type: ignore
return ssh_to_https(url)
gh_url = os.getenv("GITHUB_ACCESS_TOKEN") and github_ssh_to_https(url)
return gh_url or url


def ssh_to_https(url: str) -> str:
def github_ssh_to_https(url: str) -> str:
m = _SSH_TO_HTTPS_REGEX.fullmatch(url)
return f"https://github.com/{m.group('user')}/{m.group('repo')}" if m else url
return f"https://github.com/{m.group('user')}/{m.group('repo')}" if m else None

def is_ssh(url: str) -> bool:
return _IS_SSH_REGEX.fullmatch(url) is not None

def get_repo_type(url: str) -> Optional[RepoType]:
"""Get the type of the repository.
Expand All @@ -87,22 +92,26 @@ def get_repo_type(url: str) -> Optional[RepoType]:
Returns:
Optional[str]: The type of the repository (None if the repo cannot be found).
"""
url = ssh_to_https(url)

# Only convert ssh to https if GITHUB_ACCESS_TOKEN is set.
if os.getenv("GITHUB_ACCESS_TOKEN"):
url = github_ssh_to_https(url)

parsed_url = urllib.parse.urlparse(url) # type: ignore
if parsed_url.scheme in ["http", "https"]:
# Case 1 - GitHub URL.
if "github.com" in url:
if not url.startswith("https://"):
logger.warning(f"{url} should start with https://")
return None
else:
return RepoType.GITHUB
# Case 2 - remote URL.
return RepoType.GITHUB
# Case 2 - remote http(s) URL.
elif url_exists(url): # Not check whether it is a git URL
return RepoType.REMOTE
# Case 3 - local path
# Case 3 - SSH URL.
elif is_ssh(url):
return RepoType.REMOTE
# Case 4 - local path
elif is_git_repo(Path(parsed_url.path)):
return RepoType.LOCAL

logger.warning(f"{url} is not a valid URL")
return None

Expand Down