File size: 3,258 Bytes
d5b2eed
 
 
 
 
 
 
 
 
 
e3032e8
 
d5b2eed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3032e8
 
d5b2eed
 
e3032e8
d5b2eed
 
 
 
 
 
 
 
 
 
 
 
e3032e8
d5b2eed
 
 
 
 
 
e3032e8
d5b2eed
 
 
 
 
 
 
 
 
 
 
 
d876d62
d5b2eed
 
 
 
 
 
e3032e8
d5b2eed
 
 
 
 
 
 
 
8c01c8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Basic example for doing model-in-the-loop dynamic adversarial data collection
# using Gradio Blocks.

import random
from urllib.parse import parse_qs

import gradio as gr
import requests
from transformers import pipeline

pipe = pipeline("sentiment-analysis")

demo = gr.Blocks()

with demo:
    total_cnt = 2 # How many examples per HIT
    dummy = gr.Textbox(visible=False)  # dummy for passing assignmentId

    # We keep track of state as a Variable
    state_dict = {"assignmentId": "", "cnt": 0, "fooled": 0, "data": [], "metadata": {}}
    state = gr.Variable(state_dict)

    gr.Markdown("# DADC in Gradio example")
    gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!")

    state_display = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)")

    # Generate model prediction
    # Default model: distilbert-base-uncased-finetuned-sst-2-english
    def _predict(txt, tgt, state):
        pred = pipe(txt)[0]
        other_label = 'negative' if pred['label'].lower() == "positive" else "positive"
        pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']}

        pred["label"] = pred["label"].title()
        ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n"
        if pred["label"] != tgt:
            state["fooled"] += 1
            ret += " You fooled the model! Well done!"
        else:
            ret += " You did not fool the model! Too bad, try again!"
        state["data"].append(ret)
        state["cnt"] += 1

        done = state["cnt"] == total_cnt
        toggle_final_submit = gr.update(visible=done)
        toggle_example_submit = gr.update(visible=not done)
        new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['fooled']} fooled)"
        return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, new_state_md

    # Input fields
    text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False)
    labels = ["Positive", "Negative"]
    random.shuffle(labels)
    label_input = gr.Radio(choices=labels, label="Target (correct) label")
    label_output = gr.Label()
    text_output = gr.Markdown()
    with gr.Column() as example_submit:
        submit_ex_button = gr.Button("Submit")
    with gr.Column(visible=False) as final_submit:
        submit_hit_button = gr.Button("Submit HIT")

    # Submit state to MTurk backend for ExternalQuestion
    # Update the URL below to switch from Sandbox to real data collection
    def _submit(state, dummy):
        query = parse_qs(dummy[1:])
        assert "assignmentId" in query, "No assignment ID provided, unable to submit"
        state["assignmentId"] = query["assignmentId"]
        url = "https://workersandbox.mturk.com/mturk/htmlSubmit"
        return requests.post(url, data=state)

    # Button event handlers
    submit_ex_button.click(
        _predict,
        inputs=[text_input, label_input, state],
        outputs=[label_output, text_output, state, example_submit, final_submit, state_display],
    )
    submit_hit_button.click(
        _submit,
        inputs=[state, dummy],
        outputs=None,
        _js="function(state, dummy) { return [state, window.location.search]; }",
    )

demo.launch()