Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import numpy as np | |
import gradio as gr | |
from random import sample | |
from detoxify import Detoxify | |
from datasets import load_dataset | |
from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM | |
from transformers import BloomTokenizerFast, BloomForCausalLM | |
HF_AUTH_TOKEN = os.environ.get('hf_token' or True) | |
DATASET = "allenai/real-toxicity-prompts" | |
CHECKPOINTS = { | |
"DistilGPT2 by HuggingFace π€" : "distilgpt2", | |
"GPT-Neo 125M by EleutherAI π€" : "EleutherAI/gpt-neo-125M", | |
"BLOOM 560M by BigScience πΈ" : "bigscience/bloom-560m", | |
"Custom Model" : None | |
} | |
MODEL_CLASSES = { | |
"DistilGPT2 by HuggingFace π€" : (GPT2LMHeadModel, GPT2Tokenizer), | |
"GPT-Neo 125M by EleutherAI π€" : (GPTNeoForCausalLM, GPT2Tokenizer), | |
"BLOOM 560M by BigScience πΈ" : (BloomForCausalLM, BloomTokenizerFast), | |
"Custom Model" : (AutoModelForCausalLM, AutoTokenizer), | |
} | |
CHOICES = sorted(list(CHECKPOINTS.keys())[:3]) | |
def load_model(model_name, custom_model_path, token): | |
try: | |
model_class, tokenizer_class = MODEL_CLASSES[model_name] | |
model_path = CHECKPOINTS[model_name] | |
except KeyError: | |
model_class, tokenizer_class = MODEL_CLASSES['Custom Model'] | |
model_path = custom_model_path or model_name | |
model = model_class.from_pretrained(model_path, use_auth_token=token) | |
tokenizer = tokenizer_class.from_pretrained(model_path, use_auth_token=token) | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
model.eval() | |
return model, tokenizer | |
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop | |
def set_seed(seed, n_gpu): | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if n_gpu > 0: | |
torch.cuda.manual_seed_all(seed) | |
def adjust_length_to_model(length, max_sequence_length): | |
if length < 0 and max_sequence_length > 0: | |
length = max_sequence_length | |
elif 0 < max_sequence_length < length: | |
length = max_sequence_length # No generation bigger than model size | |
elif length < 0: | |
length = MAX_LENGTH # avoid infinite loop | |
return length | |
def generate(model_name, | |
token, | |
custom_model_path, | |
input_sentence, | |
length = 75, | |
temperature = 0.7, | |
top_k = 50, | |
top_p = 0.95, | |
seed = 42, | |
no_cuda = False, | |
num_return_sequences = 1, | |
stop_token = '.' | |
): | |
# load device | |
#if not no_cuda: | |
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu") | |
n_gpu = 0 if no_cuda else torch.cuda.device_count() | |
# Set seed | |
set_seed(seed, n_gpu) | |
# Load model | |
model, tokenizer = load_model(model_name, custom_model_path, token) | |
model.to(device) | |
#length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings) | |
# Tokenize input | |
encoded_prompt = tokenizer.encode(input_sentence, | |
add_special_tokens=False, | |
return_tensors='pt') | |
encoded_prompt = encoded_prompt.to(device) | |
input_ids = encoded_prompt | |
# Generate output | |
output_sequences = model.generate(input_ids=input_ids, | |
max_length=length + len(encoded_prompt[0]), | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
do_sample=True, | |
num_return_sequences=num_return_sequences | |
) | |
generated_sequences = list() | |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): | |
generated_sequence = generated_sequence.tolist() | |
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) | |
#remove prompt | |
text = text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] | |
#remove all text after last occurence of stop_token | |
text = text[:text.rfind(stop_token)+1] | |
generated_sequences.append(text) | |
return generated_sequences[0] | |
def show_mode(mode): | |
if mode == 'Single Model': | |
return ( | |
gr.update(visible=True), | |
gr.update(visible=False) | |
) | |
if mode == 'Multi-Model': | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=True) | |
) | |
def prepare_dataset(dataset): | |
dataset = load_dataset(dataset, split='train') | |
return dataset | |
def load_prompts(dataset): | |
prompts = [dataset[i]['prompt']['text'] for i in range(len(dataset))] | |
return prompts | |
def random_sample(prompt_list): | |
random_sample = sample(prompt_list,10) | |
return random_sample | |
def show_dataset(dataset): | |
raw_data = prepare_dataset(dataset) | |
prompts = load_prompts(raw_data) | |
return (gr.update(choices=random_sample(prompts), | |
label='You can find below a random subset from the RealToxicityPrompts dataset', | |
visible=True), | |
gr.update(visible=True), | |
prompts, | |
) | |
def update_dropdown(prompts): | |
return gr.update(choices=random_sample(prompts)) | |
def show_search_bar(value): | |
if value == 'Custom Model': | |
return (value, | |
gr.update(visible=True) | |
) | |
else: | |
return (value, | |
gr.update(visible=False) | |
) | |
def search_model(model_name, token): | |
api = HfApi() | |
model_args = ModelSearchArguments() | |
filt = ModelFilter( | |
task=model_args.pipeline_tag.TextGeneration, | |
library=model_args.library.PyTorch) | |
results = api.list_models(filter=filt, search=model_name, use_auth_token=token) | |
model_list = [model.modelId for model in results] | |
return gr.update(visible=True, | |
choices=model_list, | |
label='Choose the model', | |
) | |
def show_api_key_textbox(checkbox): | |
if checkbox: | |
return gr.update(visible=True) | |
else: | |
return gr.update(visible=False) | |
def forward_model_choice(model_choice_path): | |
return (model_choice_path, | |
model_choice_path) | |
def auto_complete(input, generated): | |
output = input + ' ' + generated | |
output_spans = [{'entity': 'OUTPUT', 'start': len(input), 'end': len(output)}] | |
completed_prompt = {"text": output, "entities": output_spans} | |
return completed_prompt | |
def process_user_input(model, | |
token, | |
custom_model_path, | |
input, | |
length, | |
temperature, | |
top_p, | |
top_k): | |
warning = 'Please enter a valid prompt.' | |
if input == None: | |
generated = warning | |
else: | |
generated = generate(model_name=model, | |
token=token, | |
custom_model_path=custom_model_path, | |
input_sentence=input, | |
length=length, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k) | |
generated_with_spans = auto_complete(input=input, generated=generated) | |
return ( | |
gr.update(value=generated_with_spans), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
input, | |
generated | |
) | |
def pass_to_textbox(input): | |
return gr.update(value=input) | |
def run_detoxify(text): | |
results = Detoxify('original').predict(text) | |
json_ready_results = {cat:float(score) for (cat,score) in results.items()} | |
return json_ready_results | |
def compute_toxi_output(output_text): | |
scores = run_detoxify(output_text) | |
return ( | |
gr.update(value=scores, visible=True), | |
gr.update(visible=True) | |
) | |
def compute_change(input, output): | |
change_percent = round(((float(output)-input)/input)*100, 2) | |
return change_percent | |
def compare_toxi_scores(input_text, output_scores): | |
input_scores = run_detoxify(input_text) | |
json_ready_results = {cat:float(score) for (cat,score) in input_scores.items()} | |
compare_scores = { | |
cat:compute_change(json_ready_results[cat], output_scores[cat]) | |
for cat in json_ready_results | |
for cat in output_scores | |
} | |
return ( | |
gr.update(value=json_ready_results, visible=True), | |
gr.update(value=compare_scores, visible=True) | |
) | |
def show_flag_choices(): | |
return gr.update(visible=True) | |
def update_flag(flag_value): | |
return (flag_value, | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=False) | |
) | |
def upload_flag(*args): | |
if flagging_callback.flag(list(args), flag_option = None): | |
return gr.update(visible=True) | |
def forward_model_choice_multi(model_choice_path): | |
CHOICES.append(model_choice_path) | |
return gr.update(choices = CHOICES) | |
def process_user_input_multi(models, | |
input, | |
token, | |
length, | |
temperature, | |
top_p, | |
top_k): | |
warning = 'Please enter a valid prompt.' | |
if input == None: | |
generated = warning | |
else: | |
generated_dict= {model:generate(model_name=model, | |
token=token, | |
custom_model_path=None, | |
input_sentence=input, | |
length=length, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k) for model in sorted(models)} | |
generated_with_spans_dict = {model:auto_complete(input, generated) for model,generated in generated_dict.items()} | |
update_outputs = [gr.HighlightedText.update(value=output, label=model) for model,output in generated_with_spans_dict.items()] | |
update_hide = [gr.HighlightedText.update(visible=False) for i in range(10-len(models))] | |
return update_outputs + update_hide | |
def show_choices_multi(models): | |
update_show = [gr.HighlightedText.update(visible=True) for model in sorted(models)] | |
update_hide = [gr.HighlightedText.update(visible=False,value=None, label=None) for i in range(10-len(models))] | |
return update_show + update_hide | |
def show_params(checkbox): | |
if checkbox == True: | |
return gr.update(visible=True) | |
else: | |
return gr.update(visible=False) | |
CSS = """ | |
#inside_group { | |
padding-top: 0.6em; | |
padding-bottom: 0.6em; | |
} | |
#pw textarea { | |
-webkit-text-security: disc; | |
} | |
""" | |
with gr.Blocks(css=CSS) as demo: | |
dataset = gr.Variable(value=DATASET) | |
prompts_var = gr.Variable(value=None) | |
input_var = gr.Variable(label="Input Prompt", value=None) | |
output_var = gr.Variable(label="Output",value=None) | |
model_choice = gr.Variable(label="Model", value=None) | |
custom_model_path = gr.Variable(value=None) | |
flag_choice = gr.Variable(label = "Flag", value=None) | |
flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN, | |
dataset_name = "fsdlredteam/flagged_2", | |
organization = "fsdlredteam", | |
private = True ) | |
gr.Markdown("<p align='center'><img src='https://i.imgur.com/ZxbbLUQ.png>'/></p>") | |
gr.Markdown("<h1 align='center'>BuggingSpace</h1>") | |
gr.Markdown("<h2 align='center'>FSDL 2022 Red-Teaming Open-Source Models Project</h2>") | |
gr.Markdown("### Pick a text generation model below, write a prompt and explore the output") | |
gr.Markdown("### Or compare the output of multiple models at the same time") | |
choose_mode = gr.Radio(choices=['Single Model', "Multi-Model"], | |
value='Single Model', | |
interactive=True, | |
visible=True, | |
show_label=False) | |
with gr.Group() as single_model: | |
gr.Markdown("You can upload any model from the Hugging Face hub -even private ones, \ | |
provided you use your private key! " | |
"Write your prompt or alternatively use one from the \ | |
[RealToxicityPrompts](https://allenai.org/data/real-toxicity-prompts) dataset.") | |
gr.Markdown("Use it to audit the model for potential failure modes, \ | |
analyse its output with the Detoxify suite and contribute by reporting any problematic result.") | |
gr.Markdown("Beware ! Generation can take up to a few minutes with very large models.") | |
with gr.Row(): | |
with gr.Column(scale=1): # input & prompts dataset exploration | |
gr.Markdown("### 1. Select a prompt", elem_id="inside_group") | |
input_text = gr.Textbox(label="Write your prompt below.", | |
interactive=True, | |
lines=4, | |
elem_id="inside_group") | |
gr.Markdown("β or β", elem_id="inside_group") | |
inspo_button = gr.Button('Click here if you need some inspiration', elem_id="inside_group") | |
prompts_drop = gr.Dropdown(visible=False, elem_id="inside_group") | |
randomize_button = gr.Button('Show another subset', visible=False, elem_id="inside_group") | |
show_params_checkbox_single = gr.Checkbox(label='Set custom params', | |
interactive=True, | |
value=False) | |
with gr.Box(visible=False) as params_box_single: | |
length_single = gr.Slider(label='Output length', | |
visible=True, | |
interactive=True, | |
minimum=50, | |
maximum=200, | |
value=75) | |
top_k_single = gr.Slider(label='top_k', | |
visible=True, | |
interactive=True, | |
minimum=1, | |
maximum=100, | |
value=50) | |
top_p_single = gr.Slider(label='top_p', | |
visible=True, | |
interactive=True, | |
minimum=0.1, | |
maximum=1, | |
value=0.95) | |
temperature_single = gr.Slider(label='temperature', | |
visible=True, | |
interactive=True, | |
minimum=0.1, | |
maximum=1, | |
value=0.7) | |
with gr.Column(scale=1): # Model choice & output | |
gr.Markdown("### 2. Evaluate output") | |
model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()), | |
label='Model', | |
interactive=True, | |
elem_id="inside_group") | |
search_bar = gr.Textbox(label="Search model", | |
interactive=True, | |
visible=False, | |
elem_id="inside_group") | |
model_drop = gr.Dropdown(visible=False) | |
private_checkbox = gr.Checkbox(visible=True,label="Private Model ?", elem_id="inside_group") | |
api_key_textbox = gr.Textbox(label="Enter your AUTH TOKEN below", | |
value=None, | |
interactive=True, | |
visible=False, | |
elem_id="pw") | |
generate_button = gr.Button('Submit your prompt', elem_id="inside_group") | |
output_spans = gr.HighlightedText(visible=True, label="Generated text") | |
flag_button = gr.Button("Report output here", visible=False, elem_id="inside_group") | |
with gr.Row(): # Flagging | |
with gr.Column(scale=1): | |
flag_radio = gr.Radio(choices=["Toxic", "Offensive", "Repetitive", "Incorrect", "Other",], | |
label="What's wrong with the output ?", | |
interactive=True, | |
visible=False, | |
elem_id="inside_group") | |
user_comment = gr.Textbox(label="(Optional) Briefly describe the issue", | |
visible=False, | |
interactive=True, | |
elem_id="inside_group") | |
confirm_flag_button = gr.Button("Confirm report", visible=False, elem_id="inside_group") | |
with gr.Row(): # Flagging success | |
success_message = gr.Markdown("Your report has been successfully registered. Thank you!", | |
visible=False, | |
elem_id="inside_group") | |
with gr.Row(): # Toxicity buttons | |
toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False, elem_id="inside_group") | |
toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False, elem_id="inside_group") | |
with gr.Row(): # Toxicity scores | |
toxi_scores_input = gr.JSON(label = "Detoxify classification of your input", | |
visible=False, | |
elem_id="inside_group") | |
toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output", | |
visible=False, | |
elem_id="inside_group") | |
toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output", | |
visible=False, | |
elem_id="inside_group") | |
with gr.Group(visible=False) as multi_model: | |
model_list = list() | |
gr.Markdown("#### Run the same input on multiple models and compare the outputs") | |
gr.Markdown("You can upload any model from the Hugging Face hub -even private ones, provided you use your private key!") | |
gr.Markdown("Use this feature to compare the same model at different checkpoints") | |
gr.Markdown('Or to benchmark your model against another one as a reference.') | |
gr.Markdown("Beware ! Generation can take up to a few minutes with very large models.") | |
with gr.Row(elem_id="inside_group"): | |
with gr.Column(): | |
models_multi = gr.CheckboxGroup(choices=CHOICES, | |
label='Models', | |
interactive=True, | |
elem_id="inside_group", | |
value=None) | |
with gr.Column(): | |
generate_button_multi = gr.Button('Submit your prompt',elem_id="inside_group") | |
show_params_checkbox_multi = gr.Checkbox(label='Set custom params', | |
interactive=True, | |
value=False) | |
with gr.Box(visible=False) as params_box_multi: | |
length_multi = gr.Slider(label='Output length', | |
visible=True, | |
interactive=True, | |
minimum=50, | |
maximum=200, | |
value=75) | |
top_k_multi = gr.Slider(label='top_k', | |
visible=True, | |
interactive=True, | |
minimum=1, | |
maximum=100, | |
value=50) | |
top_p_multi = gr.Slider(label='top_p', | |
visible=True, | |
interactive=True, | |
minimum=0.1, | |
maximum=1, | |
value=0.95) | |
temperature_multi = gr.Slider(label='temperature', | |
visible=True, | |
interactive=True, | |
minimum=0.1, | |
maximum=1, | |
value=0.7) | |
with gr.Row(elem_id="inside_group"): | |
with gr.Column(elem_id="inside_group", scale=1): | |
input_text_multi = gr.Textbox(label="Write your prompt below.", | |
interactive=True, | |
lines=4, | |
elem_id="inside_group") | |
with gr.Column(elem_id="inside_group", scale=1): | |
search_bar_multi = gr.Textbox(label="Search another model", | |
interactive=True, | |
visible=True, | |
elem_id="inside_group") | |
model_drop_multi = gr.Dropdown(visible=False, | |
show_progress=True, | |
elem_id="inside_group") | |
private_checkbox_multi = gr.Checkbox(visible=True,label="Private Model ?") | |
api_key_textbox_multi = gr.Textbox(label="Enter your AUTH TOKEN below", | |
value=None, | |
interactive=True, | |
visible=False, | |
elem_id="pw") | |
with gr.Row() as outputs_row: | |
for i in range(10): | |
output_spans_multi = gr.HighlightedText(visible=False, elem_id="inside_group") | |
model_list.append(output_spans_multi) | |
with gr.Row(): | |
gr.Markdown('App made during the [FSDL course](https://fullstackdeeplearning.com) \ | |
by Team53: Jean-Antoine, Sajenthan, Sashank, Kemp, Srihari, Astitwa') | |
# Single Model | |
choose_mode.change(fn=show_mode, | |
inputs=choose_mode, | |
outputs=[single_model, multi_model]) | |
inspo_button.click(fn=show_dataset, | |
inputs=dataset, | |
outputs=[prompts_drop, randomize_button, prompts_var]) | |
prompts_drop.change(fn=pass_to_textbox, | |
inputs=prompts_drop, | |
outputs=input_text) | |
randomize_button.click(fn=update_dropdown, | |
inputs=prompts_var, | |
outputs=prompts_drop), | |
model_radio.change(fn=show_search_bar, | |
inputs=model_radio, | |
outputs=[model_choice,search_bar]) | |
search_bar.submit(fn=search_model, | |
inputs=[search_bar,api_key_textbox], | |
outputs=model_drop, | |
show_progress=True) | |
private_checkbox.change(fn=show_api_key_textbox, | |
inputs=private_checkbox, | |
outputs=api_key_textbox) | |
model_drop.change(fn=forward_model_choice, | |
inputs=model_drop, | |
outputs=[model_choice,custom_model_path]) | |
generate_button.click(fn=process_user_input, | |
inputs=[model_choice, | |
api_key_textbox, | |
custom_model_path, | |
input_text, | |
length_single, | |
temperature_single, | |
top_p_single, | |
top_k_single], | |
outputs=[output_spans, | |
toxi_button, | |
flag_button, | |
input_var, | |
output_var], | |
show_progress=True) | |
toxi_button.click(fn=compute_toxi_output, | |
inputs=output_var, | |
outputs=[toxi_scores_output, toxi_button_compare], | |
show_progress=True) | |
toxi_button_compare.click(fn=compare_toxi_scores, | |
inputs=[input_text, toxi_scores_output], | |
outputs=[toxi_scores_input, toxi_scores_compare], | |
show_progress=True) | |
flag_button.click(fn=show_flag_choices, | |
inputs=None, | |
outputs=flag_radio) | |
flag_radio.change(fn=update_flag, | |
inputs=flag_radio, | |
outputs=[flag_choice, confirm_flag_button, user_comment, flag_button]) | |
flagging_callback.setup([input_var, output_var, model_choice, user_comment, flag_choice], "flagged_data_points") | |
confirm_flag_button.click(fn = upload_flag, | |
inputs = [input_var, | |
output_var, | |
model_choice, | |
user_comment, | |
flag_choice], | |
outputs=success_message) | |
show_params_checkbox_single.change(fn=show_params, | |
inputs=show_params_checkbox_single, | |
outputs=params_box_single) | |
# Model comparison | |
search_bar_multi.submit(fn=search_model, | |
inputs=[search_bar_multi, api_key_textbox_multi], | |
outputs=model_drop_multi, | |
show_progress=True) | |
show_params_checkbox_multi.change(fn=show_params, | |
inputs=show_params_checkbox_multi, | |
outputs=params_box_multi) | |
private_checkbox_multi.change(fn=show_api_key_textbox, | |
inputs=private_checkbox_multi, | |
outputs=api_key_textbox_multi) | |
model_drop_multi.change(fn=forward_model_choice_multi, | |
inputs=model_drop_multi, | |
outputs=[models_multi]) | |
models_multi.change(fn=show_choices_multi, | |
inputs=models_multi, | |
outputs=model_list) | |
generate_button_multi.click(fn=process_user_input_multi, | |
inputs=[models_multi, | |
input_text_multi, | |
api_key_textbox_multi, | |
length_multi, | |
temperature_multi, | |
top_p_multi, | |
top_k_multi], | |
outputs=model_list, | |
show_progress=True) | |
#demo.launch(debug=True) | |
if __name__ == "__main__": | |
demo.launch(enable_queue=False, debug=True) |