diff --git a/commit0/harness/run_pytest_ids.py b/commit0/harness/run_pytest_ids.py index 37f011e..166f991 100644 --- a/commit0/harness/run_pytest_ids.py +++ b/commit0/harness/run_pytest_ids.py @@ -108,7 +108,7 @@ def main( if not found_remote_branch: raise Exception(f"Branch {branch} does not exist locally or remotely.") patch = generate_patch_between_commits( - local_repo, example["base_commit"], commit_id + local_repo, example["base_commit"], commit_id, example["src_dir"] ) patch_file = Path(log_dir / "patch.diff") patch_file.write_text(patch) diff --git a/commit0/harness/utils.py b/commit0/harness/utils.py index 8a44685..821bac4 100644 --- a/commit0/harness/utils.py +++ b/commit0/harness/utils.py @@ -163,7 +163,7 @@ def create_repo_on_github( def generate_patch_between_commits( - repo: git.Repo, old_commit: str, new_commit: str + repo: git.Repo, old_commit: str, new_commit: str, src_dir: str ) -> str: """Generate a patch string by comparing two specified commits. @@ -172,6 +172,7 @@ def generate_patch_between_commits( repo (git.Repo): An instance of the git.Repo object representing the repository. old_commit (str): The hash or reference to the old commit. new_commit (str): The hash or reference to the new commit. + src_dir (str): The source directory to exclude from the patch. Returns: -------