giskard-evaluator / text_classification_ui_helpers.py
inoki-giskard's picture
Early return for emprt dataset id
4958a71
import collections
import logging
import threading
import uuid
import datasets
import gradio as gr
import pandas as pd
import leaderboard
from io_utils import (
read_column_mapping,
write_column_mapping,
read_scanners,
write_scanners,
)
from run_jobs import save_job_to_pipe
from text_classification import (
check_model_task,
preload_hf_inference_api,
get_example_prediction,
get_labels_and_features_from_dataset,
check_hf_token_validity,
HuggingFaceInferenceAPIResponse,
)
from wordings import (
EXAMPLE_MODEL_ID,
CHECK_CONFIG_OR_SPLIT_RAW,
CONFIRM_MAPPING_DETAILS_FAIL_RAW,
MAPPING_STYLED_ERROR_WARNING,
NOT_FOUND_DATASET_RAW,
NOT_FOUND_MODEL_RAW,
NOT_TEXT_CLASSIFICATION_MODEL_RAW,
UNMATCHED_MODEL_DATASET_STYLED_ERROR,
CHECK_LOG_SECTION_RAW,
VALIDATED_MODEL_DATASET_STYLED,
get_dataset_fetch_error_raw,
)
import os
from app_env import HF_WRITE_TOKEN
MAX_LABELS = 40
MAX_FEATURES = 20
ds_dict = None
ds_config = None
def get_related_datasets_from_leaderboard(model_id, dataset_id_input):
records = leaderboard.records
model_records = records[records["model_id"] == model_id]
datasets_unique = list(model_records["dataset_id"].unique())
if len(datasets_unique) == 0:
return gr.update(choices=[])
if dataset_id_input in datasets_unique:
return gr.update(choices=datasets_unique)
return gr.update(choices=datasets_unique, value="")
logger = logging.getLogger(__file__)
def get_dataset_splits(dataset_id, dataset_config):
try:
splits = datasets.get_dataset_split_names(
dataset_id, dataset_config, trust_remote_code=True
)
return gr.update(choices=splits, value=splits[0], visible=True)
except Exception as e:
logger.warning(
f"Check your dataset {dataset_id} and config {dataset_config}: {e}"
)
return gr.update(visible=False)
def check_dataset(dataset_id):
logger.info(f"Loading {dataset_id}")
if not dataset_id or len(dataset_id) == 0:
return (gr.update(visible=False), gr.update(visible=False), "")
try:
configs = datasets.get_dataset_config_names(dataset_id, trust_remote_code=True)
if len(configs) == 0:
return (gr.update(visible=False), gr.update(visible=False), "")
splits = datasets.get_dataset_split_names(
dataset_id, configs[0], trust_remote_code=True
)
return (
gr.update(choices=configs, value=configs[0], visible=True),
gr.update(choices=splits, value=splits[0], visible=True),
"",
)
except Exception as e:
logger.warning(f"Check your dataset {dataset_id}: {e}")
if "doesn't exist on the Hub or cannot be accessed" in str(e):
gr.Warning(NOT_FOUND_DATASET_RAW)
elif "forbidden" in str(e).lower():
# GSK-2770: illegal name
gr.Warning(get_dataset_fetch_error_raw(e))
else:
# Unknown error
gr.Warning(get_dataset_fetch_error_raw(e))
return (gr.update(visible=False), gr.update(visible=False), "")
def empty_column_mapping(uid):
write_column_mapping(None, uid)
def write_column_mapping_to_config(uid, *labels):
# TODO: Substitute 'text' with more features for zero-shot
# we are not using ds features because we only support "text" for now
all_mappings = read_column_mapping(uid)
if labels is None:
return
all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS])
all_mappings = export_mappings(
all_mappings,
"features",
["text"],
labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)],
)
write_column_mapping(all_mappings, uid)
def export_mappings(all_mappings, key, subkeys, values):
if key not in all_mappings.keys():
all_mappings[key] = dict()
if subkeys is None:
subkeys = list(all_mappings[key].keys())
if not subkeys:
logging.debug(f"subkeys is empty for {key}")
return all_mappings
for i, subkey in enumerate(subkeys):
if subkey:
all_mappings[key][subkey] = values[i % len(values)]
return all_mappings
def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels, uid):
all_mappings = read_column_mapping(uid)
# For flattened raw datasets with no labels
# check if there are shared labels between model and dataset
shared_labels = set(model_labels).intersection(set(ds_labels))
if shared_labels:
ds_labels = list(shared_labels)
if len(ds_labels) > MAX_LABELS:
ds_labels = ds_labels[:MAX_LABELS]
gr.Warning(
f"Too many labels to display for this spcae. We do not support more than {MAX_LABELS} in this space. You can use cli tool at https://github.com/Giskard-AI/cicd."
)
# sort labels to make sure the order is consistent
# prediction gives the order based on probability
ds_labels.sort()
model_labels.sort()
lables = [
gr.Dropdown(
label=f"{label}",
choices=model_labels,
value=model_labels[i % len(model_labels)],
interactive=True,
visible=True,
)
for i, label in enumerate(ds_labels)
]
lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels)
# TODO: Substitute 'text' with more features for zero-shot
features = [
gr.Dropdown(
label=f"{feature}",
choices=ds_features,
value=ds_features[0],
interactive=True,
visible=True,
)
for feature in ["text"]
]
features += [
gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
]
all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features)
write_column_mapping(all_mappings, uid)
return lables + features
def precheck_model_ds_enable_example_btn(
model_id, dataset_id, dataset_config, dataset_split
):
model_task = check_model_task(model_id)
if not model_task:
# Model might be not found
error_msg_html = f"<p style='color: red;'>{NOT_FOUND_MODEL_RAW}</p>"
if model_id.startswith("http://") or model_id.startswith("https://"):
error_msg = f"Please input your model id, such as {EXAMPLE_MODEL_ID}, instead of URL"
error_msg_html = f"<p style='color: red;'>{error_msg}</p>"
return (
gr.update(interactive=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(value=error_msg_html, visible=True),
)
if model_task != "text-classification":
gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
return (
gr.update(interactive=False),
gr.update(value=df, visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(
value=f"<p style='color: red;'>{NOT_TEXT_CLASSIFICATION_MODEL_RAW}",
visible=True,
),
)
preload_hf_inference_api(model_id)
if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
return (
gr.update(interactive=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
try:
ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
ds_labels, ds_features, _ = get_labels_and_features_from_dataset(
ds[dataset_split]
)
if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
return (
gr.update(interactive=False),
gr.update(value=df, visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
return (
gr.update(interactive=True),
gr.update(value=df, visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
except Exception as e:
# Config or split wrong
logger.warning(
f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}"
)
return (
gr.update(interactive=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
def align_columns_and_show_prediction(
model_id,
dataset_id,
dataset_config,
dataset_split,
uid,
inference_token,
):
model_task = check_model_task(model_id)
if model_task is None or model_task != "text-classification":
gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False, open=False),
gr.update(interactive=False),
"",
*[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)],
)
dropdown_placement = [
gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)
]
hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
prediction_input, prediction_response = get_example_prediction(
model_id, dataset_id, dataset_config, dataset_split, hf_token
)
if prediction_input is None or prediction_response is None:
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False, open=False),
gr.update(interactive=False),
"",
*dropdown_placement,
)
if isinstance(prediction_response, HuggingFaceInferenceAPIResponse):
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False, open=False),
gr.update(interactive=False),
f"Hugging Face Inference API is loading your model. {prediction_response.message}",
*dropdown_placement,
)
model_labels = list(prediction_response.keys())
ds = datasets.load_dataset(
dataset_id, dataset_config, split=dataset_split, trust_remote_code=True
)
ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds)
# when dataset does not have labels or features
if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False, open=False),
gr.update(interactive=False),
"",
*dropdown_placement,
)
if len(ds_labels) != len(model_labels):
return (
gr.update(value=UNMATCHED_MODEL_DATASET_STYLED_ERROR, visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False, open=False),
gr.update(interactive=False),
"",
*dropdown_placement,
)
column_mappings = list_labels_and_features_from_dataset(
ds_labels,
ds_features,
model_labels,
uid,
)
# when labels or features are not aligned
# show manually column mapping
if (
collections.Counter(model_labels) != collections.Counter(ds_labels)
or ds_features[0] != "text"
):
return (
gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
gr.update(
value=prediction_input,
lines=min(len(prediction_input) // 225 + 1, 5),
visible=True,
),
gr.update(value=prediction_response, visible=True),
gr.update(visible=True, open=True),
gr.update(interactive=(inference_token != "")),
"",
*column_mappings,
)
return (
gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
gr.update(
value=prediction_input,
lines=min(len(prediction_input) // 225 + 1, 5),
visible=True,
),
gr.update(value=prediction_response, visible=True),
gr.update(visible=True, open=False),
gr.update(interactive=(inference_token != "")),
"",
*column_mappings,
)
def check_column_mapping_keys_validity(all_mappings):
if all_mappings is None:
logger.warning("all_mapping is None")
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
return False
if "labels" not in all_mappings.keys():
logger.warning(f"Label mapping is not valid, all_mappings: {all_mappings}")
return False
return True
def enable_run_btn(
uid, inference_token, model_id, dataset_id, dataset_config, dataset_split
):
if inference_token == "":
logger.warning("Inference API is not enabled")
return gr.update(interactive=False)
if (
model_id == ""
or dataset_id == ""
or dataset_config == ""
or dataset_split == ""
):
logger.warning("Model id or dataset id is not selected")
return gr.update(interactive=False)
all_mappings = read_column_mapping(uid)
if not check_column_mapping_keys_validity(all_mappings):
logger.warning("Column mapping is not valid")
return gr.update(interactive=False)
if not check_hf_token_validity(inference_token):
logger.warning("HF token is not valid")
return gr.update(interactive=False)
return gr.update(interactive=True)
def construct_label_and_feature_mapping(
all_mappings, ds_labels, ds_features, label_keys=None
):
label_mapping = {}
if len(all_mappings["labels"].keys()) != len(ds_labels):
logger.warning(
f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
\nall_mappings: {all_mappings}\nds_labels: {ds_labels}"""
)
if len(all_mappings["features"].keys()) != len(ds_features):
logger.warning(
f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
\nall_mappings: {all_mappings}\nds_features: {ds_features}"""
)
for i, label in zip(range(len(ds_labels)), ds_labels):
# align the saved labels with dataset labels order
label_mapping.update({str(i): all_mappings["labels"][label]})
if "features" not in all_mappings.keys():
logger.warning("features not in all_mappings")
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
feature_mapping = all_mappings["features"]
if len(label_keys) > 0:
feature_mapping.update({"label": label_keys[0]})
return label_mapping, feature_mapping
def show_hf_token_info(token):
valid = check_hf_token_validity(token)
if not valid:
return gr.update(visible=True)
return gr.update(visible=False)
def try_submit(m_id, d_id, config, split, inference_token, uid, verbose):
all_mappings = read_column_mapping(uid)
if not check_column_mapping_keys_validity(all_mappings):
return (gr.update(interactive=True), gr.update(visible=False))
# get ds labels and features again for alignment
ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
ds_labels, ds_features, label_keys = get_labels_and_features_from_dataset(ds)
label_mapping, feature_mapping = construct_label_and_feature_mapping(
all_mappings, ds_labels, ds_features, label_keys
)
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
save_job_to_pipe(
uid,
(
m_id,
d_id,
config,
split,
inference_token,
uid,
label_mapping,
feature_mapping,
verbose,
),
eval_str,
threading.Lock(),
)
gr.Info("Your evaluation has been submitted")
new_uid = uuid.uuid4()
scanners = read_scanners(uid)
write_scanners(scanners, new_uid)
return (
gr.update(interactive=False), # Submit button
gr.update(
value=f"{CHECK_LOG_SECTION_RAW}Your job id is: {uid}. ",
lines=5,
visible=True,
interactive=False,
),
new_uid, # Allocate a new uuid
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)