BuggingSpace / app.py
J-Antoine ZAGATO
Added multi model structure wo api key this time
40d38f3
raw
history blame
15.6 kB
import os
import torch
import numpy as np
import gradio as gr
from random import sample
from detoxify import Detoxify
from datasets import load_dataset
from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
from transformers import BloomTokenizerFast, BloomForCausalLM
HF_AUTH_TOKEN = os.environ.get('hf_token' or True)
DATASET = "allenai/real-toxicity-prompts"
CHECKPOINTS = {
"DistilGPT2 by HuggingFace πŸ€—" : "distilgpt2",
"GPT-Neo 125M by EleutherAI πŸ€–" : "EleutherAI/gpt-neo-125M",
"BLOOM 560M by BigScience 🌸" : "bigscience/bloom-560m",
"Custom Model" : None
}
MODEL_CLASSES = {
"DistilGPT2 by HuggingFace πŸ€—" : (GPT2LMHeadModel, GPT2Tokenizer),
"GPT-Neo 125M by EleutherAI πŸ€–" : (GPTNeoForCausalLM, GPT2Tokenizer),
"BLOOM 560M by BigScience 🌸" : (BloomForCausalLM, BloomTokenizerFast),
"Custom Model" : (AutoModelForCausalLM, AutoTokenizer),
}
def load_model(model_name, custom_model_path):
try:
model_class, tokenizer_class = MODEL_CLASSES[model_name]
model_path = CHECKPOINTS[model_name]
except KeyError:
model_class, tokenizer_class = MODEL_CLASSES['Custom Model']
model_path = custom_model_path
model = model_class.from_pretrained(model_path)
tokenizer = tokenizer_class.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
model.eval()
return model, tokenizer
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
def set_seed(seed, n_gpu):
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(seed)
def adjust_length_to_model(length, max_sequence_length):
if length < 0 and max_sequence_length > 0:
length = max_sequence_length
elif 0 < max_sequence_length < length:
length = max_sequence_length # No generation bigger than model size
elif length < 0:
length = MAX_LENGTH # avoid infinite loop
return length
def generate(model_name,
custom_model_path,
input_sentence,
length = 75,
temperature = 0.7,
top_k = 50,
top_p = 0.95,
seed = 42,
no_cuda = False,
num_return_sequences = 1,
stop_token = '.'
):
# load device
#if not no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
n_gpu = 0 if no_cuda else torch.cuda.device_count()
# Set seed
set_seed(seed, n_gpu)
# Load model
model, tokenizer = load_model(model_name, custom_model_path)
model.to(device)
#length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
# Tokenize input
encoded_prompt = tokenizer.encode(input_sentence,
add_special_tokens=False,
return_tensors='pt')
encoded_prompt = encoded_prompt.to(device)
input_ids = encoded_prompt
# Generate output
output_sequences = model.generate(input_ids=input_ids,
max_length=length + len(encoded_prompt[0]),
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
num_return_sequences=num_return_sequences
)
generated_sequences = list()
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
#remove prompt
text = text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
#remove all text after last occurence of stop_token
text = text[:text.rfind(stop_token)+1]
generated_sequences.append(text)
return generated_sequences[0]
def show_mode(mode):
if mode == 'Single Model':
return (
gr.update(visible=True),
gr.update(visible=False)
)
if mode == 'Multi-Model':
return (
gr.update(visible=False),
gr.update(visible=True)
)
def prepare_dataset(dataset):
dataset = load_dataset(dataset, split='train')
return dataset
def load_prompts(dataset):
prompts = [dataset[i]['prompt']['text'] for i in range(len(dataset))]
return prompts
def random_sample(prompt_list):
random_sample = sample(prompt_list,10)
return random_sample
def show_dataset(dataset):
raw_data = prepare_dataset(dataset)
prompts = load_prompts(raw_data)
return (gr.update(choices=random_sample(prompts),
label='You can find below a random subset from the RealToxicityPrompts dataset',
visible=True),
gr.update(visible=True),
prompts,
)
def update_dropdown(prompts):
return gr.update(choices=random_sample(prompts))
def show_search_bar(value):
if value == 'Custom Model':
return (value,
gr.update(visible=True)
)
else:
return (value,
gr.update(visible=False)
)
def search_model(model_name):
api = HfApi()
model_args = ModelSearchArguments()
filt = ModelFilter(
task=model_args.pipeline_tag.TextGeneration,
library=model_args.library.PyTorch)
results = api.list_models(filter=filt, search=model_name)
model_list = [model.modelId for model in results]
return gr.update(visible=True,
choices=model_list,
label='Choose the model',
)
def forward_model_choice(model_choice_path):
return (model_choice_path,
model_choice_path)
def auto_complete(input, generated):
output = input + ' ' + generated
output_spans = [{'entity': 'OUTPUT', 'start': len(input), 'end': len(output)}]
completed_prompt = {"text": output, "entities": output_spans}
return completed_prompt
def process_user_input(model, custom_model_path, input):
warning = 'Please enter a valid prompt.'
if input == None:
generated = warning
else:
generated = generate(model, custom_model_path, input)
generated_with_spans = auto_complete(input, generated)
return (
generated_with_spans,
gr.update(visible=True),
gr.update(visible=True),
input,
generated
)
def pass_to_textbox(input):
return gr.update(value=input)
def run_detoxify(text):
results = Detoxify('original').predict(text)
json_ready_results = {cat:float(score) for (cat,score) in results.items()}
return json_ready_results
def compute_toxi_output(output_text):
scores = run_detoxify(output_text)
return (
gr.update(value=scores, visible=True),
gr.update(visible=True)
)
def compute_change(input, output):
change_percent = round(((float(output)-input)/input)*100, 2)
return change_percent
def compare_toxi_scores(input_text, output_scores):
input_scores = run_detoxify(input_text)
json_ready_results = {cat:float(score) for (cat,score) in input_scores.items()}
compare_scores = {
cat:compute_change(json_ready_results[cat], output_scores[cat])
for cat in json_ready_results
for cat in output_scores
}
return (
gr.update(value=json_ready_results, visible=True),
gr.update(value=compare_scores, visible=True)
)
def show_flag_choices():
return gr.update(visible=True)
def update_flag(flag_value):
return (flag_value,
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False)
)
def upload_flag(*args):
if flagging_callback.flag(list(args), flag_option = None):
return gr.update(visible=True)
CSS = """
#inside_group {
padding-top: 0.6em;
padding-bottom: 0.6em;
}
"""
with gr.Blocks(css=CSS) as demo:
dataset = gr.Variable(value=DATASET)
prompts_var = gr.Variable(value=None)
input_var = gr.Variable(label="Input Prompt", value=None)
output_var = gr.Variable(label="Output",value=None)
model_choice = gr.Variable(label="Model", value=None)
custom_model_path = gr.Variable(value=None)
flag_choice = gr.Variable(label = "Flag", value=None)
flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
dataset_name = "fsdlredteam/flagged_2",
organization = "fsdlredteam",
private = True )
gr.Markdown("# Project Interface proposal")
gr.Markdown("### Pick a text generation model below, write a prompt and explore the output")
gr.Markdown("### Or compare multiple models")
choose_mode = gr.Radio(choices=['Single Model', "Multi-Model"],
value='Single Model',
interactive=True,
visible=True,
show_label=False)
with gr.Group() as single_model:
with gr.Row():
with gr.Column(scale=1): # input & prompts dataset exploration
gr.Markdown("### 1. Select a prompt", elem_id="inside_group")
input_text = gr.Textbox(label="Write your prompt below.",
interactive=True,
lines=4,
elem_id="inside_group")
gr.Markdown("β€” or β€”", elem_id="inside_group")
inspo_button = gr.Button('Click here if you need some inspiration', elem_id="inside_group")
prompts_drop = gr.Dropdown(visible=False, elem_id="inside_group")
randomize_button = gr.Button('Show another subset', visible=False, elem_id="inside_group")
with gr.Column(scale=1): # Model choice & output
gr.Markdown("### 2. Evaluate output")
model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
label='Model',
interactive=True,
elem_id="inside_group")
search_bar = gr.Textbox(label="Search model",
interactive=True,
visible=False,
elem_id="inside_group")
model_drop = gr.Dropdown(visible=False)
generate_button = gr.Button('Submit your prompt')
output_spans = gr.HighlightedText(visible=True, label="Generated text", elem_id="inside_group")
flag_button = gr.Button("Report output here", visible=False)
with gr.Row(): # Flagging
with gr.Column(scale=1):
flag_radio = gr.Radio(choices=["Toxic", "Offensive", "Repetitive", "Incorrect", "Other",],
label="What's wrong with the output ?",
interactive=True,
visible=False,
elem_id="inside_group")
user_comment = gr.Textbox(label="(Optional) Briefly describe the issue",
visible=False,
interactive=True,
elem_id="inside_group")
confirm_flag_button = gr.Button("Confirm report", visible=False, elem_id="inside_group")
with gr.Row(): # Flagging success
success_message = gr.Markdown("Your report has been successfully registered. Thank you!",
visible=False,
elem_id="inside_group")
with gr.Row(): # Toxicity buttons
toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False, elem_id="inside_group")
toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False, elem_id="inside_group")
with gr.Row(): # Toxicity scores
toxi_scores_input = gr.JSON(label = "Detoxify classification of your input",
visible=False,
elem_id="inside_group")
toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output",
visible=False,
elem_id="inside_group")
toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output",
visible=False,
elem_id="inside_group")
with gr.Group() as multi_model:
gr.Markdown("Model comparison will be here")
choose_mode.change(fn=show_mode,
inputs=choose_mode,
outputs=[single_model, multi_model])
inspo_button.click(fn=show_dataset,
inputs=dataset,
outputs=[prompts_drop, randomize_button, prompts_var])
prompts_drop.change(fn=pass_to_textbox,
inputs=prompts_drop,
outputs=input_text)
randomize_button.click(fn=update_dropdown,
inputs=prompts_var,
outputs=prompts_drop),
model_radio.change(fn=show_search_bar,
inputs=model_radio,
outputs=[model_choice,search_bar])
search_bar.submit(fn=search_model,
inputs=search_bar,
outputs=model_drop,
show_progress=True)
model_drop.change(fn=forward_model_choice,
inputs=model_drop,
outputs=[model_choice,custom_model_path])
generate_button.click(fn=process_user_input,
inputs=[model_choice, custom_model_path, input_text],
outputs=[output_spans,
toxi_button,
flag_button,
input_var,
output_var],
show_progress=True)
toxi_button.click(fn=compute_toxi_output,
inputs=output_var,
outputs=[toxi_scores_output, toxi_button_compare],
show_progress=True)
toxi_button_compare.click(fn=compare_toxi_scores,
inputs=[input_text, toxi_scores_output],
outputs=[toxi_scores_input, toxi_scores_compare],
show_progress=True)
flag_button.click(fn=show_flag_choices,
inputs=None,
outputs=flag_radio)
flag_radio.change(fn=update_flag,
inputs=flag_radio,
outputs=[flag_choice, confirm_flag_button, user_comment, flag_button])
flagging_callback.setup([input_var, output_var, model_choice, user_comment, flag_choice], "flagged_data_points")
confirm_flag_button.click(fn = upload_flag,
inputs = [input_var,
output_var,
model_choice,
user_comment,
flag_choice],
outputs=success_message)
#demo.launch(debug=True)
if __name__ == "__main__":
demo.launch(enable_queue=False)