added all the progress bars for each worker

This commit is contained in:
Karma Riuk
2025-05-17 10:45:57 +02:00
parent 98db478b7b
commit 4d9c47f33a

View File

@ -1,5 +1,6 @@
from collections import defaultdict from collections import defaultdict
import argparse, os, subprocess, docker, uuid, concurrent.futures import argparse, os, subprocess, docker, uuid
from concurrent.futures import wait, FIRST_COMPLETED, ProcessPoolExecutor
from github.Commit import Commit from github.Commit import Commit
from github.ContentFile import ContentFile from github.ContentFile import ContentFile
from github.PullRequest import PullRequest from github.PullRequest import PullRequest
@ -214,6 +215,7 @@ def process_pull(
repos_dir: str, repos_dir: str,
archive_destination: str, archive_destination: str,
cache: dict[str, dict[int, DatasetEntry]] = {}, cache: dict[str, dict[int, DatasetEntry]] = {},
show_progress: bool = True,
): ):
if pr.number in cache.get(repo.full_name, set()): if pr.number in cache.get(repo.full_name, set()):
dataset.entries.append(cache[repo.full_name][pr.number]) dataset.entries.append(cache[repo.full_name][pr.number])
@ -279,7 +281,9 @@ def process_pull(
), ),
] ]
pbar = tqdm(total=len(setup_steps) + 6, desc="Processing PR", leave=False) pbar = tqdm(
total=len(setup_steps) + 6, desc="Processing PR", leave=False, disable=not show_progress
)
for message, action in setup_steps: for message, action in setup_steps:
pbar.set_postfix( pbar.set_postfix(
{ {
@ -361,6 +365,8 @@ def process_repo(
repos_dir: str, repos_dir: str,
archive_destination: str, archive_destination: str,
cache: dict[str, dict[int, DatasetEntry]] = {}, cache: dict[str, dict[int, DatasetEntry]] = {},
position: int = 1,
show_progress: bool = True,
): ):
repo = g.get_repo(repo_name) repo = g.get_repo(repo_name)
already_seen_prs = set() already_seen_prs = set()
@ -370,7 +376,9 @@ def process_repo(
prs = repo.get_pulls(state="closed") prs = repo.get_pulls(state="closed")
n_good_prs = 0 n_good_prs = 0
with tqdm(total=prs.totalCount, desc="Processing prs", leave=False) as pbar: with tqdm(
total=prs.totalCount, desc=f"Processing prs of {repo_name}", leave=False, position=position
) as pbar:
for pr in prs: for pr in prs:
pbar.set_postfix({"pr": pr.number, "# new good found": n_good_prs}) pbar.set_postfix({"pr": pr.number, "# new good found": n_good_prs})
try: try:
@ -378,8 +386,10 @@ def process_repo(
continue continue
n_good_prs += 1 n_good_prs += 1
process_pull(repo, pr, dataset, repos_dir, archive_destination, cache) process_pull(
dataset.to_json(args.output) repo, pr, dataset, repos_dir, archive_destination, cache, show_progress
)
# dataset.to_json(args.output)
except Exception as e: except Exception as e:
tqdm.write(f"[ERROR] PR #{pr.number} in {repo.full_name}. {type(e)}: {e}") tqdm.write(f"[ERROR] PR #{pr.number} in {repo.full_name}. {type(e)}: {e}")
finally: finally:
@ -388,19 +398,24 @@ def process_repo(
# Wrapper to run in each worker process # Wrapper to run in each worker process
def process_repo_worker( def process_repo_worker(
repo_name: str, repos_dir: str, archive_destination: str, cache: dict repo_name: str, repos_dir: str, archive_destination: str, cache: dict, position: int
) -> list: ) -> list:
# Initialize GitHub and Docker clients in each process
token = os.environ.get("GITHUB_AUTH_TOKEN_CRAB")
g_worker = Github(token, seconds_between_requests=0)
docker_client_worker = docker.from_env()
# Local dataset to collect entries for this repo # Local dataset to collect entries for this repo
local_dataset = Dataset() local_dataset = Dataset()
# Call the existing process_repo, but passing the local GitHub and Docker clients # Call the existing process_repo, but passing the local GitHub and Docker clients
# You may need to modify process_repo to accept g and docker_client as parameters # You may need to modify process_repo to accept g and docker_client as parameters
process_repo(repo_name, local_dataset, repos_dir, archive_destination, cache) try:
process_repo(
repo_name,
local_dataset,
repos_dir,
archive_destination,
cache,
position=position,
show_progress=False,
)
finally:
return local_dataset.entries return local_dataset.entries
@ -409,6 +424,7 @@ def process_repos_parallel(
dataset: Dataset, dataset: Dataset,
repos_dir: str, repos_dir: str,
archive_destination: str, archive_destination: str,
n_workers: int,
cache: dict[str, dict[int, DatasetEntry]] = {}, cache: dict[str, dict[int, DatasetEntry]] = {},
): ):
""" """
@ -421,29 +437,62 @@ def process_repos_parallel(
archive_destination: Directory for archives archive_destination: Directory for archives
cache: Optional cache of previously processed PR entries cache: Optional cache of previously processed PR entries
""" """
for pr2entry in tqdm(cache.values(), desc="Adding cache in dataset"):
dataset.entries.extend(pr2entry.values())
dataset.to_json(args.output)
repo_names = df["name"] repo_names = df["name"]
free_positions = list(range(1, n_workers + 1))
# Use all CPUs; adjust max_workers as needed repo_names_iter = iter(repo_names)
with concurrent.futures.ProcessPoolExecutor(max_workers=os.cpu_count()) as executor: future_to_repo = {}
# Map each repo to a future with tqdm(
future_to_repo = { total=len(repo_names),
executor.submit(process_repo_worker, name, repos_dir, archive_destination, cache): name
for name in repo_names
}
# Iterate as repos complete
for future in tqdm(
concurrent.futures.as_completed(future_to_repo),
total=len(future_to_repo),
desc="Processing repos", desc="Processing repos",
): ) as outer_pb, ProcessPoolExecutor(max_workers=n_workers) as executor:
repo_name = future_to_repo[future] # Map each repo to a future
for _ in range(n_workers):
try: try:
entries = future.result() name = next(repo_names_iter)
except StopIteration:
break
pos = free_positions.pop(0)
fut = executor.submit(
process_repo_worker, name, repos_dir, archive_destination, cache, pos
)
future_to_repo[fut] = (name, pos)
try:
while future_to_repo:
done, _ = wait(future_to_repo, return_when=FIRST_COMPLETED)
for fut in done:
repo_finished, pos = future_to_repo.pop(fut)
entries = fut.result()
dataset.entries.extend(entries) dataset.entries.extend(entries)
dataset.to_json(args.output) dataset.to_json(args.output)
except Exception as e: outer_pb.update(1)
tqdm.write(f"[ERROR] Repo {repo_name}: {e}")
try:
name = next(repo_names_iter)
except StopIteration:
# no more tasks: free the slot
free_positions.append(pos)
else:
new_fut = executor.submit(
process_repo_worker, name, repos_dir, archive_destination, cache, pos
)
future_to_repo[new_fut] = (name, pos)
except KeyboardInterrupt:
print("Saving all the entries up until now")
# any futures that happen to be done but not yet popped:
for fut in list(future_to_repo):
if fut.done():
try:
dataset.entries.extend(fut.result())
except Exception:
pass
# re-raise so the toplevel finally block still runs
raise
def process_repos( def process_repos(
@ -539,6 +588,12 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Caches GitHub API requests in a SQLite file using 'requests_cache' (see optional-requirements.txt). Useful for faster reruns if the script crashes or youre tweaking it. Might produce stale data.", help="Caches GitHub API requests in a SQLite file using 'requests_cache' (see optional-requirements.txt). Useful for faster reruns if the script crashes or youre tweaking it. Might produce stale data.",
) )
parser.add_argument(
"--max-workers",
metavar="N_WORKERS",
type=int,
help="Parallelize the processing of the repos with the given number of workers. If not given, the script is monothreaded",
)
args = parser.parse_args() args = parser.parse_args()
@ -577,6 +632,18 @@ if __name__ == "__main__":
dataset = Dataset() dataset = Dataset()
try: try:
if args.max_workers is not None:
process_repos_parallel(
df, dataset, args.repos, args.archive_destination, args.max_workers, cache
)
else:
process_repos(df, dataset, args.repos, args.archive_destination, cache) process_repos(df, dataset, args.repos, args.archive_destination, cache)
finally: finally:
print("")
print("")
print("")
print("")
print("")
print("")
print(f"Writing dataset to {args.output}")
dataset.to_json(args.output) dataset.to_json(args.output)