Update app.py

#2
by abidlabs HF staff - opened
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -8,6 +8,8 @@ import gradio as gr
8
  import requests
9
  from transformers import pipeline
10
 
 
 
11
  demo = gr.Blocks()
12
 
13
  with demo:
@@ -26,11 +28,12 @@ with demo:
26
  # Generate model prediction
27
  # Default model: distilbert-base-uncased-finetuned-sst-2-english
28
  def _predict(txt, tgt, state):
29
- pipe = pipeline("sentiment-analysis")
30
  pred = pipe(txt)[0]
 
 
31
 
32
  pred["label"] = pred["label"].title()
33
- ret = f"Target: {tgt}. Model prediction: {pred['label']} ({pred['score']} confidence). {pred['label'] != tgt}\n\n"
34
  if pred["label"] != tgt:
35
  state["fooled"] += 1
36
  ret += " You fooled the model! Well done!"
@@ -43,13 +46,14 @@ with demo:
43
  toggle_final_submit = gr.update(visible=done)
44
  toggle_example_submit = gr.update(visible=not done)
45
  new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['fooled']} fooled)"
46
- return ret, state, toggle_example_submit, toggle_final_submit, new_state_md
47
 
48
  # Input fields
49
  text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False)
50
  labels = ["Positive", "Negative"]
51
  random.shuffle(labels)
52
  label_input = gr.Radio(choices=labels, label="Target (correct) label")
 
53
  text_output = gr.Markdown()
54
  with gr.Column() as example_submit:
55
  submit_ex_button = gr.Button("Submit")
@@ -69,7 +73,7 @@ with demo:
69
  submit_ex_button.click(
70
  _predict,
71
  inputs=[text_input, label_input, state],
72
- outputs=[text_output, state, example_submit, final_submit, state_display],
73
  )
74
 
75
  submit_hit_button.click(
@@ -79,4 +83,4 @@ with demo:
79
  _js="function(state, dummy) { return [state, window.location.search]; }",
80
  )
81
 
82
- demo.launch(favicon_path="https://huggingface.co/favicon.ico")
 
8
  import requests
9
  from transformers import pipeline
10
 
11
+ pipe = pipeline("sentiment-analysis")
12
+
13
  demo = gr.Blocks()
14
 
15
  with demo:
 
28
  # Generate model prediction
29
  # Default model: distilbert-base-uncased-finetuned-sst-2-english
30
  def _predict(txt, tgt, state):
 
31
  pred = pipe(txt)[0]
32
+ other_label = 'negative' if pred['label'].lower() == "positive" else "positive"
33
+ pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']}
34
 
35
  pred["label"] = pred["label"].title()
36
+ ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n"
37
  if pred["label"] != tgt:
38
  state["fooled"] += 1
39
  ret += " You fooled the model! Well done!"
 
46
  toggle_final_submit = gr.update(visible=done)
47
  toggle_example_submit = gr.update(visible=not done)
48
  new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['fooled']} fooled)"
49
+ return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, new_state_md
50
 
51
  # Input fields
52
  text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False)
53
  labels = ["Positive", "Negative"]
54
  random.shuffle(labels)
55
  label_input = gr.Radio(choices=labels, label="Target (correct) label")
56
+ label_output = gr.Label()
57
  text_output = gr.Markdown()
58
  with gr.Column() as example_submit:
59
  submit_ex_button = gr.Button("Submit")
 
73
  submit_ex_button.click(
74
  _predict,
75
  inputs=[text_input, label_input, state],
76
+ outputs=[label_output, text_output, state, example_submit, final_submit, state_display],
77
  )
78
 
79
  submit_hit_button.click(
 
83
  _js="function(state, dummy) { return [state, window.location.search]; }",
84
  )
85
 
86
+ demo.launch(favicon_path="https://huggingface.co/favicon.ico")