cogen / app.py
momergul
Update
33f8437
raw
history blame
10.6 kB
import spaces
import gradio as gr
import torch
import random
import os
from typing import List, Tuple
from config_generator import generate_complete_game
from dataset import get_processor, joint_speaker_input, joint_listener_input, get_index_to_token
from models import get_model
css="""
.radio-group .wrap {
display: grid;
grid-template-columns: repeat(5, 1fr);
grid-template-rows: repeat(5, 1fr);
width: 100%;
height: 100%
}
"""
def initialize_game() -> List[List[str]]:
context_dicts = [generate_complete_game() for _ in range(4)]
roles = ["speaker"] * 3 + ["listener"] * 3 + ["speaker"] * 3 + ["listener"] * 3
speaker_images = []
listener_images = []
targets = []
for context_dict in context_dicts:
for i in range(3):
speaker_images.append(context_dict["speaker_context"])
listener_images.append(context_dict["listener_context"])
targets.append(context_dict["targets"][i])
return list(zip(speaker_images, listener_images, targets, roles))
@spaces.GPU(duration=120)
def get_model_response(
model, adapter_name, processor, index_to_token, role: str,
image_paths: List[str], user_message: str = "", target_image: str = ""
) -> str:
model.model.set_adapter(adapter_name)
print(model.model.active_adapter)
if role == "speaker":
img_dir = "tangram_pngs"
input_tokens, attn_mask, images, image_attn_mask, label = joint_speaker_input(
processor, image_paths, target_image, model.get_listener().device
)
with torch.no_grad():
image_paths = [image_paths]
captions, _, _, _, _ = model.generate(
images, input_tokens, attn_mask, image_attn_mask, label,
image_paths, processor, img_dir, index_to_token,
max_steps=30, sampling_type="nucleus", temperature=0.7,
top_k=50, top_p=1, repetition_penalty=1, num_samples=10
)
response = captions[0]
else: # listener
images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens, s_attn_mask, \
s_image_attn_mask, s_target_mask, s_target_label = joint_listener_input(
processor, image_paths, user_message, model.get_listener().device
)
with torch.no_grad():
# Forward
_, _, joint_log_probs = model.comprehension_side([
images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label,
])
target_idx = joint_log_probs[0].argmax().item()
response = image_paths[target_idx]
return response
def interaction(model, processor, index_to_token, model_iteration: str) -> Tuple[List[str], List[str]]:
image_role_pairs = initialize_game()
conversation = []
turn = 0
num_correct = 0
human_role = None
adapter_name = "initial" if model_iteration == "Initial System" else "final"
internal_model = model
for speaker_image, listener_image, target_image, model_role in image_role_pairs:
acc_message = f"{num_correct}/{turn}"
if model_role == "speaker":
human_role = "Listener"
turn += 1
turn_message = f"{turn}/12"
human_context = listener_image
model_context = speaker_image
target_idx = human_context.index(target_image)
conversation.extend([
f"TURN: {turn}/12",
f"Guess the target image given the speaker's description. ",
])
model_message = get_model_response(internal_model, adapter_name, processor, index_to_token, model_role, model_context, target_image=target_image)
conversation.append(f"Model: {model_message}")
conversation.append("You: The target is Image ")
user_message = yield human_context, conversation, human_role, turn_message, acc_message
conversation[-1] += f"{user_message}"
if int(user_message) == target_idx + 1:
conversation.append("Correct!\n")
num_correct += 1
else:
conversation.append(f"Incorrect!\n")
else:
# listener
human_role = "Speaker"
turn += 1
turn_message = f"{turn}/12"
human_context = speaker_image
model_context = listener_image
target_idx = human_context.index(target_image)
conversation.extend([
f"TURN: {turn}/12",
f"Generate a description for the target image. Your target is Image {target_idx + 1}",
])
user_message = yield human_context, conversation, human_role, turn_message, acc_message
conversation.append(f"You: {user_message}")
model_message = get_model_response(internal_model, adapter_name, processor, index_to_token, model_role, model_context, user_message=user_message)
model_idx = human_context.index(model_message)
if int(model_idx) == int(target_idx):
conversation.append("The model guessed correctly!\n")
num_correct += 1
else:
conversation.append(f"The model guessed incorrectly.\n")
acc_message = f"{num_correct}/{turn}"
conversation.append("The game is over!")
yield human_context, conversation, human_role, turn_message, acc_message
def create_app():
with gr.Blocks(css=css) as app:
gr.Markdown("# Tangram Reference Game")
gr.Markdown(
'### You will be playing a sequence of reference games against a model. To start a game, first select whether ' +\
'you wish to play against our initial trained model ("Initial System") or our model at the end of deployment ("Final System") ' +\
'and press the "Start Game" button. There will be 12 rounds of reference games. You will take on a "listener" or a "speaker" role at each round.'
)
gr.Markdown(
'### In the speaker role, you will be assigned a target image. Your goal will be to describe this image (via a message in the textbox) ' +\
'so that your partner can guess what it is.'
)
gr.Markdown(
'### In the listener role, you will be given a description. Your goal will be ' +\
'to select the image that the description best describes (by clicking on the relevant button).'
)
gr.Markdown(
'### Press "Send" to submit your action in either role and make the game proceed.'
)
with gr.Row():
model_iteration = gr.Radio(["Initial System", "Final System"], label="Model Iteration")
start_btn = gr.Button("Start Game")
with gr.Row():
current_role = gr.Textbox(label="YOUR ROLE")
current_turn = gr.Textbox(label="TURN")
accuracy = gr.Textbox(label="FINAL ACCURACY")
with gr.Row():
image_output = gr.Gallery(
label="CONTEXT", show_label=False, elem_id="gallery",
columns=5, rows=2, object_fit="contain", height="250px",
allow_preview=False, container=True
)
with gr.Row():
conversation_output = gr.Textbox(label="Interaction History")
with gr.Column():
user_input = gr.Textbox(label="Your Message as Speaker", interactive=False)
radio_buttons = gr.Radio(
label="Your Guess as Listener",
elem_classes="radio-group",
choices=list(range(1, 11)),
interactive=False,
)
send_btn = gr.Button("Send")
interaction_generator = None
model = get_model()
processor = get_processor()
index_to_token = get_index_to_token()
def start_interaction(model_iteration):
if model_iteration is None:
return [], "Please select a model iteration.", "", "", "", gr.update(interactive=False), \
gr.update(interactive=False), gr.update(interactive=False)
nonlocal interaction_generator
nonlocal model
nonlocal processor
nonlocal index_to_token
interaction_generator = interaction(model, processor, index_to_token, model_iteration)
images, conversation, role, turn, acc_message = next(interaction_generator)
human_listener = role == "Listener"
return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, acc_message, \
gr.update(interactive=not human_listener), gr.update(interactive=human_listener), gr.update(interactive=True)
def send_message(message, radio_choice):
nonlocal interaction_generator
if interaction_generator is None:
return [], "Please start the interaction first.", "", gr.update(interactive=False), gr.update(interactive=False, value=None)
try:
user_output = message if radio_choice is None else radio_choice
images, conversation, role, turn, acc_message = interaction_generator.send(user_output)
human_listener = role == "Listener"
return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, acc_message, \
gr.update(interactive=not human_listener, value=""), gr.update(interactive=human_listener, value=None), gr.update(interactive=True)
except StopIteration:
return [], conversation_output.value, current_role.value, current_turn.value, accuracy.value, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
start_btn.click(
start_interaction,
inputs=[model_iteration],
outputs=[image_output, conversation_output, current_role, current_turn, accuracy, user_input, radio_buttons, send_btn]
)
send_btn.click(send_message, inputs=[user_input, radio_buttons], outputs=[image_output, conversation_output, current_role, current_turn, accuracy, user_input, radio_buttons, send_btn])
return app
app = create_app()
app.launch()