ZeroCommand's picture
updated version of ui
9e212de
raw
history blame
14 kB
import gradio as gr
import datasets
import huggingface_hub
import os
import time
import subprocess
import logging
import json
from transformers.pipelines import TextClassificationPipeline
from text_classification import text_classification_fix_column_mapping
HF_REPO_ID = 'HF_REPO_ID'
HF_SPACE_ID = 'SPACE_ID'
HF_WRITE_TOKEN = 'HF_WRITE_TOKEN'
theme = gr.themes.Soft(
primary_hue="green",
)
def check_model(model_id):
try:
task = huggingface_hub.model_info(model_id).pipeline_tag
except Exception:
return None, None
try:
from transformers import pipeline
ppl = pipeline(task=task, model=model_id)
return model_id, ppl
except Exception as e:
return model_id, e
def check_dataset(dataset_id, dataset_config="default", dataset_split="test"):
try:
configs = datasets.get_dataset_config_names(dataset_id)
except Exception:
# Dataset may not exist
return None, dataset_config, dataset_split
if dataset_config not in configs:
# Need to choose dataset subset (config)
return dataset_id, configs, dataset_split
ds = datasets.load_dataset(dataset_id, dataset_config)
if isinstance(ds, datasets.DatasetDict):
# Need to choose dataset split
if dataset_split not in ds.keys():
return dataset_id, None, list(ds.keys())
elif not isinstance(ds, datasets.Dataset):
# Unknown type
return dataset_id, None, None
return dataset_id, dataset_config, dataset_split
def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping='{}'):
# Validate model
m_id, ppl = check_model(model_id=model_id)
if m_id is None:
gr.Warning(f'Model "{model_id}" is not accessible. Please set your HF_TOKEN if it is a private model.')
return (
gr.update(interactive=False), # Submit button
gr.update(visible=True), # Loading row
gr.update(visible=False), # Preview row
gr.update(visible=False), # Model prediction input
gr.update(visible=False), # Model prediction preview
gr.update(visible=False), # Label mapping preview
)
if isinstance(ppl, Exception):
gr.Warning(f'Failed to load "{model_id} model": {ppl}')
return (
gr.update(interactive=False), # Submit button
gr.update(visible=True), # Loading row
gr.update(visible=False), # Preview row
gr.update(visible=False), # Model prediction input
gr.update(visible=False), # Model prediction preview
gr.update(visible=False), # Label mapping preview
)
# Validate dataset
d_id, config, split = check_dataset(dataset_id=dataset_id, dataset_config=dataset_config, dataset_split=dataset_split)
dataset_ok = False
if d_id is None:
gr.Warning(f'Dataset "{dataset_id}" is not accessible. Please set your HF_TOKEN if it is a private dataset.')
elif isinstance(config, list):
gr.Warning(f'Dataset "{dataset_id}" does not have "{dataset_config}" config. Please choose a valid config.')
config = gr.update(choices=config, value=config[0])
elif isinstance(split, list):
gr.Warning(f'Dataset "{dataset_id}" does not have "{dataset_split}" split. Please choose a valid split.')
split = gr.update(choices=split, value=split[0])
else:
dataset_ok = True
if not dataset_ok:
return (
gr.update(interactive=False), # Submit button
gr.update(visible=True), # Loading row
gr.update(visible=False), # Preview row
gr.update(visible=False), # Model prediction input
gr.update(visible=False), # Model prediction preview
gr.update(visible=False), # Label mapping preview
# gr.update(visible=True), # Column mapping
)
# TODO: Validate column mapping by running once
prediction_result = None
id2label_df = None
if isinstance(ppl, TextClassificationPipeline):
try:
print('validating phase, ', column_mapping)
column_mapping = json.loads(column_mapping)
except Exception:
column_mapping = {}
column_mapping, prediction_input, prediction_result, id2label_df = \
text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split)
column_mapping = json.dumps(column_mapping, indent=2)
del ppl
if prediction_result is None:
gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
return (
gr.update(interactive=False), # Submit button
gr.update(visible=True), # Loading row
gr.update(visible=False), # Preview row
gr.update(visible=False), # Model prediction input
gr.update(visible=False), # Model prediction preview
gr.update(visible=False), # Label mapping preview
# gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
)
elif id2label_df is None:
gr.Warning('The prediction result does not conform the labels in the dataset. Please provide label mappings in "Advance" settings.')
return (
gr.update(interactive=False), # Submit button
gr.update(visible=False), # Loading row
gr.update(visible=True), # Preview row
gr.update(value=f'**Sample Input**: {prediction_input}', visible=True), # Model prediction input
gr.update(value=prediction_result, visible=True), # Model prediction preview
gr.update(visible=False), # Label mapping preview
# gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
)
gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
return (
gr.update(interactive=True), # Submit button
gr.update(visible=False), # Loading row
gr.update(visible=True), # Preview row
gr.update(value=f'**Sample Input**: {prediction_input}', visible=True), # Model prediction input
gr.update(value=prediction_result, visible=True), # Model prediction preview
gr.update(value=id2label_df, visible=True, interactive=True), # Label mapping preview
)
def try_submit(m_id, d_id, config, split, column_mappings, local):
label_mapping = {}
try:
column_mapping = json.loads(column_mappings)
if "label" in column_mapping:
label_mapping = column_mapping.pop("label", {})
except Exception:
column_mapping = {}
if local:
command = [
"python",
"cli.py",
"--loader", "huggingface",
"--model", m_id,
"--dataset", d_id,
"--dataset_config", config,
"--dataset_split", split,
"--hf_token", os.environ.get(HF_WRITE_TOKEN),
"--discussion_repo", os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
"--output_format", "markdown",
"--output_portal", "huggingface",
"--feature_mapping", json.dumps(column_mapping),
"--label_mapping", json.dumps(label_mapping),
]
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
start = time.time()
logging.info(f"Start local evaluation on {eval_str}")
evaluator = subprocess.Popen(
command,
cwd=os.path.join(os.path.dirname(os.path.realpath(__file__)), "cicd"),
stderr=subprocess.STDOUT,
)
result = evaluator.wait()
logging.info(f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s")
gr.Info(f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s")
else:
gr.Info("TODO: Submit task to an endpoint")
return gr.update(interactive=True) # Submit button
with gr.Blocks(theme=theme) as iface:
with gr.Tab("Text Classification"):
def check_dataset_and_get_config(dataset_id):
try:
configs = datasets.get_dataset_config_names(dataset_id)
print(configs)
return gr.Dropdown(configs, value=configs[0], visible=True)
except Exception:
# Dataset may not exist
pass
def check_dataset_and_get_split(dataset_config, dataset_id):
try:
splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
print('splits: ',splits)
return gr.Dropdown(splits, value=splits[0], visible=True)
except Exception as e:
# Dataset may not exist
print(e)
pass
def gate_validate_btn(model_id, dataset_id, dataset_config, dataset_split, id2label_mapping_dataframe=None):
print('model_id: ',model_id)
column_mapping = '{}'
if id2label_mapping_dataframe is not None:
column_mapping = id2label_mapping_dataframe.to_json(orient="split")
print(column_mapping)
if model_id and dataset_id and dataset_config and dataset_split:
return try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping)
else:
return (gr.update(interactive=False),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False))
with gr.Row():
model_id_input = gr.Textbox(
label="Hugging Face model id",
placeholder="cardiffnlp/twitter-roberta-base-sentiment-latest",
)
dataset_id_input = gr.Textbox(
label="Hugging Face Dataset id",
placeholder="tweet_eval",
)
with gr.Row():
dataset_config_input = gr.Dropdown(['default'], value=['default'], label='Dataset Config', visible=False)
dataset_split_input = gr.Dropdown(['default'], value=['default'], label='Dataset Split', visible=False)
dataset_id_input.change(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
dataset_config_input.change(
check_dataset_and_get_split,
inputs=[dataset_config_input, dataset_id_input],
outputs=[dataset_split_input])
with gr.Row(visible=True) as loading_row:
gr.Markdown('''
<h1 style="text-align: center;">
Please validate your model and dataset first...
</h1>
''')
with gr.Row(visible=False) as preview_row:
gr.Markdown('''
<h1 style="text-align: center;">
Confirm Label Details
</h1>
Base on your model and dataset, we inferred this label mapping. **If the mapping is incorrect, please modify it in the table below.**
''')
with gr.Row():
id2label_mapping_dataframe = gr.DataFrame(label="Preview of label mapping", interactive=True, visible=False)
with gr.Row():
example_input = gr.Markdown('Sample Input: ', visible=False)
with gr.Row():
example_labels = gr.Label(label='Model Prediction Sample', visible=False)
run_btn = gr.Button(
"Get Evaluation Result",
variant="primary",
interactive=False,
size="lg",
)
model_id_input.change(gate_validate_btn,
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
dataset_id_input.change(gate_validate_btn,
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
dataset_config_input.change(gate_validate_btn,
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
dataset_split_input.change(gate_validate_btn,
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
id2label_mapping_dataframe.input(gate_validate_btn,
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input, id2label_mapping_dataframe],
outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
run_btn.click(
try_submit,
inputs=[
model_id_input,
dataset_id_input,
dataset_config_input,
dataset_split_input,
],
outputs=[
run_btn,
],
)
with gr.Tab("More"):
pass
if __name__ == "__main__":
iface.queue(max_size=20).launch()