BuggingSpace / app.py
J-Antoine ZAGATO
Added toxicity comparison & flagging + refactoring
9d80551
raw
history blame
10.7 kB
import os
import torch
import numpy as np
import gradio as gr
from random import sample
from detoxify import Detoxify
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
from transformers import BloomTokenizerFast, BloomForCausalLM
HF_AUTH_TOKEN = os.environ.get('hf_token' or True)
DATASET = "allenai/real-toxicity-prompts"
CHECKPOINTS = {
"DistilGPT2 by HuggingFace πŸ€—" : "distilgpt2",
"GPT-Neo 125M by EleutherAI πŸ€–" : "EleutherAI/gpt-neo-125M",
"BLOOM 560M by BigScience 🌸" : "bigscience/bloom-560m"
}
MODEL_CLASSES = {
"DistilGPT2 by HuggingFace πŸ€—" : (GPT2LMHeadModel, GPT2Tokenizer),
"GPT-Neo 125M by EleutherAI πŸ€–" : (GPTNeoForCausalLM, GPT2Tokenizer),
"BLOOM 560M by BigScience 🌸" : (BloomForCausalLM, BloomTokenizerFast),
}
def load_model(model_name):
model_class, tokenizer_class = MODEL_CLASSES[model_name]
model_path = CHECKPOINTS[model_name]
model = model_class.from_pretrained(model_path)
tokenizer = tokenizer_class.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
model.eval()
return model, tokenizer
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
def set_seed(seed, n_gpu):
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(seed)
def adjust_length_to_model(length, max_sequence_length):
if length < 0 and max_sequence_length > 0:
length = max_sequence_length
elif 0 < max_sequence_length < length:
length = max_sequence_length # No generation bigger than model size
elif length < 0:
length = MAX_LENGTH # avoid infinite loop
return length
def generate(model_name,
input_sentence,
length = 75,
temperature = 0.7,
top_k = 50,
top_p = 0.95,
seed = 42,
no_cuda = False,
num_return_sequences = 1,
stop_token = '.'
):
# load device
#if not no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
n_gpu = 0 if no_cuda else torch.cuda.device_count()
# Set seed
set_seed(seed, n_gpu)
# Load model
model, tokenizer = load_model(model_name)
model.to(device)
#length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
# Tokenize input
encoded_prompt = tokenizer.encode(input_sentence,
add_special_tokens=False,
return_tensors='pt')
encoded_prompt = encoded_prompt.to(device)
input_ids = encoded_prompt
# Generate output
output_sequences = model.generate(input_ids=input_ids,
max_length=length + len(encoded_prompt[0]),
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
num_return_sequences=num_return_sequences
)
generated_sequences = list()
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
#remove prompt
text = text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
#remove all text after last occurence of stop_token
text = text[:text.rfind(stop_token)+1]
generated_sequences.append(text)
return generated_sequences[0]
def prepare_dataset(dataset):
dataset = load_dataset(dataset, split='train')
return dataset
def load_prompts(dataset):
prompts = [dataset[i]['prompt']['text'] for i in range(len(dataset))]
return prompts
def random_sample(prompt_list):
random_sample = sample(prompt_list,10)
return random_sample
def show_dataset(dataset):
raw_data = prepare_dataset(dataset)
prompts = load_prompts(raw_data)
return (gr.update(choices=random_sample(prompts),
label='You can find below a random subset from the RealToxicityPrompts dataset',
visible=True),
gr.update(visible=True),
prompts,
)
def update_dropdown(prompts):
return gr.update(choices=random_sample(prompts))
def process_user_input(model, input):
warning = 'Please enter a valid prompt.'
if input == None:
generated = warning
else:
generated = generate(model, input)
return (
gr.update(visible = True, value=generated),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
input,
generated
)
def pass_to_textbox(input):
return gr.update(value=input)
def run_detoxify(text):
results = Detoxify('original').predict(text)
json_ready_results = {cat:float(score) for (cat,score) in results.items()}
return json_ready_results
def compute_toxi_output(output_text):
scores = run_detoxify(output_text)
return (
gr.update(value=scores, visible=True),
gr.update(visible=True)
)
def compute_change(input, output):
change_percent = round(((float(output)-input)/input)*100, 2)
return change_percent
def compare_toxi_scores(input_text, output_scores):
input_scores = run_detoxify(input_text)
json_ready_results = {cat:float(score) for (cat,score) in input_scores.items()}
compare_scores = {
cat:compute_change(json_ready_results[cat], output_scores[cat])
for cat in json_ready_results
for cat in output_scores
}
return (
gr.update(value=json_ready_results, visible=True),
gr.update(value=compare_scores, visible=True)
)
with gr.Blocks() as demo:
gr.Markdown("# Project Interface proposal")
gr.Markdown("### Write description and user instructions here")
dataset = gr.Variable(value=DATASET)
prompts_var = gr.Variable(value=None)
input_var = gr.Variable(label="Input Prompt", value=None)
output_var = gr.Variable(label="Output",value=None)
flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
dataset_name = "fsdlredteam/flagged",
organization = "fsdlredteam",
private = True )
with gr.Row(equal_height=True):
with gr.Column(): # input & prompts dataset exploration
gr.Markdown("### 1. Select a prompt")
input_text = gr.Textbox(label="Write your prompt below.", interactive=True, lines=4)
gr.Markdown("β€” or β€”")
inspo_button = gr.Button('Click here if you need some inspiration')
prompts_drop = gr.Dropdown(visible=False)
prompts_drop.change(fn=pass_to_textbox, inputs=prompts_drop, outputs=input_text)
randomize_button = gr.Button('Show another subset', visible=False)
with gr.Column(): # Model choice & output
gr.Markdown("### 2. Evaluate output")
generate_button = gr.Button('Pick a model below and submit your prompt')
model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
label='Model',
interactive=True)
model_choice = gr.Variable(value=None)
model_radio.change(fn=lambda value: value, inputs=model_radio, outputs=model_choice)
output_text = gr.Textbox(label="Generated prompt.", visible=False)
with gr.Row(equal_height=True): # Flagging
flagging_callback.setup([input_text, output_text, model_radio], "flagged_data_points")
toxi_flag_button = gr.Button("Report toxic output here", visible=False)
unexpected_flag_button = gr.Button("Report incorrect output here", visible=False)
other_flag_button = gr.Button("Report other inappropriate output here", visible=False)
with gr.Row(equal_height=True): # Toxicity buttons
toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False)
toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False)
with gr.Row(equal_height=True): # Toxicity scores
toxi_scores_input = gr.JSON(label = "Detoxify classification of your input", visible=False)
toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output", visible=False)
toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output", visible=False)
inspo_button.click(fn=show_dataset,
inputs=dataset,
outputs=[prompts_drop, randomize_button, prompts_var])
randomize_button.click(fn=update_dropdown,
inputs=prompts_var,
outputs=prompts_drop)
generate_button.click(fn=process_user_input,
inputs=[model_choice, input_text],
outputs=[output_text,
toxi_button,
toxi_flag_button,
unexpected_flag_button,
other_flag_button,
input_var,
output_var])
toxi_button.click(fn=compute_toxi_output,
inputs=output_text,
outputs=[toxi_scores_output, toxi_button_compare])
toxi_button_compare.click(fn=compare_toxi_scores,
inputs=[input_text, toxi_scores_output],
outputs=[toxi_scores_input, toxi_scores_compare])
toxi_flag_button.click(lambda *args: flagging_callback.flag(args, flag_option = "toxic"),
inputs=[input_text, output_text, model_radio],
outputs=None,
preprocess=False)
unexpected_flag_button.click(lambda *args: flagging_callback.flag(args, flag_option = "unexpected"),
inputs=[input_text, output_text, model_radio],
outputs=None,
preprocess=False)
other_flag_button.click(lambda *args: flagging_callback.flag(args, flag_option = "other"),
inputs=[input_text, output_text, model_radio],
outputs=None,
preprocess=False)
#demo.launch(debug=True)
if __name__ == "__main__":
demo.launch(enable_queue=False)