J-Antoine ZAGATO commited on
Commit
0509539
β€’
1 Parent(s): ba936cb

Add app file

Browse files
Files changed (1) hide show
  1. app.py +215 -0
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import numpy as np
4
+ import gradio as gr
5
+
6
+ from random import sample
7
+ from detoxify import Detoxify
8
+ from datasets import load_dataset
9
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
10
+ from transformers import BloomTokenizerFast, BloomForCausalLM
11
+
12
+ DATASET = "allenai/real-toxicity-prompts"
13
+
14
+ CHECKPOINTS = {
15
+ "DistilGPT2 by HuggingFace πŸ€—" : "distilgpt2",
16
+ "GPT-Neo 125M by EleutherAI πŸ€–" : "EleutherAI/gpt-neo-125M",
17
+ "BLOOM 560M by BigScience 🌸" : "bigscience/bloom-560m"
18
+ }
19
+
20
+ MODEL_CLASSES = {
21
+ "DistilGPT2 by HuggingFace πŸ€—" : (GPT2LMHeadModel, GPT2Tokenizer),
22
+ "GPT-Neo 125M by EleutherAI πŸ€–" : (GPTNeoForCausalLM, GPT2Tokenizer),
23
+ "BLOOM 560M by BigScience 🌸" : (BloomForCausalLM, BloomTokenizerFast),
24
+ }
25
+
26
+ def load_model(model_name):
27
+ model_class, tokenizer_class = MODEL_CLASSES[model_name]
28
+
29
+ model_path = CHECKPOINTS[model_name]
30
+ model = model_class.from_pretrained(model_path)
31
+ tokenizer = tokenizer_class.from_pretrained(model_path)
32
+
33
+ tokenizer.pad_token = tokenizer.eos_token
34
+ model.config.pad_token_id = model.config.eos_token_id
35
+
36
+ model.eval()
37
+
38
+ return model, tokenizer
39
+
40
+ MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
41
+
42
+ def set_seed(seed, n_gpu):
43
+ np.random.seed(seed)
44
+ torch.manual_seed(seed)
45
+ if n_gpu > 0:
46
+ torch.cuda.manual_seed_all(seed)
47
+
48
+ def adjust_length_to_model(length, max_sequence_length):
49
+ if length < 0 and max_sequence_length > 0:
50
+ length = max_sequence_length
51
+ elif 0 < max_sequence_length < length:
52
+ length = max_sequence_length # No generation bigger than model size
53
+ elif length < 0:
54
+ length = MAX_LENGTH # avoid infinite loop
55
+ return length
56
+
57
+ def generate(model_name,
58
+ input_sentence,
59
+ length = 75,
60
+ temperature = 0.7,
61
+ top_k = 50,
62
+ top_p = 0.95,
63
+ seed = 42,
64
+ no_cuda = False,
65
+ num_return_sequences = 1,
66
+ stop_token = '.'
67
+ ):
68
+
69
+ # load device
70
+ #if not no_cuda:
71
+ device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
72
+ n_gpu = 0 if no_cuda else torch.cuda.device_count()
73
+
74
+ # Set seed
75
+ set_seed(seed, n_gpu)
76
+
77
+ # Load model
78
+ model, tokenizer = load_model(model_name)
79
+ model.to(device)
80
+
81
+ #length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
82
+
83
+ # Tokenize input
84
+ encoded_prompt = tokenizer.encode(input_sentence,
85
+ add_special_tokens=False,
86
+ return_tensors='pt')
87
+
88
+ encoded_prompt = encoded_prompt.to(device)
89
+
90
+ input_ids = encoded_prompt
91
+
92
+ # Generate output
93
+ output_sequences = model.generate(input_ids=input_ids,
94
+ max_length=length + len(encoded_prompt[0]),
95
+ temperature=temperature,
96
+ top_k=top_k,
97
+ top_p=top_p,
98
+ do_sample=True,
99
+ num_return_sequences=num_return_sequences
100
+ )
101
+ generated_sequences = list()
102
+
103
+ for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
104
+ generated_sequence = generated_sequence.tolist()
105
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
106
+ #remove prompt
107
+ text = text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
108
+
109
+ #remove all text after last occurence of stop_token
110
+ text = text[:text.rfind(stop_token)+1]
111
+
112
+ generated_sequences.append(text)
113
+
114
+ return generated_sequences[0]
115
+
116
+ def prepare_dataset(dataset):
117
+
118
+ dataset = load_dataset(dataset, split='train')
119
+ return dataset
120
+
121
+ def load_prompts(dataset):
122
+ prompts = [dataset[i]['prompt']['text'] for i in range(len(dataset))]
123
+ return prompts
124
+
125
+ def random_sample(prompt_list):
126
+ random_sample = sample(prompt_list,10)
127
+ return random_sample
128
+
129
+ def show_dataset(dataset):
130
+ raw_data = prepare_dataset(dataset)
131
+ prompts = load_prompts(raw_data)
132
+
133
+ return (gr.update(choices=random_sample(prompts),
134
+ label='You can find below a random subset from the RealToxicityPrompts dataset',
135
+ visible=True),
136
+ gr.update(visible=True),
137
+ prompts,
138
+ )
139
+
140
+ def update_dropdown(prompts):
141
+ return gr.update(choices=random_sample(prompts))
142
+
143
+ def show_text(text):
144
+ new_text = "lol " + text
145
+ return gr.update(visible = True, value=new_text)
146
+
147
+ def process_user_input(model, input):
148
+ warning = 'Please enter a valid prompt.'
149
+ if input == None:
150
+ input = warning
151
+ generated = generate(model, input)
152
+
153
+ return (
154
+ gr.update(visible = True, value=generated),
155
+ gr.update(visible=True)
156
+ )
157
+
158
+ def pass_to_textbox(input):
159
+ return gr.update(value=input)
160
+
161
+ def run_detoxify(text):
162
+ results = Detoxify('original').predict(text)
163
+ json_ready_results = {cat:float(score) for (cat,score) in results.items()}
164
+
165
+ return gr.update(value=json_ready_results, visible=True)
166
+
167
+
168
+ with gr.Blocks() as demo:
169
+ gr.Markdown("# Project Interface proposal")
170
+
171
+ dataset = gr.Variable(value=DATASET)
172
+ prompts_var = gr.Variable(value=None)
173
+
174
+ with gr.Row(equal_height=True):
175
+ with gr.Column():
176
+ gr.Markdown("### 1. Select a prompt")
177
+
178
+ input_text = gr.Textbox(label="Write your prompt below.", interactive=True)
179
+ gr.Markdown("β€” or β€”")
180
+ inspo_button = gr.Button('Click here if you need some inspiration')
181
+
182
+ prompts_drop = gr.Dropdown(visible=False)
183
+ prompts_drop.change(fn=pass_to_textbox, inputs=prompts_drop, outputs=input_text)
184
+
185
+ randomize_button = gr.Button('Show another subset', visible=False)
186
+
187
+ inspo_button.click(fn=show_dataset, inputs=dataset, outputs=[prompts_drop, randomize_button, prompts_var])
188
+ randomize_button.click(fn=update_dropdown, inputs=prompts_var, outputs=prompts_drop)
189
+
190
+ with gr.Column():
191
+
192
+ gr.Markdown("### 2. Evaluate output")
193
+
194
+ generate_button = gr.Button('Pick a model below and submit your prompt')
195
+ model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
196
+ label='Model',
197
+ interactive=True)
198
+ model_choice = gr.Variable(value=None)
199
+ model_radio.change(fn=lambda value: value, inputs=model_radio, outputs=model_choice)
200
+
201
+ output_text = gr.Textbox(label="Generated prompt.", visible=False)
202
+
203
+ toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False)
204
+ toxi_scores = gr.JSON(visible=False)
205
+
206
+
207
+ generate_button.click(fn=process_user_input,
208
+ inputs=[model_choice, input_text],
209
+ outputs=[output_text,toxi_button])
210
+
211
+ toxi_button.click(fn=run_detoxify, inputs=output_text, outputs=toxi_scores)
212
+
213
+ #demo.launch(debug=True)
214
+ if __name__ == "__main__":
215
+ demo.launch(enable_queue=False)