Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import gradio as gr | |
from random import sample | |
from detoxify import Detoxify | |
from datasets import load_dataset | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM | |
from transformers import BloomTokenizerFast, BloomForCausalLM | |
DATASET = "allenai/real-toxicity-prompts" | |
CHECKPOINTS = { | |
"DistilGPT2 by HuggingFace π€" : "distilgpt2", | |
"GPT-Neo 125M by EleutherAI π€" : "EleutherAI/gpt-neo-125M", | |
"BLOOM 560M by BigScience πΈ" : "bigscience/bloom-560m" | |
} | |
MODEL_CLASSES = { | |
"DistilGPT2 by HuggingFace π€" : (GPT2LMHeadModel, GPT2Tokenizer), | |
"GPT-Neo 125M by EleutherAI π€" : (GPTNeoForCausalLM, GPT2Tokenizer), | |
"BLOOM 560M by BigScience πΈ" : (BloomForCausalLM, BloomTokenizerFast), | |
} | |
def load_model(model_name): | |
model_class, tokenizer_class = MODEL_CLASSES[model_name] | |
model_path = CHECKPOINTS[model_name] | |
model = model_class.from_pretrained(model_path) | |
tokenizer = tokenizer_class.from_pretrained(model_path) | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
model.eval() | |
return model, tokenizer | |
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop | |
def set_seed(seed, n_gpu): | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if n_gpu > 0: | |
torch.cuda.manual_seed_all(seed) | |
def adjust_length_to_model(length, max_sequence_length): | |
if length < 0 and max_sequence_length > 0: | |
length = max_sequence_length | |
elif 0 < max_sequence_length < length: | |
length = max_sequence_length # No generation bigger than model size | |
elif length < 0: | |
length = MAX_LENGTH # avoid infinite loop | |
return length | |
def generate(model_name, | |
input_sentence, | |
length = 75, | |
temperature = 0.7, | |
top_k = 50, | |
top_p = 0.95, | |
seed = 42, | |
no_cuda = False, | |
num_return_sequences = 1, | |
stop_token = '.' | |
): | |
# load device | |
#if not no_cuda: | |
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu") | |
n_gpu = 0 if no_cuda else torch.cuda.device_count() | |
# Set seed | |
set_seed(seed, n_gpu) | |
# Load model | |
model, tokenizer = load_model(model_name) | |
model.to(device) | |
#length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings) | |
# Tokenize input | |
encoded_prompt = tokenizer.encode(input_sentence, | |
add_special_tokens=False, | |
return_tensors='pt') | |
encoded_prompt = encoded_prompt.to(device) | |
input_ids = encoded_prompt | |
# Generate output | |
output_sequences = model.generate(input_ids=input_ids, | |
max_length=length + len(encoded_prompt[0]), | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
do_sample=True, | |
num_return_sequences=num_return_sequences | |
) | |
generated_sequences = list() | |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): | |
generated_sequence = generated_sequence.tolist() | |
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) | |
#remove prompt | |
text = text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] | |
#remove all text after last occurence of stop_token | |
text = text[:text.rfind(stop_token)+1] | |
generated_sequences.append(text) | |
return generated_sequences[0] | |
def prepare_dataset(dataset): | |
dataset = load_dataset(dataset, split='train') | |
return dataset | |
def load_prompts(dataset): | |
prompts = [dataset[i]['prompt']['text'] for i in range(len(dataset))] | |
return prompts | |
def random_sample(prompt_list): | |
random_sample = sample(prompt_list,10) | |
return random_sample | |
def show_dataset(dataset): | |
raw_data = prepare_dataset(dataset) | |
prompts = load_prompts(raw_data) | |
return (gr.update(choices=random_sample(prompts), | |
label='You can find below a random subset from the RealToxicityPrompts dataset', | |
visible=True), | |
gr.update(visible=True), | |
prompts, | |
) | |
def update_dropdown(prompts): | |
return gr.update(choices=random_sample(prompts)) | |
def show_text(text): | |
new_text = "lol " + text | |
return gr.update(visible = True, value=new_text) | |
def process_user_input(model, input): | |
warning = 'Please enter a valid prompt.' | |
if input == None: | |
input = warning | |
generated = generate(model, input) | |
return ( | |
gr.update(visible = True, value=generated), | |
gr.update(visible=True) | |
) | |
def pass_to_textbox(input): | |
return gr.update(value=input) | |
def run_detoxify(text): | |
results = Detoxify('original').predict(text) | |
json_ready_results = {cat:float(score) for (cat,score) in results.items()} | |
return gr.update(value=json_ready_results, visible=True) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Project Interface proposal") | |
dataset = gr.Variable(value=DATASET) | |
prompts_var = gr.Variable(value=None) | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
gr.Markdown("### 1. Select a prompt") | |
input_text = gr.Textbox(label="Write your prompt below.", interactive=True) | |
gr.Markdown("β or β") | |
inspo_button = gr.Button('Click here if you need some inspiration') | |
prompts_drop = gr.Dropdown(visible=False) | |
prompts_drop.change(fn=pass_to_textbox, inputs=prompts_drop, outputs=input_text) | |
randomize_button = gr.Button('Show another subset', visible=False) | |
inspo_button.click(fn=show_dataset, inputs=dataset, outputs=[prompts_drop, randomize_button, prompts_var]) | |
randomize_button.click(fn=update_dropdown, inputs=prompts_var, outputs=prompts_drop) | |
with gr.Column(): | |
gr.Markdown("### 2. Evaluate output") | |
generate_button = gr.Button('Pick a model below and submit your prompt') | |
model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()), | |
label='Model', | |
interactive=True) | |
model_choice = gr.Variable(value=None) | |
model_radio.change(fn=lambda value: value, inputs=model_radio, outputs=model_choice) | |
output_text = gr.Textbox(label="Generated prompt.", visible=False) | |
toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False) | |
toxi_scores = gr.JSON(visible=False) | |
generate_button.click(fn=process_user_input, | |
inputs=[model_choice, input_text], | |
outputs=[output_text,toxi_button]) | |
toxi_button.click(fn=run_detoxify, inputs=output_text, outputs=toxi_scores) | |
#demo.launch(debug=True) | |
if __name__ == "__main__": | |
demo.launch(enable_queue=False) |