BuggingSpace / app.py
J-Antoine ZAGATO
Changed title & markdown
a2772c8
raw
history blame
27.2 kB
import os
import torch
import numpy as np
import gradio as gr
from random import sample
from detoxify import Detoxify
from datasets import load_dataset
from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
from transformers import BloomTokenizerFast, BloomForCausalLM
HF_AUTH_TOKEN = os.environ.get('hf_token' or True)
DATASET = "allenai/real-toxicity-prompts"
CHECKPOINTS = {
"DistilGPT2 by HuggingFace πŸ€—" : "distilgpt2",
"GPT-Neo 125M by EleutherAI πŸ€–" : "EleutherAI/gpt-neo-125M",
"BLOOM 560M by BigScience 🌸" : "bigscience/bloom-560m",
"Custom Model" : None
}
MODEL_CLASSES = {
"DistilGPT2 by HuggingFace πŸ€—" : (GPT2LMHeadModel, GPT2Tokenizer),
"GPT-Neo 125M by EleutherAI πŸ€–" : (GPTNeoForCausalLM, GPT2Tokenizer),
"BLOOM 560M by BigScience 🌸" : (BloomForCausalLM, BloomTokenizerFast),
"Custom Model" : (AutoModelForCausalLM, AutoTokenizer),
}
CHOICES = sorted(list(CHECKPOINTS.keys())[:3])
def load_model(model_name, custom_model_path, token):
try:
model_class, tokenizer_class = MODEL_CLASSES[model_name]
model_path = CHECKPOINTS[model_name]
except KeyError:
model_class, tokenizer_class = MODEL_CLASSES['Custom Model']
model_path = custom_model_path or model_name
model = model_class.from_pretrained(model_path, use_auth_token=token)
tokenizer = tokenizer_class.from_pretrained(model_path, use_auth_token=token)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
model.eval()
return model, tokenizer
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
def set_seed(seed, n_gpu):
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(seed)
def adjust_length_to_model(length, max_sequence_length):
if length < 0 and max_sequence_length > 0:
length = max_sequence_length
elif 0 < max_sequence_length < length:
length = max_sequence_length # No generation bigger than model size
elif length < 0:
length = MAX_LENGTH # avoid infinite loop
return length
def generate(model_name,
token,
custom_model_path,
input_sentence,
length = 75,
temperature = 0.7,
top_k = 50,
top_p = 0.95,
seed = 42,
no_cuda = False,
num_return_sequences = 1,
stop_token = '.'
):
# load device
#if not no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
n_gpu = 0 if no_cuda else torch.cuda.device_count()
# Set seed
set_seed(seed, n_gpu)
# Load model
model, tokenizer = load_model(model_name, custom_model_path, token)
model.to(device)
#length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
# Tokenize input
encoded_prompt = tokenizer.encode(input_sentence,
add_special_tokens=False,
return_tensors='pt')
encoded_prompt = encoded_prompt.to(device)
input_ids = encoded_prompt
# Generate output
output_sequences = model.generate(input_ids=input_ids,
max_length=length + len(encoded_prompt[0]),
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
num_return_sequences=num_return_sequences
)
generated_sequences = list()
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
#remove prompt
text = text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
#remove all text after last occurence of stop_token
text = text[:text.rfind(stop_token)+1]
generated_sequences.append(text)
return generated_sequences[0]
def show_mode(mode):
if mode == 'Single Model':
return (
gr.update(visible=True),
gr.update(visible=False)
)
if mode == 'Multi-Model':
return (
gr.update(visible=False),
gr.update(visible=True)
)
def prepare_dataset(dataset):
dataset = load_dataset(dataset, split='train')
return dataset
def load_prompts(dataset):
prompts = [dataset[i]['prompt']['text'] for i in range(len(dataset))]
return prompts
def random_sample(prompt_list):
random_sample = sample(prompt_list,10)
return random_sample
def show_dataset(dataset):
raw_data = prepare_dataset(dataset)
prompts = load_prompts(raw_data)
return (gr.update(choices=random_sample(prompts),
label='You can find below a random subset from the RealToxicityPrompts dataset',
visible=True),
gr.update(visible=True),
prompts,
)
def update_dropdown(prompts):
return gr.update(choices=random_sample(prompts))
def show_search_bar(value):
if value == 'Custom Model':
return (value,
gr.update(visible=True)
)
else:
return (value,
gr.update(visible=False)
)
def search_model(model_name, token):
api = HfApi()
model_args = ModelSearchArguments()
filt = ModelFilter(
task=model_args.pipeline_tag.TextGeneration,
library=model_args.library.PyTorch)
results = api.list_models(filter=filt, search=model_name, use_auth_token=token)
model_list = [model.modelId for model in results]
return gr.update(visible=True,
choices=model_list,
label='Choose the model',
)
def show_api_key_textbox(checkbox):
if checkbox:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def forward_model_choice(model_choice_path):
return (model_choice_path,
model_choice_path)
def auto_complete(input, generated):
output = input + ' ' + generated
output_spans = [{'entity': 'OUTPUT', 'start': len(input), 'end': len(output)}]
completed_prompt = {"text": output, "entities": output_spans}
return completed_prompt
def process_user_input(model,
token,
custom_model_path,
input,
length,
temperature,
top_p,
top_k):
warning = 'Please enter a valid prompt.'
if input == None:
generated = warning
else:
generated = generate(model_name=model,
token=token,
custom_model_path=custom_model_path,
input_sentence=input,
length=length,
temperature=temperature,
top_p=top_p,
top_k=top_k)
generated_with_spans = auto_complete(input=input, generated=generated)
return (
gr.update(value=generated_with_spans),
gr.update(visible=True),
gr.update(visible=True),
input,
generated
)
def pass_to_textbox(input):
return gr.update(value=input)
def run_detoxify(text):
results = Detoxify('original').predict(text)
json_ready_results = {cat:float(score) for (cat,score) in results.items()}
return json_ready_results
def compute_toxi_output(output_text):
scores = run_detoxify(output_text)
return (
gr.update(value=scores, visible=True),
gr.update(visible=True)
)
def compute_change(input, output):
change_percent = round(((float(output)-input)/input)*100, 2)
return change_percent
def compare_toxi_scores(input_text, output_scores):
input_scores = run_detoxify(input_text)
json_ready_results = {cat:float(score) for (cat,score) in input_scores.items()}
compare_scores = {
cat:compute_change(json_ready_results[cat], output_scores[cat])
for cat in json_ready_results
for cat in output_scores
}
return (
gr.update(value=json_ready_results, visible=True),
gr.update(value=compare_scores, visible=True)
)
def show_flag_choices():
return gr.update(visible=True)
def update_flag(flag_value):
return (flag_value,
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False)
)
def upload_flag(*args):
if flagging_callback.flag(list(args), flag_option = None):
return gr.update(visible=True)
def forward_model_choice_multi(model_choice_path):
CHOICES.append(model_choice_path)
return gr.update(choices = CHOICES)
def process_user_input_multi(models,
input,
token,
length,
temperature,
top_p,
top_k):
warning = 'Please enter a valid prompt.'
if input == None:
generated = warning
else:
generated_dict= {model:generate(model_name=model,
token=token,
custom_model_path=None,
input_sentence=input,
length=length,
temperature=temperature,
top_p=top_p,
top_k=top_k) for model in sorted(models)}
generated_with_spans_dict = {model:auto_complete(input, generated) for model,generated in generated_dict.items()}
update_outputs = [gr.HighlightedText.update(value=output, label=model) for model,output in generated_with_spans_dict.items()]
update_hide = [gr.HighlightedText.update(visible=False) for i in range(10-len(models))]
return update_outputs + update_hide
def show_choices_multi(models):
update_show = [gr.HighlightedText.update(visible=True) for model in sorted(models)]
update_hide = [gr.HighlightedText.update(visible=False,value=None, label=None) for i in range(10-len(models))]
return update_show + update_hide
def show_params(checkbox):
if checkbox == True:
return gr.update(visible=True)
else:
return gr.update(visible=False)
CSS = """
#inside_group {
padding-top: 0.6em;
padding-bottom: 0.6em;
}
#pw textarea {
-webkit-text-security: disc;
}
"""
with gr.Blocks(css=CSS) as demo:
dataset = gr.Variable(value=DATASET)
prompts_var = gr.Variable(value=None)
input_var = gr.Variable(label="Input Prompt", value=None)
output_var = gr.Variable(label="Output",value=None)
model_choice = gr.Variable(label="Model", value=None)
custom_model_path = gr.Variable(value=None)
flag_choice = gr.Variable(label = "Flag", value=None)
flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
dataset_name = "fsdlredteam/flagged_2",
organization = "fsdlredteam",
private = True )
gr.Markdown("<p align='center'><img src='https://i.imgur.com/ZxbbLUQ.png>'/></p>")
gr.Markdown("<h1 align='center'>BuggingSpace</h1>")
gr.Markdown("<h2 align='center'>FSDL 2022 Red-Teaming Open-Source Models Project</h2>")
gr.Markdown("### Pick a text generation model below, write a prompt and explore the output")
gr.Markdown("### Or compare the output of multiple models at the same time")
choose_mode = gr.Radio(choices=['Single Model', "Multi-Model"],
value='Single Model',
interactive=True,
visible=True,
show_label=False)
with gr.Group() as single_model:
gr.Markdown("You can upload any model from the Hugging Face hub -even private ones, \
provided you use your private key! "
"Write your prompt or alternatively use one from the \
[RealToxicityPrompts](https://allenai.org/data/real-toxicity-prompts) dataset.")
gr.Markdown("Use it to audit the model for potential failure modes, \
analyse its output with the Detoxify suite and contribute by reporting any problematic result.")
gr.Markdown("Beware ! Generation can take up to a few minutes with very large models.")
with gr.Row():
with gr.Column(scale=1): # input & prompts dataset exploration
gr.Markdown("### 1. Select a prompt", elem_id="inside_group")
input_text = gr.Textbox(label="Write your prompt below.",
interactive=True,
lines=4,
elem_id="inside_group")
gr.Markdown("β€” or β€”", elem_id="inside_group")
inspo_button = gr.Button('Click here if you need some inspiration', elem_id="inside_group")
prompts_drop = gr.Dropdown(visible=False, elem_id="inside_group")
randomize_button = gr.Button('Show another subset', visible=False, elem_id="inside_group")
show_params_checkbox_single = gr.Checkbox(label='Set custom params',
interactive=True,
value=False)
with gr.Box(visible=False) as params_box_single:
length_single = gr.Slider(label='Output length',
visible=True,
interactive=True,
minimum=50,
maximum=200,
value=75)
top_k_single = gr.Slider(label='top_k',
visible=True,
interactive=True,
minimum=1,
maximum=100,
value=50)
top_p_single = gr.Slider(label='top_p',
visible=True,
interactive=True,
minimum=0.1,
maximum=1,
value=0.95)
temperature_single = gr.Slider(label='temperature',
visible=True,
interactive=True,
minimum=0.1,
maximum=1,
value=0.7)
with gr.Column(scale=1): # Model choice & output
gr.Markdown("### 2. Evaluate output")
model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
label='Model',
interactive=True,
elem_id="inside_group")
search_bar = gr.Textbox(label="Search model",
interactive=True,
visible=False,
elem_id="inside_group")
model_drop = gr.Dropdown(visible=False)
private_checkbox = gr.Checkbox(visible=True,label="Private Model ?", elem_id="inside_group")
api_key_textbox = gr.Textbox(label="Enter your AUTH TOKEN below",
value=None,
interactive=True,
visible=False,
elem_id="pw")
generate_button = gr.Button('Submit your prompt', elem_id="inside_group")
output_spans = gr.HighlightedText(visible=True, label="Generated text")
flag_button = gr.Button("Report output here", visible=False, elem_id="inside_group")
with gr.Row(): # Flagging
with gr.Column(scale=1):
flag_radio = gr.Radio(choices=["Toxic", "Offensive", "Repetitive", "Incorrect", "Other",],
label="What's wrong with the output ?",
interactive=True,
visible=False,
elem_id="inside_group")
user_comment = gr.Textbox(label="(Optional) Briefly describe the issue",
visible=False,
interactive=True,
elem_id="inside_group")
confirm_flag_button = gr.Button("Confirm report", visible=False, elem_id="inside_group")
with gr.Row(): # Flagging success
success_message = gr.Markdown("Your report has been successfully registered. Thank you!",
visible=False,
elem_id="inside_group")
with gr.Row(): # Toxicity buttons
toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False, elem_id="inside_group")
toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False, elem_id="inside_group")
with gr.Row(): # Toxicity scores
toxi_scores_input = gr.JSON(label = "Detoxify classification of your input",
visible=False,
elem_id="inside_group")
toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output",
visible=False,
elem_id="inside_group")
toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output",
visible=False,
elem_id="inside_group")
with gr.Group(visible=False) as multi_model:
model_list = list()
gr.Markdown("#### Run the same input on multiple models and compare the outputs")
gr.Markdown("You can upload any model from the Hugging Face hub -even private ones, provided you use your private key!")
gr.Markdown("Use this feature to compare the same model at different checkpoints")
gr.Markdown('Or to benchmark your model against another one as a reference.')
gr.Markdown("Beware ! Generation can take up to a few minutes with very large models.")
with gr.Row(elem_id="inside_group"):
with gr.Column():
models_multi = gr.CheckboxGroup(choices=CHOICES,
label='Models',
interactive=True,
elem_id="inside_group",
value=None)
with gr.Column():
generate_button_multi = gr.Button('Submit your prompt',elem_id="inside_group")
show_params_checkbox_multi = gr.Checkbox(label='Set custom params',
interactive=True,
value=False)
with gr.Box(visible=False) as params_box_multi:
length_multi = gr.Slider(label='Output length',
visible=True,
interactive=True,
minimum=50,
maximum=200,
value=75)
top_k_multi = gr.Slider(label='top_k',
visible=True,
interactive=True,
minimum=1,
maximum=100,
value=50)
top_p_multi = gr.Slider(label='top_p',
visible=True,
interactive=True,
minimum=0.1,
maximum=1,
value=0.95)
temperature_multi = gr.Slider(label='temperature',
visible=True,
interactive=True,
minimum=0.1,
maximum=1,
value=0.7)
with gr.Row(elem_id="inside_group"):
with gr.Column(elem_id="inside_group", scale=1):
input_text_multi = gr.Textbox(label="Write your prompt below.",
interactive=True,
lines=4,
elem_id="inside_group")
with gr.Column(elem_id="inside_group", scale=1):
search_bar_multi = gr.Textbox(label="Search another model",
interactive=True,
visible=True,
elem_id="inside_group")
model_drop_multi = gr.Dropdown(visible=False,
show_progress=True,
elem_id="inside_group")
private_checkbox_multi = gr.Checkbox(visible=True,label="Private Model ?")
api_key_textbox_multi = gr.Textbox(label="Enter your AUTH TOKEN below",
value=None,
interactive=True,
visible=False,
elem_id="pw")
with gr.Row() as outputs_row:
for i in range(10):
output_spans_multi = gr.HighlightedText(visible=False, elem_id="inside_group")
model_list.append(output_spans_multi)
with gr.Row():
gr.Markdown('App made during the [FSDL course](https://fullstackdeeplearning.com) \
by Team53: Jean-Antoine, Sajenthan, Sashank, Kemp, Srihari, Astitwa')
# Single Model
choose_mode.change(fn=show_mode,
inputs=choose_mode,
outputs=[single_model, multi_model])
inspo_button.click(fn=show_dataset,
inputs=dataset,
outputs=[prompts_drop, randomize_button, prompts_var])
prompts_drop.change(fn=pass_to_textbox,
inputs=prompts_drop,
outputs=input_text)
randomize_button.click(fn=update_dropdown,
inputs=prompts_var,
outputs=prompts_drop),
model_radio.change(fn=show_search_bar,
inputs=model_radio,
outputs=[model_choice,search_bar])
search_bar.submit(fn=search_model,
inputs=[search_bar,api_key_textbox],
outputs=model_drop,
show_progress=True)
private_checkbox.change(fn=show_api_key_textbox,
inputs=private_checkbox,
outputs=api_key_textbox)
model_drop.change(fn=forward_model_choice,
inputs=model_drop,
outputs=[model_choice,custom_model_path])
generate_button.click(fn=process_user_input,
inputs=[model_choice,
api_key_textbox,
custom_model_path,
input_text,
length_single,
temperature_single,
top_p_single,
top_k_single],
outputs=[output_spans,
toxi_button,
flag_button,
input_var,
output_var],
show_progress=True)
toxi_button.click(fn=compute_toxi_output,
inputs=output_var,
outputs=[toxi_scores_output, toxi_button_compare],
show_progress=True)
toxi_button_compare.click(fn=compare_toxi_scores,
inputs=[input_text, toxi_scores_output],
outputs=[toxi_scores_input, toxi_scores_compare],
show_progress=True)
flag_button.click(fn=show_flag_choices,
inputs=None,
outputs=flag_radio)
flag_radio.change(fn=update_flag,
inputs=flag_radio,
outputs=[flag_choice, confirm_flag_button, user_comment, flag_button])
flagging_callback.setup([input_var, output_var, model_choice, user_comment, flag_choice], "flagged_data_points")
confirm_flag_button.click(fn = upload_flag,
inputs = [input_var,
output_var,
model_choice,
user_comment,
flag_choice],
outputs=success_message)
show_params_checkbox_single.change(fn=show_params,
inputs=show_params_checkbox_single,
outputs=params_box_single)
# Model comparison
search_bar_multi.submit(fn=search_model,
inputs=[search_bar_multi, api_key_textbox_multi],
outputs=model_drop_multi,
show_progress=True)
show_params_checkbox_multi.change(fn=show_params,
inputs=show_params_checkbox_multi,
outputs=params_box_multi)
private_checkbox_multi.change(fn=show_api_key_textbox,
inputs=private_checkbox_multi,
outputs=api_key_textbox_multi)
model_drop_multi.change(fn=forward_model_choice_multi,
inputs=model_drop_multi,
outputs=[models_multi])
models_multi.change(fn=show_choices_multi,
inputs=models_multi,
outputs=model_list)
generate_button_multi.click(fn=process_user_input_multi,
inputs=[models_multi,
input_text_multi,
api_key_textbox_multi,
length_multi,
temperature_multi,
top_p_multi,
top_k_multi],
outputs=model_list,
show_progress=True)
#demo.launch(debug=True)
if __name__ == "__main__":
demo.launch(enable_queue=False, debug=True)