comparator / src /details.py
albertvillanova's picture
Fix wrapping to keep non-str data
e3edf6d verified
raw
history blame
3.43 kB
import asyncio
import gradio as gr
import pandas as pd
from huggingface_hub import HfFileSystem
import src.constants as constants
from src.hub import load_details_file
def update_task_description_component(task):
base_description = constants.TASK_DESCRIPTIONS.get(task, "")
additional_info = "A higher score is a better score."
description = f"{base_description}\n\n{additional_info}" if base_description else additional_info
return gr.Textbox(
description,
label="Task Description",
lines=5,
visible=True,
)
def update_subtasks_component(task):
return gr.Radio(
constants.SUBTASKS.get(task),
info="Evaluation subtasks to be loaded",
value=None,
)
def update_load_details_component(model_id_1, model_id_2, subtask):
if (model_id_1 or model_id_2) and subtask:
return gr.Button("Load Details", interactive=True)
else:
return gr.Button("Load Details", interactive=False)
async def load_details_dataframe(model_id, subtask):
fs = HfFileSystem()
if not model_id or not subtask:
return
model_name_sanitized = model_id.replace("/", "__")
paths = fs.glob(
f"{constants.DETAILS_DATASET_ID}/**/{constants.DETAILS_FILENAME}".format(
model_name_sanitized=model_name_sanitized, subtask=subtask
)
)
if not paths:
return
path = max(paths)
data = await load_details_file(path)
df = pd.json_normalize(data)
# df = df.rename_axis("Parameters", axis="columns")
df["model_name"] = model_id # Keep model_name
return df
# return df.set_index(pd.Index([model_id])).reset_index()
async def load_details_dataframes(subtask, *model_ids):
result = await asyncio.gather(*[load_details_dataframe(model_id, subtask) for model_id in model_ids])
return result
def display_details(sample_idx, *dfs):
rows = [df.iloc[sample_idx] for df in dfs if "model_name" in df.columns and sample_idx < len(df)]
if not rows:
return
# Pop model_name and add it to the column name
df = pd.concat([row.rename(row.pop("model_name")) for row in rows], axis="columns")
# Wrap long strings to avoid overflow; e.g. URLs in "doc.Websites visited_NEV_2"
def wrap(row):
try:
result = row.str.wrap(140)
return result if result.notna().all() else row # NaN when data is a list
except AttributeError: # when data is number
return row
df = df.apply(wrap, axis=1)
# Style
return (
df.style
.format(escape="html", na_rep="")
# .hide(axis="index")
.to_html()
)
def update_sample_idx_component(*dfs):
maximum = max([len(df) - 1 for df in dfs])
return gr.Number(
label="Sample Index",
info="Index of the sample to be displayed",
value=0,
minimum=0,
maximum=maximum,
visible=True,
)
def clear_details():
# model_id_1, model_id_2, details_dataframe_1, details_dataframe_2, details_task, subtask, load_details_btn, sample_idx
return (
None, None, None, None, None, None,
gr.Button("Load Details", interactive=False),
gr.Number(label="Sample Index", info="Index of the sample to be displayed", value=0, minimum=0,visible=False),
)
def display_loading_message_for_details():
return "<h3 style='text-align: center;'>Loading...</h3>"