Spaces:
Sleeping
Sleeping
import collections | |
import logging | |
import threading | |
import uuid | |
import datasets | |
import gradio as gr | |
import pandas as pd | |
import leaderboard | |
from io_utils import read_column_mapping, write_column_mapping | |
from run_jobs import save_job_to_pipe | |
from text_classification import ( | |
check_model_task, | |
get_example_prediction, | |
get_labels_and_features_from_dataset, | |
) | |
from wordings import ( | |
CHECK_CONFIG_OR_SPLIT_RAW, | |
CONFIRM_MAPPING_DETAILS_FAIL_RAW, | |
MAPPING_STYLED_ERROR_WARNING, | |
get_styled_input, | |
) | |
MAX_LABELS = 40 | |
MAX_FEATURES = 20 | |
ds_dict = None | |
ds_config = None | |
def get_related_datasets_from_leaderboard(model_id): | |
records = leaderboard.records | |
model_records = records[records["model_id"] == model_id] | |
datasets_unique = list(model_records["dataset_id"].unique()) | |
if len(datasets_unique) == 0: | |
all_unique_datasets = list(records["dataset_id"].unique()) | |
return gr.update(choices=all_unique_datasets, value="") | |
return gr.update(choices=datasets_unique, value=datasets_unique[0]) | |
logger = logging.getLogger(__file__) | |
def check_dataset(dataset_id): | |
logger.info(f"Loading {dataset_id}") | |
try: | |
configs = datasets.get_dataset_config_names(dataset_id) | |
if len(configs) == 0: | |
return ( | |
gr.update(), | |
gr.update(), | |
"" | |
) | |
splits = list( | |
datasets.load_dataset( | |
dataset_id, configs[0] | |
).keys() | |
) | |
return ( | |
gr.update(choices=configs, value=configs[0], visible=True), | |
gr.update(choices=splits, value=splits[0], visible=True), | |
"" | |
) | |
except Exception as e: | |
logger.warn(f"Check your dataset {dataset_id}: {e}") | |
return ( | |
gr.update(), | |
gr.update(), | |
"" | |
) | |
def write_column_mapping_to_config(uid, *labels): | |
# TODO: Substitute 'text' with more features for zero-shot | |
# we are not using ds features because we only support "text" for now | |
all_mappings = read_column_mapping(uid) | |
if labels is None: | |
return | |
all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS]) | |
all_mappings = export_mappings( | |
all_mappings, | |
"features", | |
["text"], | |
labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)], | |
) | |
write_column_mapping(all_mappings, uid) | |
def export_mappings(all_mappings, key, subkeys, values): | |
if key not in all_mappings.keys(): | |
all_mappings[key] = dict() | |
if subkeys is None: | |
subkeys = list(all_mappings[key].keys()) | |
if not subkeys: | |
logging.debug(f"subkeys is empty for {key}") | |
return all_mappings | |
for i, subkey in enumerate(subkeys): | |
if subkey: | |
all_mappings[key][subkey] = values[i % len(values)] | |
return all_mappings | |
def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels, uid): | |
all_mappings = read_column_mapping(uid) | |
# For flattened raw datasets with no labels | |
# check if there are shared labels between model and dataset | |
shared_labels = set(model_labels).intersection(set(ds_labels)) | |
if shared_labels: | |
ds_labels = list(shared_labels) | |
if len(ds_labels) > MAX_LABELS: | |
ds_labels = ds_labels[:MAX_LABELS] | |
gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}") | |
ds_labels.sort() | |
model_labels.sort() | |
lables = [ | |
gr.Dropdown( | |
label=f"{label}", | |
choices=model_labels, | |
value=model_labels[i % len(model_labels)], | |
interactive=True, | |
visible=True, | |
) | |
for i, label in enumerate(ds_labels) | |
] | |
lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))] | |
all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels) | |
# TODO: Substitute 'text' with more features for zero-shot | |
features = [ | |
gr.Dropdown( | |
label=f"{feature}", | |
choices=ds_features, | |
value=ds_features[0], | |
interactive=True, | |
visible=True, | |
) | |
for feature in ["text"] | |
] | |
features += [ | |
gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features)) | |
] | |
all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features) | |
write_column_mapping(all_mappings, uid) | |
return lables + features | |
def precheck_model_ds_enable_example_btn( | |
model_id, dataset_id, dataset_config, dataset_split | |
): | |
model_task = check_model_task(model_id) | |
if model_task is None or model_task != "text-classification": | |
gr.Warning("Please check your model.") | |
return gr.update(interactive=False), "" | |
if dataset_config is None or dataset_split is None or len(dataset_config) == 0: | |
return (gr.update(), gr.update(), "") | |
try: | |
ds = datasets.load_dataset(dataset_id, dataset_config) | |
df: pd.DataFrame = ds[dataset_split].to_pandas().head(5) | |
ds_labels, ds_features = get_labels_and_features_from_dataset(ds[dataset_split]) | |
if not isinstance(ds_labels, list) or not isinstance(ds_features, list): | |
gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) | |
return (gr.update(interactive=False), gr.update(value=df, visible=True), "") | |
return (gr.update(interactive=True), gr.update(value=df, visible=True), "") | |
except Exception as e: | |
# Config or split wrong | |
gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}") | |
return (gr.update(interactive=False), gr.update(value=pd.DataFrame(), visible=False), "") | |
def align_columns_and_show_prediction( | |
model_id, | |
dataset_id, | |
dataset_config, | |
dataset_split, | |
uid, | |
run_inference, | |
inference_token, | |
): | |
model_task = check_model_task(model_id) | |
if model_task is None or model_task != "text-classification": | |
gr.Warning("Please check your model.") | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False, open=False), | |
gr.update(interactive=False), | |
"", | |
*[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)], | |
) | |
dropdown_placement = [ | |
gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES) | |
] | |
# FIXME: prefiction_output could be None | |
prediction_input, prediction_output = get_example_prediction( | |
model_id, dataset_id, dataset_config, dataset_split | |
) | |
model_labels = list(prediction_output.keys()) | |
ds = datasets.load_dataset(dataset_id, dataset_config)[dataset_split] | |
ds_labels, ds_features = get_labels_and_features_from_dataset(ds) | |
# when dataset does not have labels or features | |
if not isinstance(ds_labels, list) or not isinstance(ds_features, list): | |
gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False, open=False), | |
gr.update(interactive=False), | |
"", | |
*dropdown_placement, | |
) | |
column_mappings = list_labels_and_features_from_dataset( | |
ds_labels, | |
ds_features, | |
model_labels, | |
uid, | |
) | |
# when labels or features are not aligned | |
# show manually column mapping | |
if ( | |
collections.Counter(model_labels) != collections.Counter(ds_labels) | |
or ds_features[0] != "text" | |
): | |
return ( | |
gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True), | |
gr.update(visible=False), | |
gr.update(visible=True, open=True), | |
gr.update(interactive=(run_inference and inference_token != "")), | |
"", | |
*column_mappings, | |
) | |
return ( | |
gr.update(value=get_styled_input(prediction_input), visible=True), | |
gr.update(value=prediction_output, visible=True), | |
gr.update(visible=True, open=False), | |
gr.update(interactive=(run_inference and inference_token != "")), | |
"", | |
*column_mappings, | |
) | |
def check_column_mapping_keys_validity(all_mappings): | |
if all_mappings is None: | |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
return (gr.update(interactive=True), gr.update(visible=False)) | |
if "labels" not in all_mappings.keys(): | |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
return (gr.update(interactive=True), gr.update(visible=False)) | |
def construct_label_and_feature_mapping(all_mappings): | |
label_mapping = {} | |
for i, label in zip( | |
range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys() | |
): | |
label_mapping.update({str(i): label}) | |
if "features" not in all_mappings.keys(): | |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
return (gr.update(interactive=True), gr.update(visible=False)) | |
feature_mapping = all_mappings["features"] | |
return label_mapping, feature_mapping | |
def try_submit(m_id, d_id, config, split, inference, inference_token, uid): | |
all_mappings = read_column_mapping(uid) | |
check_column_mapping_keys_validity(all_mappings) | |
label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings) | |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>" | |
save_job_to_pipe( | |
uid, | |
( | |
m_id, | |
d_id, | |
config, | |
split, | |
inference, | |
inference_token, | |
uid, | |
label_mapping, | |
feature_mapping, | |
), | |
eval_str, | |
threading.Lock(), | |
) | |
gr.Info("Your evaluation is submitted") | |
return ( | |
gr.update(interactive=False), # Submit button | |
gr.update(lines=5, visible=True, interactive=False), | |
uuid.uuid4(), # Allocate a new uuid | |
) | |