long-code-arena / src /submission_uploader.py
saridormi's picture
Remove previous attempt at adding folders, as it doesn't work :( also, remove dataset name from exceptions
afb227e
raw
history blame
13.2 kB
import json
import logging
import os
import time
from tempfile import TemporaryDirectory
from typing import Dict, List, Optional
import jsonlines
from huggingface_hub import CommitOperationAdd # type: ignore[import]
from huggingface_hub import Discussion, HfApi, HfFileSystem
from tqdm import tqdm
from .evaluation import METRICS
from .formatting import styled_error, styled_message, styled_warning
from .tasks_content import TASKS_PRETTY_REVERSE
class AlreadyExists(Exception):
pass
class SubmissionUploader:
"""Class for adding new files to a dataset on a Hub and opening a PR.
Heavily influenced by these amazing spaces:
* https://huggingface.co/spaces/safetensors/convert
* https://huggingface.co/spaces/gaia-benchmark/leaderboard
"""
def __init__(self, dataset_id: str, private_dataset_id: str):
self._api = HfApi(token=os.environ["HF_TOKEN"])
self._fs = HfFileSystem(token=os.environ["HF_TOKEN"])
self._dataset_id = dataset_id
self._private_dataset_id = private_dataset_id
def _get_previous_pr(self, pr_title: str) -> Optional[Discussion]:
"""Searches among discussions of dataset repo for a PR with the given title."""
try:
discussions = self._api.get_repo_discussions(repo_id=self._dataset_id, repo_type="dataset")
except Exception:
return None
for discussion in discussions:
if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
return discussion
return None
def _get_metadata(
self,
model_name_pretty: str,
model_availability: str,
urls: Optional[str],
context_size: str,
submitted_by: str,
) -> Dict[str, Optional[str]]:
return {
"model_name": model_name_pretty,
"model_availability": model_availability,
"urls": urls,
"context_size": context_size,
"submitted_by": submitted_by,
}
def _upload_request(
self,
task_id: str,
model_folder: str,
model_name_pretty: str,
model_availability: str,
urls: Optional[str],
context_size: str,
submitted_by: str,
contact_information: str,
comment: Optional[str],
pr_url: str,
temp_directory: str,
) -> List[CommitOperationAdd]:
request_metadata = {
"model_folder": model_folder,
"model_name_pretty": model_name_pretty,
"model_availability": model_availability,
"urls": urls,
"context_size": context_size,
"submitted_by": submitted_by,
"contact_information": contact_information,
"comment": comment,
"timestamp": time.time(),
"pr_url": pr_url,
}
with open(os.path.join(temp_directory, "request_metadata.json"), "w") as f:
json.dump(request_metadata, f)
num_requests_already_present = len(self._fs.ls(f"datasets/{self._private_dataset_id}/{task_id}/"))
commit_operations = [
CommitOperationAdd(
path_in_repo=f"{task_id}/{num_requests_already_present}_{model_folder}.json",
path_or_fileobj=os.path.join(temp_directory, "request_metadata.json"),
)
]
return commit_operations
def _upload_predictions(
self,
task_id: str,
model_folder: str,
filenames: List[str],
) -> List[CommitOperationAdd]:
commit_operations = [
CommitOperationAdd(
path_in_repo=f"{task_id}/predictions/{model_folder}/{os.path.basename(filename)}",
path_or_fileobj=filename,
)
for filename in filenames
]
return commit_operations
def _compute_metrics_for_predictions(self, task_id: str, filenames: List[str], temp_directory: str) -> None:
metrics_module = METRICS[task_id]
assert metrics_module is not None, f"Computing metrics for {task_id} is not supported."
metrics_module.reset()
open(os.path.join(temp_directory, "metrics.jsonl"), "w").close()
# compute the metrics for each submitted file
for filename in filenames:
with jsonlines.open(filename, "r") as reader:
for example in tqdm(reader, desc=f"Computing metrics for {os.path.basename(filename)}"):
metrics_module.add_batch(
predictions=[example["prediction"]],
references=[example["reference"]],
)
computed_metrics = metrics_module.compute()
metrics_module.reset()
with jsonlines.open(os.path.join(temp_directory, "metrics.jsonl"), "a") as writer:
writer.write(computed_metrics)
# aggregate the metrics over submitted files
with jsonlines.open(os.path.join(temp_directory, "metrics.jsonl"), "r") as reader:
metrics_results = [line for line in reader]
final_metrics_results = {
key: sum(entry[key] for entry in metrics_results) / len(metrics_results) for key in metrics_results[0]
}
with open(os.path.join(temp_directory, "final_metrics.json"), "w") as f:
json.dump(final_metrics_results, f)
def _upload_results(
self,
task_id: str,
model_folder: str,
model_name_pretty: str,
model_availability: str,
urls: Optional[str],
context_size: str,
submitted_by: str,
temp_directory: str,
) -> List[CommitOperationAdd]:
final_results = {}
with open(os.path.join(temp_directory, "final_metrics.json"), "r") as f:
metrics = json.load(f)
final_results.update(metrics)
metadata_dict = self._get_metadata(
model_name_pretty=model_name_pretty,
model_availability=model_availability,
urls=urls,
context_size=context_size,
submitted_by=submitted_by,
)
final_results.update(metadata_dict)
with jsonlines.open(os.path.join(temp_directory, "final_results.jsonl"), "w") as writer:
writer.write(final_results)
return [
CommitOperationAdd(
path_in_repo=f"{task_id}/results/{model_folder}.jsonl",
path_or_fileobj=os.path.join(temp_directory, "final_results.jsonl"),
)
]
def _verify_arguments(
self,
task_pretty: str,
model_folder: str,
model_name_pretty: str,
model_availability: str,
urls: Optional[str],
context_size: str,
submitted_by: str,
contact_information: str,
comment: Optional[str],
filenames: Optional[List[str]],
):
assert task_pretty and task_pretty in TASKS_PRETTY_REVERSE, "Please, select one of the supported tasks."
assert model_folder, "Please, specify non-empty name for a directory with a model's results."
assert model_name_pretty, "Please, specify non-empty name for a model."
assert model_availability, "Please, specify non-empty information about a model's availability."
assert context_size, "Please, specify non-empty information about a model's context size."
try:
_ = int(context_size)
except:
raise ValueError("Please, specify a model's context size as an integer (e.g., 16000).")
assert submitted_by, "Please, specify non-empty information about a submission's author(s)."
assert filenames, "Please, attach at least one file with predictions."
assert contact_information, "Please, fill in the field with contact information."
def upload_files(
self,
task_pretty: str,
model_folder: str,
model_name_pretty: str,
model_availability: str,
urls: Optional[str],
context_size: str,
submitted_by: str,
contact_information: str,
comment: Optional[str],
filenames: Optional[List[str]],
force: bool = False,
) -> str:
try:
self._verify_arguments(
task_pretty=task_pretty,
model_folder=model_folder,
model_name_pretty=model_name_pretty,
model_availability=model_availability,
urls=urls,
context_size=context_size,
submitted_by=submitted_by,
contact_information=contact_information,
comment=comment,
filenames=filenames,
)
pr_title = f"πŸš€ New submission to {task_pretty} task: {model_name_pretty} with {context_size} context size from {submitted_by}"
logging.info(f"Start processing {pr_title}")
task_id = TASKS_PRETTY_REVERSE[task_pretty]
logging.info("Checking if this request has already been submitted...")
if not force:
if model_folder in self._fs.ls(f"datasets/{self._dataset_id}/{task_id}/predictions"):
return styled_warning(
f"{model_folder} is already present in {self._dataset_id}, please, select another folder name."
)
prev_pr = self._get_previous_pr(pr_title)
if prev_pr is not None:
url = f"https://huggingface.co/datasets/{self._dataset_id}/discussions/{prev_pr.num}"
return styled_warning(f"{self._dataset_id} already has an open PR for this submission: {url}.")
logging.info("Processing predictions...")
predictions_commit_operations = self._upload_predictions(
task_id=task_id,
model_folder=model_folder,
filenames=filenames,
)
with TemporaryDirectory() as d:
logging.info("Computing metrics...")
self._compute_metrics_for_predictions(task_id=task_id, filenames=filenames, temp_directory=str(d))
logging.info("Processing results...")
results_commit_operations = self._upload_results(
task_id=task_id,
model_folder=model_folder,
model_name_pretty=model_name_pretty,
model_availability=model_availability,
urls=urls,
context_size=context_size,
submitted_by=submitted_by,
temp_directory=str(d),
)
logging.info(f"Creating commit to results dataset...")
new_pr = self._api.create_commit(
repo_id=self._dataset_id,
operations=predictions_commit_operations + results_commit_operations,
commit_message=pr_title,
commit_description=f"""New submission to {task_pretty} task in 🏟️ Long Code Arena benchmark!\n* Model name: {model_name_pretty}\n* Model availability: {model_availability}\n* Context Size: {context_size}\n* Relevant URLs: {urls}\n* Submitted By: {submitted_by}""",
create_pr=True,
repo_type="dataset",
)
logging.info(f"Creating commit to requests dataset...")
request_commit_operations = self._upload_request(
task_id=task_id,
model_folder=model_folder,
temp_directory=str(d),
model_name_pretty=model_name_pretty,
model_availability=model_availability,
urls=urls,
context_size=context_size,
submitted_by=submitted_by,
contact_information=contact_information,
comment=comment,
pr_url=new_pr.pr_url,
)
self._api.create_commit(
repo_id=self._private_dataset_id,
operations=request_commit_operations,
commit_message=pr_title,
commit_description=f"""New submission to {task_pretty} task in 🏟️ Long Code Arena benchmark!\n* Model name: {model_name_pretty}\n* Model availability: {model_availability}\n* Context Size: {context_size}\n* Relevant URLs: {urls}\n* Submitted By: {submitted_by}\n* PR: {new_pr.pr_url}\n* Contact information: {contact_information}\n* Comment: {comment}""",
create_pr=True,
repo_type="dataset",
)
return styled_message(f"πŸŽ‰ PR created at {new_pr.pr_url}.")
except Exception as e:
logging.exception(e)
exception_msg = str(e)
if exception_msg and os.environ["PRIVATE_DATASET_ID"] in exception_msg:
exception_msg = exception_msg.replace(os.environ["PRIVATE_DATASET_ID"], "{private_dataset}")
if exception_msg:
return styled_error(exception_msg)
return styled_error("An exception occurred. Please, try again.")