Spaces:
Runtime error
Runtime error
J-Antoine ZAGATO
commited on
Commit
•
40d38f3
1
Parent(s):
5962754
Added multi model structure wo api key this time
Browse files
app.py
CHANGED
@@ -126,6 +126,18 @@ def generate(model_name,
|
|
126 |
|
127 |
return generated_sequences[0]
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
def prepare_dataset(dataset):
|
130 |
dataset = load_dataset(dataset, split='train')
|
131 |
return dataset
|
@@ -252,9 +264,14 @@ def upload_flag(*args):
|
|
252 |
if flagging_callback.flag(list(args), flag_option = None):
|
253 |
return gr.update(visible=True)
|
254 |
|
255 |
-
|
256 |
-
|
257 |
-
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
dataset = gr.Variable(value=DATASET)
|
260 |
prompts_var = gr.Variable(value=None)
|
@@ -264,76 +281,106 @@ with gr.Blocks() as demo:
|
|
264 |
custom_model_path = gr.Variable(value=None)
|
265 |
flag_choice = gr.Variable(label = "Flag", value=None)
|
266 |
|
267 |
-
|
268 |
flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
|
269 |
dataset_name = "fsdlredteam/flagged_2",
|
270 |
organization = "fsdlredteam",
|
271 |
private = True )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
-
|
274 |
|
275 |
-
|
276 |
-
gr.Markdown("### 1. Select a prompt")
|
277 |
|
278 |
-
|
|
|
|
|
279 |
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
-
|
285 |
|
286 |
-
|
287 |
|
288 |
-
|
289 |
-
with gr.Column(scale=1): # Model choice & output
|
290 |
-
gr.Markdown("### 2. Evaluate output")
|
291 |
|
|
|
292 |
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
296 |
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
output_spans = gr.HighlightedText(visible=True, label="Generated text")
|
303 |
|
304 |
-
|
305 |
-
|
306 |
-
with gr.Row(): # Flagging
|
307 |
-
|
308 |
-
with gr.Column(scale=1):
|
309 |
-
flag_radio = gr.Radio(choices=["Toxic", "Offensive", "Repetitive", "Incorrect", "Other",],
|
310 |
-
label="What's wrong with the output ?",
|
311 |
-
interactive=True,
|
312 |
-
visible=False)
|
313 |
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
|
|
335 |
|
336 |
|
|
|
|
|
|
|
|
|
337 |
inspo_button.click(fn=show_dataset,
|
338 |
inputs=dataset,
|
339 |
outputs=[prompts_drop, randomize_button, prompts_var])
|
|
|
126 |
|
127 |
return generated_sequences[0]
|
128 |
|
129 |
+
def show_mode(mode):
|
130 |
+
if mode == 'Single Model':
|
131 |
+
return (
|
132 |
+
gr.update(visible=True),
|
133 |
+
gr.update(visible=False)
|
134 |
+
)
|
135 |
+
if mode == 'Multi-Model':
|
136 |
+
return (
|
137 |
+
gr.update(visible=False),
|
138 |
+
gr.update(visible=True)
|
139 |
+
)
|
140 |
+
|
141 |
def prepare_dataset(dataset):
|
142 |
dataset = load_dataset(dataset, split='train')
|
143 |
return dataset
|
|
|
264 |
if flagging_callback.flag(list(args), flag_option = None):
|
265 |
return gr.update(visible=True)
|
266 |
|
267 |
+
CSS = """
|
268 |
+
#inside_group {
|
269 |
+
padding-top: 0.6em;
|
270 |
+
padding-bottom: 0.6em;
|
271 |
+
}
|
272 |
+
"""
|
273 |
+
|
274 |
+
with gr.Blocks(css=CSS) as demo:
|
275 |
|
276 |
dataset = gr.Variable(value=DATASET)
|
277 |
prompts_var = gr.Variable(value=None)
|
|
|
281 |
custom_model_path = gr.Variable(value=None)
|
282 |
flag_choice = gr.Variable(label = "Flag", value=None)
|
283 |
|
|
|
284 |
flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
|
285 |
dataset_name = "fsdlredteam/flagged_2",
|
286 |
organization = "fsdlredteam",
|
287 |
private = True )
|
288 |
+
|
289 |
+
gr.Markdown("# Project Interface proposal")
|
290 |
+
gr.Markdown("### Pick a text generation model below, write a prompt and explore the output")
|
291 |
+
gr.Markdown("### Or compare multiple models")
|
292 |
+
|
293 |
+
choose_mode = gr.Radio(choices=['Single Model', "Multi-Model"],
|
294 |
+
value='Single Model',
|
295 |
+
interactive=True,
|
296 |
+
visible=True,
|
297 |
+
show_label=False)
|
298 |
+
|
299 |
+
with gr.Group() as single_model:
|
300 |
+
with gr.Row():
|
301 |
+
|
302 |
+
with gr.Column(scale=1): # input & prompts dataset exploration
|
303 |
+
gr.Markdown("### 1. Select a prompt", elem_id="inside_group")
|
304 |
+
|
305 |
+
input_text = gr.Textbox(label="Write your prompt below.",
|
306 |
+
interactive=True,
|
307 |
+
lines=4,
|
308 |
+
elem_id="inside_group")
|
309 |
+
|
310 |
+
gr.Markdown("— or —", elem_id="inside_group")
|
311 |
+
|
312 |
+
inspo_button = gr.Button('Click here if you need some inspiration', elem_id="inside_group")
|
313 |
|
314 |
+
prompts_drop = gr.Dropdown(visible=False, elem_id="inside_group")
|
315 |
|
316 |
+
randomize_button = gr.Button('Show another subset', visible=False, elem_id="inside_group")
|
|
|
317 |
|
318 |
+
|
319 |
+
with gr.Column(scale=1): # Model choice & output
|
320 |
+
gr.Markdown("### 2. Evaluate output")
|
321 |
|
322 |
+
|
323 |
+
model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
|
324 |
+
label='Model',
|
325 |
+
interactive=True,
|
326 |
+
elem_id="inside_group")
|
327 |
+
|
328 |
+
search_bar = gr.Textbox(label="Search model",
|
329 |
+
interactive=True,
|
330 |
+
visible=False,
|
331 |
+
elem_id="inside_group")
|
332 |
+
model_drop = gr.Dropdown(visible=False)
|
333 |
|
334 |
+
generate_button = gr.Button('Submit your prompt')
|
335 |
|
336 |
+
output_spans = gr.HighlightedText(visible=True, label="Generated text", elem_id="inside_group")
|
337 |
|
338 |
+
flag_button = gr.Button("Report output here", visible=False)
|
|
|
|
|
339 |
|
340 |
+
with gr.Row(): # Flagging
|
341 |
|
342 |
+
with gr.Column(scale=1):
|
343 |
+
flag_radio = gr.Radio(choices=["Toxic", "Offensive", "Repetitive", "Incorrect", "Other",],
|
344 |
+
label="What's wrong with the output ?",
|
345 |
+
interactive=True,
|
346 |
+
visible=False,
|
347 |
+
elem_id="inside_group")
|
348 |
|
349 |
+
user_comment = gr.Textbox(label="(Optional) Briefly describe the issue",
|
350 |
+
visible=False,
|
351 |
+
interactive=True,
|
352 |
+
elem_id="inside_group")
|
|
|
|
|
353 |
|
354 |
+
confirm_flag_button = gr.Button("Confirm report", visible=False, elem_id="inside_group")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
|
356 |
+
with gr.Row(): # Flagging success
|
357 |
+
success_message = gr.Markdown("Your report has been successfully registered. Thank you!",
|
358 |
+
visible=False,
|
359 |
+
elem_id="inside_group")
|
360 |
+
|
361 |
+
with gr.Row(): # Toxicity buttons
|
362 |
+
toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False, elem_id="inside_group")
|
363 |
+
toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False, elem_id="inside_group")
|
364 |
+
|
365 |
+
with gr.Row(): # Toxicity scores
|
366 |
+
toxi_scores_input = gr.JSON(label = "Detoxify classification of your input",
|
367 |
+
visible=False,
|
368 |
+
elem_id="inside_group")
|
369 |
+
toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output",
|
370 |
+
visible=False,
|
371 |
+
elem_id="inside_group")
|
372 |
+
toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output",
|
373 |
+
visible=False,
|
374 |
+
elem_id="inside_group")
|
375 |
+
|
376 |
+
with gr.Group() as multi_model:
|
377 |
+
gr.Markdown("Model comparison will be here")
|
378 |
|
379 |
|
380 |
+
choose_mode.change(fn=show_mode,
|
381 |
+
inputs=choose_mode,
|
382 |
+
outputs=[single_model, multi_model])
|
383 |
+
|
384 |
inspo_button.click(fn=show_dataset,
|
385 |
inputs=dataset,
|
386 |
outputs=[prompts_drop, randomize_button, prompts_var])
|