Spaces:
Runtime error
Runtime error
J-Antoine ZAGATO
commited on
Commit
β’
0509539
1
Parent(s):
ba936cb
Add app file
Browse files
app.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
from random import sample
|
7 |
+
from detoxify import Detoxify
|
8 |
+
from datasets import load_dataset
|
9 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
|
10 |
+
from transformers import BloomTokenizerFast, BloomForCausalLM
|
11 |
+
|
12 |
+
DATASET = "allenai/real-toxicity-prompts"
|
13 |
+
|
14 |
+
CHECKPOINTS = {
|
15 |
+
"DistilGPT2 by HuggingFace π€" : "distilgpt2",
|
16 |
+
"GPT-Neo 125M by EleutherAI π€" : "EleutherAI/gpt-neo-125M",
|
17 |
+
"BLOOM 560M by BigScience πΈ" : "bigscience/bloom-560m"
|
18 |
+
}
|
19 |
+
|
20 |
+
MODEL_CLASSES = {
|
21 |
+
"DistilGPT2 by HuggingFace π€" : (GPT2LMHeadModel, GPT2Tokenizer),
|
22 |
+
"GPT-Neo 125M by EleutherAI π€" : (GPTNeoForCausalLM, GPT2Tokenizer),
|
23 |
+
"BLOOM 560M by BigScience πΈ" : (BloomForCausalLM, BloomTokenizerFast),
|
24 |
+
}
|
25 |
+
|
26 |
+
def load_model(model_name):
|
27 |
+
model_class, tokenizer_class = MODEL_CLASSES[model_name]
|
28 |
+
|
29 |
+
model_path = CHECKPOINTS[model_name]
|
30 |
+
model = model_class.from_pretrained(model_path)
|
31 |
+
tokenizer = tokenizer_class.from_pretrained(model_path)
|
32 |
+
|
33 |
+
tokenizer.pad_token = tokenizer.eos_token
|
34 |
+
model.config.pad_token_id = model.config.eos_token_id
|
35 |
+
|
36 |
+
model.eval()
|
37 |
+
|
38 |
+
return model, tokenizer
|
39 |
+
|
40 |
+
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
41 |
+
|
42 |
+
def set_seed(seed, n_gpu):
|
43 |
+
np.random.seed(seed)
|
44 |
+
torch.manual_seed(seed)
|
45 |
+
if n_gpu > 0:
|
46 |
+
torch.cuda.manual_seed_all(seed)
|
47 |
+
|
48 |
+
def adjust_length_to_model(length, max_sequence_length):
|
49 |
+
if length < 0 and max_sequence_length > 0:
|
50 |
+
length = max_sequence_length
|
51 |
+
elif 0 < max_sequence_length < length:
|
52 |
+
length = max_sequence_length # No generation bigger than model size
|
53 |
+
elif length < 0:
|
54 |
+
length = MAX_LENGTH # avoid infinite loop
|
55 |
+
return length
|
56 |
+
|
57 |
+
def generate(model_name,
|
58 |
+
input_sentence,
|
59 |
+
length = 75,
|
60 |
+
temperature = 0.7,
|
61 |
+
top_k = 50,
|
62 |
+
top_p = 0.95,
|
63 |
+
seed = 42,
|
64 |
+
no_cuda = False,
|
65 |
+
num_return_sequences = 1,
|
66 |
+
stop_token = '.'
|
67 |
+
):
|
68 |
+
|
69 |
+
# load device
|
70 |
+
#if not no_cuda:
|
71 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
|
72 |
+
n_gpu = 0 if no_cuda else torch.cuda.device_count()
|
73 |
+
|
74 |
+
# Set seed
|
75 |
+
set_seed(seed, n_gpu)
|
76 |
+
|
77 |
+
# Load model
|
78 |
+
model, tokenizer = load_model(model_name)
|
79 |
+
model.to(device)
|
80 |
+
|
81 |
+
#length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
|
82 |
+
|
83 |
+
# Tokenize input
|
84 |
+
encoded_prompt = tokenizer.encode(input_sentence,
|
85 |
+
add_special_tokens=False,
|
86 |
+
return_tensors='pt')
|
87 |
+
|
88 |
+
encoded_prompt = encoded_prompt.to(device)
|
89 |
+
|
90 |
+
input_ids = encoded_prompt
|
91 |
+
|
92 |
+
# Generate output
|
93 |
+
output_sequences = model.generate(input_ids=input_ids,
|
94 |
+
max_length=length + len(encoded_prompt[0]),
|
95 |
+
temperature=temperature,
|
96 |
+
top_k=top_k,
|
97 |
+
top_p=top_p,
|
98 |
+
do_sample=True,
|
99 |
+
num_return_sequences=num_return_sequences
|
100 |
+
)
|
101 |
+
generated_sequences = list()
|
102 |
+
|
103 |
+
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
104 |
+
generated_sequence = generated_sequence.tolist()
|
105 |
+
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
106 |
+
#remove prompt
|
107 |
+
text = text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
|
108 |
+
|
109 |
+
#remove all text after last occurence of stop_token
|
110 |
+
text = text[:text.rfind(stop_token)+1]
|
111 |
+
|
112 |
+
generated_sequences.append(text)
|
113 |
+
|
114 |
+
return generated_sequences[0]
|
115 |
+
|
116 |
+
def prepare_dataset(dataset):
|
117 |
+
|
118 |
+
dataset = load_dataset(dataset, split='train')
|
119 |
+
return dataset
|
120 |
+
|
121 |
+
def load_prompts(dataset):
|
122 |
+
prompts = [dataset[i]['prompt']['text'] for i in range(len(dataset))]
|
123 |
+
return prompts
|
124 |
+
|
125 |
+
def random_sample(prompt_list):
|
126 |
+
random_sample = sample(prompt_list,10)
|
127 |
+
return random_sample
|
128 |
+
|
129 |
+
def show_dataset(dataset):
|
130 |
+
raw_data = prepare_dataset(dataset)
|
131 |
+
prompts = load_prompts(raw_data)
|
132 |
+
|
133 |
+
return (gr.update(choices=random_sample(prompts),
|
134 |
+
label='You can find below a random subset from the RealToxicityPrompts dataset',
|
135 |
+
visible=True),
|
136 |
+
gr.update(visible=True),
|
137 |
+
prompts,
|
138 |
+
)
|
139 |
+
|
140 |
+
def update_dropdown(prompts):
|
141 |
+
return gr.update(choices=random_sample(prompts))
|
142 |
+
|
143 |
+
def show_text(text):
|
144 |
+
new_text = "lol " + text
|
145 |
+
return gr.update(visible = True, value=new_text)
|
146 |
+
|
147 |
+
def process_user_input(model, input):
|
148 |
+
warning = 'Please enter a valid prompt.'
|
149 |
+
if input == None:
|
150 |
+
input = warning
|
151 |
+
generated = generate(model, input)
|
152 |
+
|
153 |
+
return (
|
154 |
+
gr.update(visible = True, value=generated),
|
155 |
+
gr.update(visible=True)
|
156 |
+
)
|
157 |
+
|
158 |
+
def pass_to_textbox(input):
|
159 |
+
return gr.update(value=input)
|
160 |
+
|
161 |
+
def run_detoxify(text):
|
162 |
+
results = Detoxify('original').predict(text)
|
163 |
+
json_ready_results = {cat:float(score) for (cat,score) in results.items()}
|
164 |
+
|
165 |
+
return gr.update(value=json_ready_results, visible=True)
|
166 |
+
|
167 |
+
|
168 |
+
with gr.Blocks() as demo:
|
169 |
+
gr.Markdown("# Project Interface proposal")
|
170 |
+
|
171 |
+
dataset = gr.Variable(value=DATASET)
|
172 |
+
prompts_var = gr.Variable(value=None)
|
173 |
+
|
174 |
+
with gr.Row(equal_height=True):
|
175 |
+
with gr.Column():
|
176 |
+
gr.Markdown("### 1. Select a prompt")
|
177 |
+
|
178 |
+
input_text = gr.Textbox(label="Write your prompt below.", interactive=True)
|
179 |
+
gr.Markdown("β or β")
|
180 |
+
inspo_button = gr.Button('Click here if you need some inspiration')
|
181 |
+
|
182 |
+
prompts_drop = gr.Dropdown(visible=False)
|
183 |
+
prompts_drop.change(fn=pass_to_textbox, inputs=prompts_drop, outputs=input_text)
|
184 |
+
|
185 |
+
randomize_button = gr.Button('Show another subset', visible=False)
|
186 |
+
|
187 |
+
inspo_button.click(fn=show_dataset, inputs=dataset, outputs=[prompts_drop, randomize_button, prompts_var])
|
188 |
+
randomize_button.click(fn=update_dropdown, inputs=prompts_var, outputs=prompts_drop)
|
189 |
+
|
190 |
+
with gr.Column():
|
191 |
+
|
192 |
+
gr.Markdown("### 2. Evaluate output")
|
193 |
+
|
194 |
+
generate_button = gr.Button('Pick a model below and submit your prompt')
|
195 |
+
model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
|
196 |
+
label='Model',
|
197 |
+
interactive=True)
|
198 |
+
model_choice = gr.Variable(value=None)
|
199 |
+
model_radio.change(fn=lambda value: value, inputs=model_radio, outputs=model_choice)
|
200 |
+
|
201 |
+
output_text = gr.Textbox(label="Generated prompt.", visible=False)
|
202 |
+
|
203 |
+
toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False)
|
204 |
+
toxi_scores = gr.JSON(visible=False)
|
205 |
+
|
206 |
+
|
207 |
+
generate_button.click(fn=process_user_input,
|
208 |
+
inputs=[model_choice, input_text],
|
209 |
+
outputs=[output_text,toxi_button])
|
210 |
+
|
211 |
+
toxi_button.click(fn=run_detoxify, inputs=output_text, outputs=toxi_scores)
|
212 |
+
|
213 |
+
#demo.launch(debug=True)
|
214 |
+
if __name__ == "__main__":
|
215 |
+
demo.launch(enable_queue=False)
|