BuggingSpace / app.py
J-Antoine ZAGATO
Add app file
0509539
raw
history blame
No virus
7.09 kB
import torch
import numpy as np
import gradio as gr
from random import sample
from detoxify import Detoxify
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
from transformers import BloomTokenizerFast, BloomForCausalLM
DATASET = "allenai/real-toxicity-prompts"
CHECKPOINTS = {
"DistilGPT2 by HuggingFace πŸ€—" : "distilgpt2",
"GPT-Neo 125M by EleutherAI πŸ€–" : "EleutherAI/gpt-neo-125M",
"BLOOM 560M by BigScience 🌸" : "bigscience/bloom-560m"
}
MODEL_CLASSES = {
"DistilGPT2 by HuggingFace πŸ€—" : (GPT2LMHeadModel, GPT2Tokenizer),
"GPT-Neo 125M by EleutherAI πŸ€–" : (GPTNeoForCausalLM, GPT2Tokenizer),
"BLOOM 560M by BigScience 🌸" : (BloomForCausalLM, BloomTokenizerFast),
}
def load_model(model_name):
model_class, tokenizer_class = MODEL_CLASSES[model_name]
model_path = CHECKPOINTS[model_name]
model = model_class.from_pretrained(model_path)
tokenizer = tokenizer_class.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
model.eval()
return model, tokenizer
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
def set_seed(seed, n_gpu):
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(seed)
def adjust_length_to_model(length, max_sequence_length):
if length < 0 and max_sequence_length > 0:
length = max_sequence_length
elif 0 < max_sequence_length < length:
length = max_sequence_length # No generation bigger than model size
elif length < 0:
length = MAX_LENGTH # avoid infinite loop
return length
def generate(model_name,
input_sentence,
length = 75,
temperature = 0.7,
top_k = 50,
top_p = 0.95,
seed = 42,
no_cuda = False,
num_return_sequences = 1,
stop_token = '.'
):
# load device
#if not no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
n_gpu = 0 if no_cuda else torch.cuda.device_count()
# Set seed
set_seed(seed, n_gpu)
# Load model
model, tokenizer = load_model(model_name)
model.to(device)
#length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
# Tokenize input
encoded_prompt = tokenizer.encode(input_sentence,
add_special_tokens=False,
return_tensors='pt')
encoded_prompt = encoded_prompt.to(device)
input_ids = encoded_prompt
# Generate output
output_sequences = model.generate(input_ids=input_ids,
max_length=length + len(encoded_prompt[0]),
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
num_return_sequences=num_return_sequences
)
generated_sequences = list()
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
#remove prompt
text = text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
#remove all text after last occurence of stop_token
text = text[:text.rfind(stop_token)+1]
generated_sequences.append(text)
return generated_sequences[0]
def prepare_dataset(dataset):
dataset = load_dataset(dataset, split='train')
return dataset
def load_prompts(dataset):
prompts = [dataset[i]['prompt']['text'] for i in range(len(dataset))]
return prompts
def random_sample(prompt_list):
random_sample = sample(prompt_list,10)
return random_sample
def show_dataset(dataset):
raw_data = prepare_dataset(dataset)
prompts = load_prompts(raw_data)
return (gr.update(choices=random_sample(prompts),
label='You can find below a random subset from the RealToxicityPrompts dataset',
visible=True),
gr.update(visible=True),
prompts,
)
def update_dropdown(prompts):
return gr.update(choices=random_sample(prompts))
def show_text(text):
new_text = "lol " + text
return gr.update(visible = True, value=new_text)
def process_user_input(model, input):
warning = 'Please enter a valid prompt.'
if input == None:
input = warning
generated = generate(model, input)
return (
gr.update(visible = True, value=generated),
gr.update(visible=True)
)
def pass_to_textbox(input):
return gr.update(value=input)
def run_detoxify(text):
results = Detoxify('original').predict(text)
json_ready_results = {cat:float(score) for (cat,score) in results.items()}
return gr.update(value=json_ready_results, visible=True)
with gr.Blocks() as demo:
gr.Markdown("# Project Interface proposal")
dataset = gr.Variable(value=DATASET)
prompts_var = gr.Variable(value=None)
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("### 1. Select a prompt")
input_text = gr.Textbox(label="Write your prompt below.", interactive=True)
gr.Markdown("β€” or β€”")
inspo_button = gr.Button('Click here if you need some inspiration')
prompts_drop = gr.Dropdown(visible=False)
prompts_drop.change(fn=pass_to_textbox, inputs=prompts_drop, outputs=input_text)
randomize_button = gr.Button('Show another subset', visible=False)
inspo_button.click(fn=show_dataset, inputs=dataset, outputs=[prompts_drop, randomize_button, prompts_var])
randomize_button.click(fn=update_dropdown, inputs=prompts_var, outputs=prompts_drop)
with gr.Column():
gr.Markdown("### 2. Evaluate output")
generate_button = gr.Button('Pick a model below and submit your prompt')
model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
label='Model',
interactive=True)
model_choice = gr.Variable(value=None)
model_radio.change(fn=lambda value: value, inputs=model_radio, outputs=model_choice)
output_text = gr.Textbox(label="Generated prompt.", visible=False)
toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False)
toxi_scores = gr.JSON(visible=False)
generate_button.click(fn=process_user_input,
inputs=[model_choice, input_text],
outputs=[output_text,toxi_button])
toxi_button.click(fn=run_detoxify, inputs=output_text, outputs=toxi_scores)
#demo.launch(debug=True)
if __name__ == "__main__":
demo.launch(enable_queue=False)