diff --git a/pull_requests.py b/pull_requests.py index 8209b67..93e7d7a 100644 --- a/pull_requests.py +++ b/pull_requests.py @@ -79,20 +79,7 @@ def reset_repo_to_latest_commit(repo_path: str) -> None: run_git_cmd(["reset", "--hard", current_branch], repo_path) -def get_diffs_before(repo: Repository, pr: PullRequest) -> dict[str, str]: - comments = list(pr.get_review_comments()) - comments.sort(key=lambda comment: comment.created_at) - first_comment = comments[0] - try: - return { - file.filename: file.patch - for file in repo.compare(pr.base.sha, first_comment.commit_id).files - } - except GithubException as e: - raise NoDiffsBeforeError(e) - - -def get_diffs_after(repo: Repository, pr: PullRequest) -> dict[str, str]: +def get_last_commit_before_comments(pr: PullRequest) -> Commit: comments = list(pr.get_review_comments()) commits = list(pr.get_commits()) comments.sort(key=lambda comment: comment.created_at) @@ -103,7 +90,22 @@ def get_diffs_after(repo: Repository, pr: PullRequest) -> dict[str, str]: for commit in commits[:]: if commit.commit.author.date > first_comment.created_at: commits.remove(commit) - last_commit_before_comments = commits[-1] + return commits[-1] + + +def get_diffs_before(repo: Repository, pr: PullRequest) -> dict[str, str]: + last_commit_before_comments = get_last_commit_before_comments(pr) + try: + return { + file.filename: file.patch + for file in repo.compare(pr.base.sha, last_commit_before_comments.sha).files + } + except GithubException as e: + raise NoDiffsBeforeError(e) + + +def get_diffs_after(repo: Repository, pr: PullRequest) -> dict[str, str]: + last_commit_before_comments = get_last_commit_before_comments(pr) try: return { file.filename: file.patch