dadc / app.py
Tristan Thrush
bugfixes
29025ba
raw
history blame
6.02 kB
# Basic example for doing model-in-the-loop dynamic adversarial data collection
# using Gradio Blocks.
import os
import random
from urllib.parse import parse_qs
import gradio as gr
import requests
from transformers import pipeline
from huggingface_hub import Repository
from dotenv import load_dotenv
from pathlib import Path
import json
# These variables are for storing the mturk HITs in a Hugging Face dataset.
if Path(".env").is_file():
load_dotenv(".env")
DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
HF_TOKEN = os.getenv("HF_TOKEN")
DATA_FILENAME = "data.jsonl"
DATA_FILE = os.path.join("data", DATA_FILENAME)
repo = Repository(
local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
)
# Now let's run the app!
pipe = pipeline("sentiment-analysis")
demo = gr.Blocks()
with demo:
total_cnt = 2 # How many examples per HIT
dummy = gr.Textbox(visible=False) # dummy for passing assignmentId
# We keep track of state as a JSON
state_dict = {"assignmentId": "", "cnt": 0, "fooled": 0, "data": [], "metadata": {}}
state = gr.JSON(state_dict, visible=False)
gr.Markdown("# DADC in Gradio example")
gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!")
state_display = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)")
# Generate model prediction
# Default model: distilbert-base-uncased-finetuned-sst-2-english
def _predict(txt, tgt, state, dummy):
pred = pipe(txt)[0]
other_label = 'negative' if pred['label'].lower() == "positive" else "positive"
pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']}
pred["label"] = pred["label"].title()
ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n"
if pred["label"] != tgt:
state["fooled"] += 1
ret += " You fooled the model! Well done!"
else:
ret += " You did not fool the model! Too bad, try again!"
state["data"].append(ret)
state["cnt"] += 1
done = state["cnt"] == total_cnt
toggle_final_submit = gr.update(visible=done)
toggle_example_submit = gr.update(visible=not done)
new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['fooled']} fooled)"
query = parse_qs(dummy[1:])
if "assignmentId" in query:
# It seems that someone is using this app on mturk. We need to
# store the assignmentId in the state before submit_hit_button
# is clicked. We can do this here in _predict. We need to save the
# assignmentId so that the turker can get credit for their hit.
state["assignmentId"] = query["assignmentId"][0]
return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, new_state_md, dummy
# Input fields
text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False)
labels = ["Positive", "Negative"]
random.shuffle(labels)
label_input = gr.Radio(choices=labels, label="Target (correct) label")
label_output = gr.Label()
text_output = gr.Markdown()
with gr.Column() as example_submit:
submit_ex_button = gr.Button("Submit")
with gr.Column(visible=False) as final_submit:
submit_hit_button = gr.Button("Submit HIT")
# Store the HIT data into a Hugging Face dataset.
# The HIT is also stored and logged on mturk when post_hit_js is run below.
# This _store_in_huggingface_dataset function just demonstrates how easy it is
# to automatically create a Hugging Face dataset from mturk.
def _store_in_huggingface_dataset(state):
with open(DATA_FILE, "a") as jsonlfile:
jsonlfile.write(json.dumps(state) + "\n")
repo.push_to_hub()
if state["assignmentId"] == "":
# If assignmentId is not set, then someone is using this app on
# huggingface.co, and we can reset the app to it's initial state
# after they submit their fake "HIT".
state = {"assignmentId": "", "cnt": 0, "fooled": 0, "data": [], "metadata": {}}
toggle_final_submit = gr.update(visible=False)
toggle_example_submit = gr.update(visible=True)
new_state_md = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)")
return state, toggle_final_submit, toggle_example_submit, new_state_md
# Button event handlers
get_window_location_search_js = """
function(text_input, label_input, state, dummy) {
return [text_input, label_input, state, window.location.search];
}
"""
submit_ex_button.click(
_predict,
inputs=[text_input, label_input, state, dummy],
outputs=[label_output, text_output, state, example_submit, final_submit, state_display, dummy],
_js=get_window_location_search_js,
)
post_hit_js = """
function(state, toggle_final_submit, toggle_example_submit, new_state_md) {
if (state["assignmentId"] !== "") {
//It seems that someone is using this app on mturk, so submit the HIT.
const form = document.createElement('form');
form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit';
form.method = 'post';
for (const key in state) {
const hiddenField = document.createElement('input');
hiddenField.type = 'hidden';
hiddenField.name = key;
hiddenField.value = state[key];
form.appendChild(hiddenField)
};
document.body.appendChild(form);
form.submit();
}
return [state];
}
"""
submit_hit_button.click(
_store_in_huggingface_dataset,
inputs=[state],
outputs=[state, final_submit, example_submit, state_display],
_js=post_hit_js,
)
demo.launch()