lewtun's picture
lewtun HF staff
Add template files
958473e
raw history blame
No virus
3.58 kB
import datetime
import re
import subprocess
from pathlib import Path
import pandas as pd
import typer
from datasets import get_dataset_config_names, load_dataset
CSV_SCHEMA = {
"banking_77": (5000, 2),
"overruling": (2350, 2),
"semiconductor_org_types": (449, 2),
"ade_corpus_v2": (5000, 2),
"twitter_complaints": (3399, 2),
"neurips_impact_statement_risks": (150, 2),
"systematic_review_inclusion": (2244, 2),
"terms_of_service": (5000, 2),
"tai_safety_research": (1639, 2),
"one_stop_english": (518, 2),
"tweet_eval_hate": (2966, 2),
}
app = typer.Typer()
def _update_submission_name(submission_name: str):
replacement = ""
with open("README.md", "r") as f:
lines = f.readlines()
for line in lines:
if line.startswith("submission_name:"):
changes = re.sub(r"submission_name:.+", f"submission_name: {submission_name}", line)
replacement += changes
else:
replacement += line
with open("README.md", "w") as f:
f.write(replacement)
@app.command()
def validate():
# TODO(lewtun): Consider using great_expectations for the data validation
tasks = get_dataset_config_names("ought/raft")
# Check that all the expected files exist
prediction_files = list(Path("data").rglob("predictions.csv"))
mismatched_files = set(tasks).symmetric_difference(set([f.parent.name for f in prediction_files]))
if mismatched_files:
raise ValueError(f"Incorrect number of files! Expected {len(tasks)} files, but got {len(prediction_files)}.")
# Check all files have the expected shape (number of rows, number of columns)
# TODO(lewtun): Add a check for the specific IDs per file
shape_errors = []
column_errors = []
for prediction_file in prediction_files:
df = pd.read_csv(prediction_file)
incorrect_shape = df.shape != CSV_SCHEMA[prediction_file.parent.name]
if incorrect_shape:
shape_errors.append(prediction_file)
incorrect_columns = sorted(df.columns) != ["ID", "Label"]
if incorrect_columns:
column_errors.append(prediction_file)
if shape_errors:
raise ValueError(f"Incorrect CSV shapes in files: {shape_errors}")
if column_errors:
raise ValueError(f"Incorrect CSV columns in files: {column_errors}")
# Check we can load the dataset for each task
load_errors = []
for task in tasks:
try:
_ = load_dataset("../{{cookiecutter.repo_name}}", task)
except Exception as e:
load_errors.append(e)
if load_errors:
raise ValueError(f"Could not load predictions! Errors: {load_errors}")
typer.echo("All submission files validated! ✨ πŸš€ ✨")
typer.echo("Now you can make a submission πŸ€—")
@app.command()
def submit(submission_name: str = typer.Option(..., prompt="Please provide a name for your submission, e.g. GPT-4 😁")):
subprocess.call("git pull origin main".split())
_update_submission_name(submission_name)
subprocess.call(["git", "add", "data/*predictions.csv", "README.md"])
subprocess.call(["git", "commit", "-m", f"Submission: {submission_name} "])
subprocess.call(["git", "push"])
today = datetime.date.today()
# MON = 0, SUN = 6 -> SUN = 0 .. SAT = 6
idx = (today.weekday() + 1) % 7
sun = today + datetime.timedelta(7 - idx)
typer.echo("Submission successful! πŸŽ‰ πŸ₯³ πŸŽ‰")
typer.echo(f"Your submission will be evaulated on {sun:%A %d %B %Y} ⏳")
if __name__ == "__main__":
app()