import collections import json import logging import os import threading import datasets import gradio as gr from transformers.pipelines import TextClassificationPipeline from io_utils import ( read_column_mapping, save_job_to_pipe, write_column_mapping, write_log_to_user_file, ) from text_classification import ( check_model, get_example_prediction, get_labels_and_features_from_dataset, ) from wordings import CONFIRM_MAPPING_DETAILS_FAIL_RAW MAX_LABELS = 20 MAX_FEATURES = 20 HF_REPO_ID = "HF_REPO_ID" HF_SPACE_ID = "SPACE_ID" HF_WRITE_TOKEN = "HF_WRITE_TOKEN" CONFIG_PATH = "./config.yaml" def check_dataset_and_get_config(dataset_id): try: write_column_mapping(None) configs = datasets.get_dataset_config_names(dataset_id) return gr.Dropdown(configs, value=configs[0], visible=True) except Exception: # Dataset may not exist pass def check_dataset_and_get_split(dataset_id, dataset_config): try: splits = list(datasets.load_dataset(dataset_id, dataset_config).keys()) return gr.Dropdown(splits, value=splits[0], visible=True) except Exception: # Dataset may not exist # gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}") pass def write_column_mapping_to_config(dataset_id, dataset_config, dataset_split, *labels): ds_labels, ds_features = get_labels_and_features_from_dataset( dataset_id, dataset_config, dataset_split ) if labels is None: return labels = [*labels] all_mappings = read_column_mapping(CONFIG_PATH) if all_mappings is None: all_mappings = dict() if "labels" not in all_mappings.keys(): all_mappings["labels"] = dict() for i, label in enumerate(labels[:MAX_LABELS]): if label: all_mappings["labels"][label] = ds_labels[i] if "features" not in all_mappings.keys(): all_mappings["features"] = dict() for i, feat in enumerate(labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)]): if feat: all_mappings["features"][feat] = ds_features[i] write_column_mapping(all_mappings) def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label): model_labels = list(model_id2label.values()) len_model_labels = len(model_labels) print(model_labels, model_id2label, 3 % len_model_labels) lables = [ gr.Dropdown( label=f"{label}", choices=model_labels, value=model_id2label[i % len_model_labels], interactive=True, visible=True, ) for i, label in enumerate(ds_labels[:MAX_LABELS]) ] lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))] # TODO: Substitute 'text' with more features for zero-shot features = [ gr.Dropdown( label=f"{feature}", choices=ds_features, value=ds_features[0], interactive=True, visible=True, ) for feature in ["text"] ] features += [ gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features)) ] return lables + features def check_model_and_show_prediction( model_id, dataset_id, dataset_config, dataset_split ): ppl = check_model(model_id) if ppl is None or not isinstance(ppl, TextClassificationPipeline): gr.Warning("Please check your model.") return ( gr.update(visible=False), gr.update(visible=False), *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)], ) dropdown_placement = [ gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES) ] if ppl is None: # pipeline not found gr.Warning("Model not found") return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, open=False), *dropdown_placement, ) model_id2label = ppl.model.config.id2label ds_labels, ds_features = get_labels_and_features_from_dataset( dataset_id, dataset_config, dataset_split ) # when dataset does not have labels or features if not isinstance(ds_labels, list) or not isinstance(ds_features, list): # gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, open=False), *dropdown_placement, ) column_mappings = list_labels_and_features_from_dataset( ds_labels, ds_features, model_id2label, ) # when labels or features are not aligned # show manually column mapping if ( collections.Counter(model_id2label.values()) != collections.Counter(ds_labels) or ds_features[0] != "text" ): gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, open=True), *column_mappings, ) prediction_input, prediction_output = get_example_prediction( ppl, dataset_id, dataset_config, dataset_split ) return ( gr.update(value=prediction_input, visible=True), gr.update(value=prediction_output, visible=True), gr.update(visible=True, open=False), *column_mappings, ) def try_submit(m_id, d_id, config, split, local, uid): all_mappings = read_column_mapping(CONFIG_PATH) if all_mappings is None: gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) return (gr.update(interactive=True), gr.update(visible=False)) if "labels" not in all_mappings.keys(): gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) return (gr.update(interactive=True), gr.update(visible=False)) label_mapping = all_mappings["labels"] if "features" not in all_mappings.keys(): gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) return (gr.update(interactive=True), gr.update(visible=False)) feature_mapping = all_mappings["features"] # TODO: Set column mapping for some dataset such as `amazon_polarity` if local: command = [ "giskard_scanner", "--loader", "huggingface", "--model", m_id, "--dataset", d_id, "--dataset_config", config, "--dataset_split", split, "--hf_token", os.environ.get(HF_WRITE_TOKEN), "--discussion_repo", os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID), "--output_format", "markdown", "--output_portal", "huggingface", "--feature_mapping", json.dumps(feature_mapping), "--label_mapping", json.dumps(label_mapping), "--scan_config", "../config.yaml", ] eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>" logging.info(f"Start local evaluation on {eval_str}") save_job_to_pipe(uid, command, threading.Lock()) write_log_to_user_file( uid, f"Start local evaluation on {eval_str}. Please wait for your job to start...\n", ) gr.Info(f"Start local evaluation on {eval_str}") return ( gr.update(interactive=False), gr.update(lines=5, visible=True, interactive=False), ) else: gr.Info("TODO: Submit task to an endpoint") return (gr.update(interactive=True), gr.update(visible=False)) # Submit button