long-code-arena / src /get_results_for_task.py
jenyag's picture
Update src/get_results_for_task.py
1ca1e4b verified
raw
history blame
3.99 kB
import logging
import os
import re
import pandas as pd # type: ignore[import]
from datasets import get_dataset_config_names, load_dataset # type: ignore[import]
from .formatting import model_hyperlink
from .leaderboard_formatting import (
COLUMNS_PRETTY,
METRICS_PER_TASK,
SORT_COLUMN_PER_TASK,
get_columns_per_task,
)
from .tasks_content import TASKS_PRETTY_REVERSE
from .utils import MD_LINK_PATTERN
try:
AVAILABLE_TASKS = get_dataset_config_names(os.environ["DATASET_ID"])
except FileNotFoundError as e:
AVAILABLE_TASKS = []
logging.warning("Dataset is not available! Check if token is expired.")
AVAILABLE_TASKS_STR = ' ; '.join(AVAILABLE_TASKS)
logging.warning(f"Available tasks: {AVAILABLE_TASKS_STR}")
def _get_results_stub() -> pd.DataFrame:
stub_df = pd.DataFrame(
[
{
"Model Name": "GPT-4",
"Availability": "Proprietary",
"Context Size": "16k",
"BLEU": "X",
"ROUGE": "X",
"ChrF": "X",
"BERTScore": "X",
"BERTScore (Normalized)": "X",
"Submitted By": "🏟 Long Code Arena Team",
"Resources": "",
},
{
"Model Name": "CodeLlama-7b (instruct)",
"Availability": "Llama 2 license",
"Context Size": "16k",
"BLEU": "X",
"ROUGE": "X",
"ChrF": "X",
"BERTScore": "X",
"BERTScore (Normalized)": "X",
"Submitted By": "🏟 Long Code Arena Team",
"Resources": "",
},
]
)
return stub_df
def _process_urls(raw_urls: str) -> str:
if not raw_urls:
return raw_urls
html_urls = [model_hyperlink(*re.search(MD_LINK_PATTERN, url.strip()).groups()) for url in raw_urls.split(",")]
return ", ".join(html_urls)
def _extract_dataset_name(raw_urls: str) -> str:
if not raw_urls:
return raw_urls
names = [re.search(MD_LINK_PATTERN, url.strip()).group(1) + ' context' for url in raw_urls.split(",")]
return ", ".join(names)
def _get_results_dataset(task_id: str) -> pd.DataFrame:
logging.info(f"Loading dataset: {task_id}...")
results_df = load_dataset(
os.environ["DATASET_ID"], task_id, split="test", download_mode="force_redownload"
).to_pandas()
results_df = results_df.rename(columns=COLUMNS_PRETTY, errors="ignore")
results_df["Context Size"] = results_df["Context Size"].map(lambda x: f"{int(x) // 1000}k" if int(x) >= 1000 else x)
results_df = results_df.sort_values(by=SORT_COLUMN_PER_TASK[task_id], ascending=False)
for metric_column in METRICS_PER_TASK[task_id]:
if "BERTScore" in metric_column:
results_df[metric_column] = results_df[metric_column].map(lambda x: f"{x:.5f}")
else:
results_df[metric_column] = results_df[metric_column].map(lambda x: f"{x:.2f}")
results_df["Model Name"] = [
model_hyperlink(link=link, model_name=model_name) if link else model_name
for link, model_name in zip(results_df["model_url"], results_df["Model Name"])
]
if task_id == 'project_code_completion':
results_df["Dataset Name"] = [_extract_dataset_name(urls) for urls in results_df["Dataset"]]
results_df["Dataset"] = [_process_urls(urls) for urls in results_df["Dataset"]]
results_df["Resources"] = [_process_urls(urls) for urls in results_df["Resources"]]
results_df = results_df[get_columns_per_task(task_id)]
return results_df
def get_results_for_task(task_pretty: str) -> pd.DataFrame:
task_id = TASKS_PRETTY_REVERSE[task_pretty]
if task_id in AVAILABLE_TASKS:
logging.info(f"Retrieving results for {task_pretty}...")
return _get_results_dataset(task_id)
logging.info(f"Generating leaderboard stub for {task_pretty}...")
return _get_results_stub()