sedrickkeh's picture
major updates demo v2
5a72dbb
raw
history blame
15.2 kB
import gradio as gr
from response_db import ResponseDb
from response_db import get_code
from create_cache import Game_Cache
import numpy as np
from PIL import Image
import pandas as pd
import torch
import pickle
import uuid
import nltk
nltk.download('punkt')
db = ResponseDb()
css = """
.chatbot {display:flex;flex-direction:column}
.msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.msg.user {background-color:cornflowerblue;color:white;align-self:self-end}
.msg.bot {background-color:lightgray}
.na_button {background-color:red;color:red}
"""
get_window_url_params = """
function(url_params) {
console.log(url_params);
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
return url_params;
}
"""
quals = {1001:99, 1002:136, 1003:56, 1004:105}
from model.run_question_asking_model import return_modules, return_modules_yn
question_model, response_model_simul, response_model_gtruth, caption_model = return_modules()
question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn = return_modules_yn()
class Game_Session:
def __init__(self, taskid, yn, hard_setting):
self.yn = yn
self.hard_setting = hard_setting
global question_model, response_model_simul, response_model_gtruth, caption_model
global question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn
self.question_model = question_model
self.response_model_simul = response_model_simul
self.response_model_gtruth = response_model_gtruth
self.caption_model = caption_model
self.question_model_yn = question_model_yn
self.response_model_simul_yn = response_model_simul_yn
self.response_model_gtruth_yn = response_model_gtruth_yn
self.caption_model_yn = caption_model_yn
global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
self.image_files, self.image_np, self.p_y_x, self.p_r_qy, self.p_y_xqr = None, None, None, None, None
self.captions, self.questions, self.target_questions = None, None, None
self.history = []
self.game_id = str(uuid.uuid4())
self.set_curr_models()
def set_curr_models(self):
if self.yn:
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul, self.curr_response_model_gtruth = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn, self.response_model_gtruth_yn
else:
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul, self.curr_response_model_gtruth = self.question_model, self.caption_model, self.response_model_simul, self.response_model_gtruth
def get_next_question(self):
return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul)
def get_model_gtruth_response(self, question):
return self.response_model_gtruth.get_response(question, self.images_np[0], self.captions[0], self.target_questions, is_a=self.yn)
def ask_a_question(input, taskid, gs):
# input = gs.get_model_gtruth_response(gs.history[-1])
if input not in ["n/a", "yes", "no"] and input not in gs.curr_response_model_simul.model.config.label2id:
html = "<div class='chatbot'>"
for m, msg in enumerate(gs.history):
cls = "bot" if m%2 == 0 else "user"
html += "<div class='msg {}'> {}</div>".format(cls, msg)
html += "</div>"
return html, gs, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(visible=False), gr.HTML.update(visible=True)
gs.history.append(input)
gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions)
gs.p_y_xqr = gs.p_y_x*gs.p_r_qy
gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr)
gs.p_y_x = gs.p_y_xqr
gs.questions.remove(gs.history[-2])
if taskid not in quals: db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1])
gs.history.append(gs.get_next_question())
top_prob = torch.max(gs.p_y_x).item()
top_pred = torch.argmax(gs.p_y_x).item()
if top_prob > 0.8 or len(gs.history) > 19:
gs.history = gs.history[:-1]
if taskid not in quals: db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "")
# write some HTML
html = "<div class='chatbot'>"
for m, msg in enumerate(gs.history):
cls = "bot" if m%2 == 0 else "user"
html += "<div class='msg {}'> {}</div>".format(cls, msg)
html += "</div>"
### Game finished:
if top_prob > 0.8 or len(gs.history) > 19:
html += f"<p>The model identified <b>Image {top_pred+1}</b> as the image. Please select a new task ID to continue.</p>"
finish_html = "<h2>Congratulations on finishing the game! Please copy the Task Finish Code below to MTurk to complete your task. You can now exit this window.</h2>"
return html, gs, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(value=get_code(taskid, gs.history, top_pred), visible=True), gr.HTML.update(value=finish_html, visible=True)
else:
if not gs.yn:
return html, gs, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Textbox.update(visible=False), gr.HTML.update(visible=False)
else:
return html, gs, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Textbox.update(visible=False), gr.HTML.update(visible=False)
def set_images(taskid):
pilot_study = pd.read_csv("pilot-study.csv")
taskid_original = taskid
if int(taskid) in quals: taskid = quals[int(taskid)]
taskid = pilot_study['mscoco-id'].tolist()[int(taskid)]
with open(f'cache-soft/{int(taskid)}.p', 'rb') as fp:
game_cache = pickle.load(fp)
gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting)
id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}"
id2 = f"./mscoco-images/val2014/{game_cache.image_files[1]}"
id3 = f"./mscoco-images/val2014/{game_cache.image_files[2]}"
id4 = f"./mscoco-images/val2014/{game_cache.image_files[3]}"
id5 = f"./mscoco-images/val2014/{game_cache.image_files[4]}"
id6 = f"./mscoco-images/val2014/{game_cache.image_files[5]}"
id7 = f"./mscoco-images/val2014/{game_cache.image_files[6]}"
id8 = f"./mscoco-images/val2014/{game_cache.image_files[7]}"
id9 = f"./mscoco-images/val2014/{game_cache.image_files[8]}"
id10 = f"./mscoco-images/val2014/{game_cache.image_files[9]}"
gs.image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
gs.image_files = [x[15:] for x in gs.image_files]
gs.images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in gs.image_files]
gs.images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in gs.images_np]
gs.p_y_x = (torch.ones(10)/10).to(gs.curr_question_model.device)
gs.captions = gs.curr_caption_model.get_captions(gs.image_files)
gs.questions, gs.target_questions = gs.curr_question_model.get_questions(gs.image_files, gs.captions, 0)
gs.curr_question_model.reset_question_bank()
gs.curr_question_model.question_bank = game_cache.question_dict
first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul)
first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
gs.history.append(first_question)
# html = f"<p>Current Task ID: <b>{int(taskid_original)}</b></p>"
if not gs.yn:
return id1, id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.Dropdown.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False)
else:
return id1, id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.Dropdown.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True)
def reset_dropdown():
return gr.Dropdown.update(visible=True, value='')
with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
gr.HTML("<h1>Image Q&A Guessing Game</h1>\
<p style='font-size:120%;'>\
Imagine you are playing 20-questions with an AI model.<br>\
The AI model plays the role of the question asker. You play the role of the responder. <br>\
There are 10 images. <b>Your image is Image 1</b>. The other images are distraction images.\
The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.<br>\
<span style='color: #0000ff'>The goal of the model is to accurately guess the correct image (i.e. <b><span style='color: #0000ff'>Image 1</span></b>) in as few turns as possible.<br>\
Your goal is to help the model guess the image by answering as clearly and accurately as possible.</span><br><br>\
(Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)<br><br>\
<b>Selecting N/A:</b><br>\
<ul style='font-size:120%;'>\
<li>In some games, there will be an N/A option. Please select N/A only if the question is unanswerable BECAUSE IT DOES NOT APPLY TO THE IMAGE.</li>\
<li>Otherwise, please select the closest possible option.</li>\
<li>e.g. Q:\"What is the dog doing?\" Please select N/A if there is no dog in the image.\
</ul> \
<br>")
with gr.Column():
with gr.Row():
taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", visible=False)
start_button = gr.Button("Start")
with gr.Column() as img_block:
with gr.Row():
img1 = gr.Image(label="Image 1", show_label=True)
img2 = gr.Image(label="Image 2", show_label=True)
img3 = gr.Image(label="Image 3", show_label=True)
img4 = gr.Image(label="Image 4", show_label=True)
img5 = gr.Image(label="Image 5", show_label=True)
with gr.Row():
img6 = gr.Image(label="Image 6", show_label=True)
img7 = gr.Image(label="Image 7", show_label=True)
img8 = gr.Image(label="Image 8", show_label=True)
img9 = gr.Image(label="Image 9", show_label=True)
img10 = gr.Image(label="Image 10", show_label=True)
conversation = gr.HTML()
game_session_state = gr.State()
# answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False)
full_vocab_dict = response_model_simul_yn.model.config.label2id
vocab_list_numbers, vocab_list_letters = [], []
for i in full_vocab_dict:
if i=="None" or i is None: continue
if i[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
vocab_list_numbers.append(i)
else:
vocab_list_letters.append(i)
with gr.Row():
answer = gr.Dropdown(vocab_list_letters+vocab_list_numbers, label="Answer the given question.", \
info="If you cannot find your exact answer, pick the word you feel would be most appropriate. ONLY SELECT N/A IF THE QUESTION DOES NOT APPLY TO THE IMAGE.", visible=False)
clear_box = gr.Button("Reset Selection \n(Use this to clear the dropdown selection.)", visible=False)
with gr.Row():
vocab_warning = gr.HTML("<h3>The word you typed in is not a valid word in the model vocabulary. Please clear it and select a valid word from the dropdown menu.</h3>", visible=False)
null_answer = gr.Textbox("n/a", visible=False)
yes_answer = gr.Textbox("yes", visible=False)
no_answer = gr.Textbox("no", visible=False)
with gr.Column():
with gr.Row():
yes_box = gr.Button("Yes", visible=False)
no_box = gr.Button("No", visible=False)
with gr.Column():
with gr.Row():
na_box = gr.Button("N/A", visible=False, elem_classes="na_button")
submit = gr.Button("Submit", visible=False)
with gr.Row():
reward_code = gr.Textbox("", label="Task Finish Code", visible=False)
with gr.Column() as img_block0:
with gr.Row():
img0 = gr.Image(label="Image 1", show_label=True).style(height=700, width=700)
### Button click events
start_button.click(fn=set_images, inputs=taskid, outputs=[img0, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box])
submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, clear_box, submit, taskid, start_button, yes_box, no_box, reward_code, vocab_warning])
clear_box.click(fn=reset_dropdown, inputs=[], outputs=[answer])
url_params = gr.JSON({}, visible=False, label="URL Params")
demo.load(fn = lambda url_params : gr.Number.update(value=int(url_params['p'])), inputs=[url_params], outputs=taskid, _js=get_window_url_params)
demo.launch()