File size: 3,490 Bytes
d5b2eed
 
 
 
 
 
 
 
 
 
e3032e8
 
d5b2eed
 
 
 
 
 
 
7a4749b
d5b2eed
 
 
 
 
 
 
 
 
 
 
e3032e8
 
d5b2eed
 
e3032e8
d5b2eed
fb34e92
d5b2eed
 
 
fb34e92
 
d5b2eed
fb34e92
d5b2eed
 
fb34e92
e3032e8
d5b2eed
 
 
 
 
 
e3032e8
d5b2eed
 
 
 
 
 
 
 
 
 
a868cbd
6b8a006
a8e5895
d073ff1
6b8a006
d5b2eed
 
 
 
 
e3032e8
d5b2eed
a54b97e
 
d5b2eed
 
 
a54b97e
e22cf04
d5b2eed
 
73f2dfd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Basic example for doing model-in-the-loop dynamic adversarial data collection
# using Gradio Blocks.

import random
from urllib.parse import parse_qs

import gradio as gr
import requests
from transformers import pipeline

pipe = pipeline("sentiment-analysis")

demo = gr.Blocks()

with demo:
    total_cnt = 2 # How many examples per HIT
    dummy = gr.Textbox(visible=False)  # dummy for passing assignmentId

    # We keep track of state as a Variable
    state_dict = {"assignmentId": "", "cnt": 0, "fooled": 0, "data": [], "metadata": {}}
    state = gr.Variable(state_dict)

    gr.Markdown("# DADC in Gradio example")
    gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!")

    state_display = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)")

    # Generate model prediction
    # Default model: distilbert-base-uncased-finetuned-sst-2-english
    def _predict(txt, tgt, state):
        pred = pipe(txt)[0]
        other_label = 'negative' if pred['label'].lower() == "positive" else "positive"
        pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']}

        pred["label"] = pred["label"].title()
        ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n"
        if pred["label"] != tgt:
            state["fooled"] += 1
            ret += " You fooled the model! Well done!"
        else:
            ret += " You did not fool the model! Too bad, try again!"
        state["data"].append(ret)
        state["cnt"] += 1

        done = state["cnt"] == total_cnt
        toggle_final_submit = gr.update(visible=done)
        toggle_example_submit = gr.update(visible=not done)
        new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['fooled']} fooled)"
        return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, new_state_md

    # Input fields
    text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False)
    labels = ["Positive", "Negative"]
    random.shuffle(labels)
    label_input = gr.Radio(choices=labels, label="Target (correct) label")
    label_output = gr.Label()
    text_output = gr.Markdown()
    with gr.Column() as example_submit:
        submit_ex_button = gr.Button("Submit")
    with gr.Column(visible=False) as final_submit:
        submit_hit_button = gr.Button("Submit HIT")

    # Submit state to MTurk backend for ExternalQuestion
    # Update the URL below to switch from Sandbox to real data collection
    def _submit(state, dummy):
        query = parse_qs(dummy[1:])
        assert "assignmentId" in query, "No assignment ID provided, unable to submit"
        state["assignmentId"] = query["assignmentId"][0]
        url = f"https://workersandbox.mturk.com/mturk/externalSubmit?assignmentId={state['assignmentId']}&colorChoice=blue"
        x = requests.post(url)
        return str(x) + " With assignmentId " + state["assignmentId"] + "\n" + x.text, state, dummy

    # Button event handlers
    submit_ex_button.click(
        _predict,
        inputs=[text_input, label_input, state],
        outputs=[label_output, text_output, state, example_submit, final_submit, state_display],
    )

    response_output = gr.Markdown()
    submit_hit_button.click(
        _submit,
        inputs=[state, dummy],
        outputs=[response_output, state, dummy],
        _js="function(state, dummy) { console.log(window); return [state, window.location.search]; }",
    )

demo.launch()