Spaces:
Runtime error
Runtime error
File size: 11,206 Bytes
022601f 1de1fd2 016285f 2424844 1de1fd2 016285f 022601f 016285f 022601f 016285f 022601f 016285f 022601f 016285f 022601f 016285f 022601f 016285f 022601f 016285f 022601f 016285f 022601f 016285f 022601f 016285f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import gradio as gr
from response_db import ResponseDb
from create_cache import Game_Cache
import numpy as np
from PIL import Image
import pandas as pd
import torch
import pickle
import uuid
import nltk
nltk.download('punkt')
db = ResponseDb()
css = """
.chatbot {display:flex;flex-direction:column}
.msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.msg.user {background-color:cornflowerblue;color:white;align-self:self-end}
.msg.bot {background-color:lightgray}
.na_button {background-color:red;color:red}
"""
from model.run_question_asking_model import return_modules, return_modules_yn
question_model, response_model_simul, _, caption_model = return_modules()
question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn()
class Game_Session:
def __init__(self, taskid, yn, hard_setting):
self.yn = yn
self.hard_setting = hard_setting
global question_model, response_model_simul, caption_model
global question_model_yn, response_model_simul_yn, caption_model_yn
self.question_model = question_model
self.response_model_simul = response_model_simul
self.caption_model = caption_model
self.question_model_yn = question_model_yn
self.response_model_simul_yn = response_model_simul_yn
self.caption_model_yn = caption_model_yn
global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
self.image_files, self.image_np, self.p_y_x, self.p_r_qy, self.p_y_xqr = None, None, None, None, None
self.captions, self.questions, self.target_questions = None, None, None
self.history = []
self.game_id = str(uuid.uuid4())
self.set_curr_models()
def set_curr_models(self):
if self.yn:
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn
else:
self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model, self.caption_model, self.response_model_simul
def get_next_question(self):
return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul)
def ask_a_question(input, taskid, gs):
gs.history.append(input)
gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions)
gs.p_y_xqr = gs.p_y_x*gs.p_r_qy
gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr)
gs.p_y_x = gs.p_y_xqr
gs.questions.remove(gs.history[-2])
db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1])
gs.history.append(gs.get_next_question())
top_prob = torch.max(gs.p_y_x).item()
top_pred = torch.argmax(gs.p_y_x).item()
if top_prob > 0.8:
gs.history = gs.history[:-1]
db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "")
# write some HTML
html = "<div class='chatbot'>"
for m, msg in enumerate(gs.history):
if msg=="nothing": msg="n/a"
cls = "bot" if m%2 == 0 else "user"
html += "<div class='msg {}'> {}</div>".format(cls, msg)
html += "</div>"
### Game finished:
if top_prob > 0.8:
html += f"<p>The model identified <b>Image {top_pred+1}</b> as the image. Please select a new task ID to continue.</p>"
return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False)
else:
if not gs.yn:
return html, gs, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False)
else:
return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True)
def set_images(taskid):
pilot_study = pd.read_csv("pilot-study.csv")
taskid_original = taskid
taskid = pilot_study['mscoco-id'].tolist()[int(taskid)]
with open(f'cache/{int(taskid)}.p', 'rb') as fp:
game_cache = pickle.load(fp)
gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting)
id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}"
id2 = f"./mscoco-images/val2014/{game_cache.image_files[1]}"
id3 = f"./mscoco-images/val2014/{game_cache.image_files[2]}"
id4 = f"./mscoco-images/val2014/{game_cache.image_files[3]}"
id5 = f"./mscoco-images/val2014/{game_cache.image_files[4]}"
id6 = f"./mscoco-images/val2014/{game_cache.image_files[5]}"
id7 = f"./mscoco-images/val2014/{game_cache.image_files[6]}"
id8 = f"./mscoco-images/val2014/{game_cache.image_files[7]}"
id9 = f"./mscoco-images/val2014/{game_cache.image_files[8]}"
id10 = f"./mscoco-images/val2014/{game_cache.image_files[9]}"
gs.image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
gs.image_files = [x[15:] for x in gs.image_files]
gs.images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in gs.image_files]
gs.images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in gs.images_np]
gs.p_y_x = (torch.ones(10)/10).to(gs.curr_question_model.device)
gs.captions = gs.curr_caption_model.get_captions(gs.image_files)
gs.questions, gs.target_questions = gs.curr_question_model.get_questions(gs.image_files, gs.captions, 0)
gs.curr_question_model.reset_question_bank()
gs.curr_question_model.question_bank = game_cache.question_dict
first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul)
first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
gs.history.append(first_question)
html = f"<p>Current Task ID: <b>{int(taskid_original)}</b></p>"
if not gs.yn:
return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False)
else:
return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True)
with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
gr.HTML("<h1>Image Q&A Guessing Game</h1>\
<p style='font-size:120%;'>\
Imagine you are playing 20-questions with an AI model.<br>\
The AI model plays the role of the question asker. You play the role of the responder. <br>\
There are 10 images. <b>Your image is Image 1</b>. The other images are distraction images.\
The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.<br>\
<span style='color: #0000ff'>The goal of the model is to accurately guess the correct image (i.e. <b><span style='color: #0000ff'>Image 1</span></b>) in as few turns as possible.<br>\
Your goal is to help the model guess the image by answering as clearly and accurately as possible.</span><br><br>\
<b>Guidelines:</b><br>\
<ol style='font-size:120%;'>\
<li>It is best to keep your answers short (a single word or a short phrase). No need to answer in full sentences.</li>\
<li>If you feel that the question cannot be answered or does not apply to Image 1, please select N/A.</li>\
</ol> \
<br>\
(Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)<br></p>\
<br>\
<h2>Please enter a TaskID to start</h2>")
with gr.Column():
with gr.Row():
taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", value=0)
start_button = gr.Button("Enter")
with gr.Row():
task_text = gr.HTML()
with gr.Column() as img_block:
with gr.Row():
img1 = gr.Image(label="Image 1", show_label=True)
img2 = gr.Image(label="Image 2", show_label=True)
img3 = gr.Image(label="Image 3", show_label=True)
img4 = gr.Image(label="Image 4", show_label=True)
img5 = gr.Image(label="Image 5", show_label=True)
with gr.Row():
img6 = gr.Image(label="Image 6", show_label=True)
img7 = gr.Image(label="Image 7", show_label=True)
img8 = gr.Image(label="Image 8", show_label=True)
img9 = gr.Image(label="Image 9", show_label=True)
img10 = gr.Image(label="Image 10", show_label=True)
conversation = gr.HTML()
game_session_state = gr.State()
answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False)
null_answer = gr.Textbox("nothing", visible=False)
yes_answer = gr.Textbox("yes", visible=False)
no_answer = gr.Textbox("no", visible=False)
with gr.Column():
with gr.Row():
yes_box = gr.Button("Yes", visible=False)
no_box = gr.Button("No", visible=False)
with gr.Column():
with gr.Row():
na_box = gr.Button("N/A", visible=False, elem_classes="na_button")
submit = gr.Button("Submit", visible=False)
### Button click events
start_button.click(fn=set_images, inputs=taskid, outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, task_text, answer, na_box, submit, taskid, start_button, yes_box, no_box])
submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
demo.launch()
|