added all the progress bars for each worker

This commit is contained in:
Karma Riuk
2025-05-17 10:45:57 +02:00
parent 98db478b7b
commit 4d9c47f33a

View File

@ -1,5 +1,6 @@
from collections import defaultdict
import argparse, os, subprocess, docker, uuid, concurrent.futures
import argparse, os, subprocess, docker, uuid
from concurrent.futures import wait, FIRST_COMPLETED, ProcessPoolExecutor
from github.Commit import Commit
from github.ContentFile import ContentFile
from github.PullRequest import PullRequest
@ -214,6 +215,7 @@ def process_pull(
repos_dir: str,
archive_destination: str,
cache: dict[str, dict[int, DatasetEntry]] = {},
show_progress: bool = True,
):
if pr.number in cache.get(repo.full_name, set()):
dataset.entries.append(cache[repo.full_name][pr.number])
@ -279,7 +281,9 @@ def process_pull(
),
]
pbar = tqdm(total=len(setup_steps) + 6, desc="Processing PR", leave=False)
pbar = tqdm(
total=len(setup_steps) + 6, desc="Processing PR", leave=False, disable=not show_progress
)
for message, action in setup_steps:
pbar.set_postfix(
{
@ -361,6 +365,8 @@ def process_repo(
repos_dir: str,
archive_destination: str,
cache: dict[str, dict[int, DatasetEntry]] = {},
position: int = 1,
show_progress: bool = True,
):
repo = g.get_repo(repo_name)
already_seen_prs = set()
@ -370,7 +376,9 @@ def process_repo(
prs = repo.get_pulls(state="closed")
n_good_prs = 0
with tqdm(total=prs.totalCount, desc="Processing prs", leave=False) as pbar:
with tqdm(
total=prs.totalCount, desc=f"Processing prs of {repo_name}", leave=False, position=position
) as pbar:
for pr in prs:
pbar.set_postfix({"pr": pr.number, "# new good found": n_good_prs})
try:
@ -378,8 +386,10 @@ def process_repo(
continue
n_good_prs += 1
process_pull(repo, pr, dataset, repos_dir, archive_destination, cache)
dataset.to_json(args.output)
process_pull(
repo, pr, dataset, repos_dir, archive_destination, cache, show_progress
)
# dataset.to_json(args.output)
except Exception as e:
tqdm.write(f"[ERROR] PR #{pr.number} in {repo.full_name}. {type(e)}: {e}")
finally:
@ -388,20 +398,25 @@ def process_repo(
# Wrapper to run in each worker process
def process_repo_worker(
repo_name: str, repos_dir: str, archive_destination: str, cache: dict
repo_name: str, repos_dir: str, archive_destination: str, cache: dict, position: int
) -> list:
# Initialize GitHub and Docker clients in each process
token = os.environ.get("GITHUB_AUTH_TOKEN_CRAB")
g_worker = Github(token, seconds_between_requests=0)
docker_client_worker = docker.from_env()
# Local dataset to collect entries for this repo
local_dataset = Dataset()
# Call the existing process_repo, but passing the local GitHub and Docker clients
# You may need to modify process_repo to accept g and docker_client as parameters
process_repo(repo_name, local_dataset, repos_dir, archive_destination, cache)
return local_dataset.entries
try:
process_repo(
repo_name,
local_dataset,
repos_dir,
archive_destination,
cache,
position=position,
show_progress=False,
)
finally:
return local_dataset.entries
def process_repos_parallel(
@ -409,6 +424,7 @@ def process_repos_parallel(
dataset: Dataset,
repos_dir: str,
archive_destination: str,
n_workers: int,
cache: dict[str, dict[int, DatasetEntry]] = {},
):
"""
@ -421,29 +437,62 @@ def process_repos_parallel(
archive_destination: Directory for archives
cache: Optional cache of previously processed PR entries
"""
for pr2entry in tqdm(cache.values(), desc="Adding cache in dataset"):
dataset.entries.extend(pr2entry.values())
dataset.to_json(args.output)
repo_names = df["name"]
# Use all CPUs; adjust max_workers as needed
with concurrent.futures.ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
free_positions = list(range(1, n_workers + 1))
repo_names_iter = iter(repo_names)
future_to_repo = {}
with tqdm(
total=len(repo_names),
desc="Processing repos",
) as outer_pb, ProcessPoolExecutor(max_workers=n_workers) as executor:
# Map each repo to a future
future_to_repo = {
executor.submit(process_repo_worker, name, repos_dir, archive_destination, cache): name
for name in repo_names
}
# Iterate as repos complete
for future in tqdm(
concurrent.futures.as_completed(future_to_repo),
total=len(future_to_repo),
desc="Processing repos",
):
repo_name = future_to_repo[future]
for _ in range(n_workers):
try:
entries = future.result()
dataset.entries.extend(entries)
dataset.to_json(args.output)
except Exception as e:
tqdm.write(f"[ERROR] Repo {repo_name}: {e}")
name = next(repo_names_iter)
except StopIteration:
break
pos = free_positions.pop(0)
fut = executor.submit(
process_repo_worker, name, repos_dir, archive_destination, cache, pos
)
future_to_repo[fut] = (name, pos)
try:
while future_to_repo:
done, _ = wait(future_to_repo, return_when=FIRST_COMPLETED)
for fut in done:
repo_finished, pos = future_to_repo.pop(fut)
entries = fut.result()
dataset.entries.extend(entries)
dataset.to_json(args.output)
outer_pb.update(1)
try:
name = next(repo_names_iter)
except StopIteration:
# no more tasks: free the slot
free_positions.append(pos)
else:
new_fut = executor.submit(
process_repo_worker, name, repos_dir, archive_destination, cache, pos
)
future_to_repo[new_fut] = (name, pos)
except KeyboardInterrupt:
print("Saving all the entries up until now")
# any futures that happen to be done but not yet popped:
for fut in list(future_to_repo):
if fut.done():
try:
dataset.entries.extend(fut.result())
except Exception:
pass
# re-raise so the toplevel finally block still runs
raise
def process_repos(
@ -539,6 +588,12 @@ if __name__ == "__main__":
action="store_true",
help="Caches GitHub API requests in a SQLite file using 'requests_cache' (see optional-requirements.txt). Useful for faster reruns if the script crashes or youre tweaking it. Might produce stale data.",
)
parser.add_argument(
"--max-workers",
metavar="N_WORKERS",
type=int,
help="Parallelize the processing of the repos with the given number of workers. If not given, the script is monothreaded",
)
args = parser.parse_args()
@ -577,6 +632,18 @@ if __name__ == "__main__":
dataset = Dataset()
try:
process_repos(df, dataset, args.repos, args.archive_destination, cache)
if args.max_workers is not None:
process_repos_parallel(
df, dataset, args.repos, args.archive_destination, args.max_workers, cache
)
else:
process_repos(df, dataset, args.repos, args.archive_destination, cache)
finally:
print("")
print("")
print("")
print("")
print("")
print("")
print(f"Writing dataset to {args.output}")
dataset.to_json(args.output)