|
import logging |
|
import os |
|
import re |
|
|
|
import pandas as pd |
|
from datasets import get_dataset_config_names, load_dataset |
|
|
|
from .formatting import model_hyperlink |
|
from .leaderboard_formatting import ( |
|
COLUMNS_PRETTY, |
|
METRICS_PER_TASK, |
|
SORT_COLUMN_PER_TASK, |
|
get_columns_per_task, |
|
) |
|
from .tasks_content import TASKS_PRETTY_REVERSE |
|
from .utils import MD_LINK_PATTERN |
|
|
|
try: |
|
AVAILABLE_TASKS = get_dataset_config_names(os.environ["DATASET_ID"]) |
|
except FileNotFoundError as e: |
|
AVAILABLE_TASKS = [] |
|
logging.warning("Dataset is not available! Check if token is expired.") |
|
|
|
AVAILABLE_TASKS_STR = ' ; '.join(AVAILABLE_TASKS) |
|
logging.warning(f"Available tasks: {AVAILABLE_TASKS_STR}") |
|
|
|
|
|
def _get_results_stub() -> pd.DataFrame: |
|
stub_df = pd.DataFrame( |
|
[ |
|
{ |
|
"Model Name": "GPT-4", |
|
"Availability": "Proprietary", |
|
"Context Size": "16k", |
|
"BLEU": "X", |
|
"ROUGE": "X", |
|
"ChrF": "X", |
|
"BERTScore": "X", |
|
"BERTScore (Normalized)": "X", |
|
"Submitted By": "π Long Code Arena Team", |
|
"Resources": "", |
|
}, |
|
{ |
|
"Model Name": "CodeLlama-7b (instruct)", |
|
"Availability": "Llama 2 license", |
|
"Context Size": "16k", |
|
"BLEU": "X", |
|
"ROUGE": "X", |
|
"ChrF": "X", |
|
"BERTScore": "X", |
|
"BERTScore (Normalized)": "X", |
|
"Submitted By": "π Long Code Arena Team", |
|
"Resources": "", |
|
}, |
|
] |
|
) |
|
return stub_df |
|
|
|
|
|
def _process_urls(raw_urls: str) -> str: |
|
if not raw_urls: |
|
return raw_urls |
|
html_urls = [model_hyperlink(*re.search(MD_LINK_PATTERN, url.strip()).groups()) for url in raw_urls.split(",")] |
|
return ", ".join(html_urls) |
|
|
|
def _extract_dataset_name(raw_urls: str) -> str: |
|
if not raw_urls: |
|
return raw_urls |
|
names = [re.search(MD_LINK_PATTERN, url.strip()).group(1) + ' context' for url in raw_urls.split(",")] |
|
return ", ".join(names) |
|
|
|
|
|
|
|
def _get_results_dataset(task_id: str) -> pd.DataFrame: |
|
logging.info(f"Loading dataset: {task_id}...") |
|
results_df = load_dataset( |
|
os.environ["DATASET_ID"], task_id, split="test", download_mode="force_redownload" |
|
).to_pandas() |
|
results_df = results_df.rename(columns=COLUMNS_PRETTY, errors="ignore") |
|
results_df["Context Size"] = results_df["Context Size"].map(lambda x: f"{int(x) // 1000}k" if int(x) >= 1000 else x) |
|
|
|
results_df = results_df.sort_values(by=SORT_COLUMN_PER_TASK[task_id], ascending=False) |
|
|
|
for metric_column in METRICS_PER_TASK[task_id]: |
|
if "BERTScore" in metric_column: |
|
results_df[metric_column] = results_df[metric_column].map(lambda x: f"{x:.5f}") |
|
else: |
|
results_df[metric_column] = results_df[metric_column].map(lambda x: f"{x:.2f}") |
|
|
|
results_df["Model Name"] = [ |
|
model_hyperlink(link=link, model_name=model_name) if link else model_name |
|
for link, model_name in zip(results_df["model_url"], results_df["Model Name"]) |
|
] |
|
if task_id == 'project_code_completion': |
|
results_df["Dataset Name"] = [_extract_dataset_name(urls) for urls in results_df["Dataset"]] |
|
results_df["Dataset"] = [_process_urls(urls) for urls in results_df["Dataset"]] |
|
results_df["Resources"] = [_process_urls(urls) for urls in results_df["Resources"]] |
|
results_df = results_df[get_columns_per_task(task_id)] |
|
return results_df |
|
|
|
|
|
def get_results_for_task(task_pretty: str) -> pd.DataFrame: |
|
task_id = TASKS_PRETTY_REVERSE[task_pretty] |
|
if task_id in AVAILABLE_TASKS: |
|
logging.info(f"Retrieving results for {task_pretty}...") |
|
return _get_results_dataset(task_id) |
|
logging.info(f"Generating leaderboard stub for {task_pretty}...") |
|
return _get_results_stub() |
|
|