momergul commited on
Commit
32814fc
·
1 Parent(s): 1245fe3
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +2 -2
  2. app.py +235 -0
  3. clip_similarities/page-A-similarities.pkl +0 -0
  4. clip_similarities/page-B-similarities.pkl +0 -0
  5. clip_similarities/page-C-similarities.pkl +0 -0
  6. clip_similarities/page-D-similarities.pkl +0 -0
  7. clip_similarities/page-E-similarities.pkl +0 -0
  8. clip_similarities/page-F-similarities.pkl +0 -0
  9. clip_similarities/page-G-similarities.pkl +0 -0
  10. clip_similarities/page-H-similarities.pkl +0 -0
  11. clip_similarities/page-I-similarities.pkl +0 -0
  12. clip_similarities/page-J-similarities.pkl +0 -0
  13. clip_similarities/page-K-similarities.pkl +0 -0
  14. clip_similarities/page-L-similarities.pkl +0 -0
  15. clip_similarities/page1-0-similarities.pkl +0 -0
  16. clip_similarities/page1-1-similarities.pkl +0 -0
  17. clip_similarities/page1-10-similarities.pkl +0 -0
  18. clip_similarities/page1-103-similarities.pkl +0 -0
  19. clip_similarities/page1-105-similarities.pkl +0 -0
  20. clip_similarities/page1-106-similarities.pkl +0 -0
  21. clip_similarities/page1-107-similarities.pkl +0 -0
  22. clip_similarities/page1-108-similarities.pkl +0 -0
  23. clip_similarities/page1-109-similarities.pkl +0 -0
  24. clip_similarities/page1-110-similarities.pkl +0 -0
  25. clip_similarities/page1-112-similarities.pkl +0 -0
  26. clip_similarities/page1-113-similarities.pkl +0 -0
  27. clip_similarities/page1-114-similarities.pkl +0 -0
  28. clip_similarities/page1-116-similarities.pkl +0 -0
  29. clip_similarities/page1-117-similarities.pkl +0 -0
  30. clip_similarities/page1-118-similarities.pkl +0 -0
  31. clip_similarities/page1-119-similarities.pkl +0 -0
  32. clip_similarities/page1-122-similarities.pkl +0 -0
  33. clip_similarities/page1-125-similarities.pkl +0 -0
  34. clip_similarities/page1-128-similarities.pkl +0 -0
  35. clip_similarities/page1-129-similarities.pkl +0 -0
  36. clip_similarities/page1-13-similarities.pkl +0 -0
  37. clip_similarities/page1-130-similarities.pkl +0 -0
  38. clip_similarities/page1-132-similarities.pkl +0 -0
  39. clip_similarities/page1-133-similarities.pkl +0 -0
  40. clip_similarities/page1-136-similarities.pkl +0 -0
  41. clip_similarities/page1-137-similarities.pkl +0 -0
  42. clip_similarities/page1-14-similarities.pkl +0 -0
  43. clip_similarities/page1-142-similarities.pkl +0 -0
  44. clip_similarities/page1-143-similarities.pkl +0 -0
  45. clip_similarities/page1-147-similarities.pkl +0 -0
  46. clip_similarities/page1-148-similarities.pkl +0 -0
  47. clip_similarities/page1-149-similarities.pkl +0 -0
  48. clip_similarities/page1-150-similarities.pkl +0 -0
  49. clip_similarities/page1-151-similarities.pkl +0 -0
  50. clip_similarities/page1-153-similarities.pkl +0 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Cogen
3
  emoji: 🔥
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
  app_file: app.py
 
1
  ---
2
  title: Cogen
3
  emoji: 🔥
4
+ colorFrom: pink
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+
5
+ import random
6
+ import os
7
+ from typing import List, Tuple
8
+
9
+ from config_generator import generate_complete_game
10
+ from dataset import get_processor, joint_speaker_input, joint_listener_input, get_index_to_token
11
+ from models import get_model
12
+
13
+ css="""
14
+ .radio-group .wrap {
15
+ display: grid;
16
+ grid-template-columns: repeat(5, 1fr);
17
+ grid-template-rows: repeat(5, 1fr);
18
+ width: 100%;
19
+ height: 100%
20
+ }
21
+ """
22
+
23
+ def initialize_game() -> List[List[str]]:
24
+ context_dicts = [generate_complete_game() for _ in range(4)]
25
+
26
+ roles = ["speaker"] * 3 + ["listener"] * 3 + ["speaker"] * 3 + ["listener"] * 3
27
+ speaker_images = []
28
+ listener_images = []
29
+ targets = []
30
+
31
+ for context_dict in context_dicts:
32
+ for i in range(3):
33
+ speaker_images.append(context_dict["speaker_context"])
34
+ listener_images.append(context_dict["listener_context"])
35
+ targets.append(context_dict["targets"][i])
36
+
37
+ return list(zip(speaker_images, listener_images, targets, roles))
38
+
39
+ @spaces.GPU
40
+ def get_model_response(
41
+ model, adapter_name, processor, index_to_token, role: str,
42
+ image_paths: List[str], user_message: str = "", target_image: str = ""
43
+ ) -> str:
44
+ model.model.set_adapter(adapter_name)
45
+ print(model.model.active_adapter)
46
+ if role == "speaker":
47
+ img_dir = "tangram_pngs"
48
+ input_tokens, attn_mask, images, image_attn_mask, label = joint_speaker_input(
49
+ processor, image_paths, target_image, model.get_listener().device
50
+ )
51
+ with torch.no_grad():
52
+ image_paths = [image_paths]
53
+ captions, _, _, _, _ = model.generate(
54
+ images, input_tokens, attn_mask, image_attn_mask, label,
55
+ image_paths, processor, img_dir, index_to_token,
56
+ max_steps=30, sampling_type="nucleus", temperature=0.7,
57
+ top_k=50, top_p=1, repetition_penalty=1, num_samples=10
58
+ )
59
+ response = captions[0]
60
+ else: # listener
61
+ images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens, s_attn_mask, \
62
+ s_image_attn_mask, s_target_mask, s_target_label = joint_listener_input(
63
+ processor, image_paths, user_message, model.get_listener().device
64
+ )
65
+
66
+ with torch.no_grad():
67
+ # Forward
68
+ _, _, joint_log_probs = model.comprehension_side([
69
+ images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
70
+ s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label,
71
+ ])
72
+ target_idx = joint_log_probs[0].argmax().item()
73
+ response = image_paths[target_idx]
74
+
75
+ return response
76
+
77
+ def interaction(model, processor, index_to_token, model_iteration: str) -> Tuple[List[str], List[str]]:
78
+ image_role_pairs = initialize_game()
79
+ conversation = []
80
+ turn = 0
81
+ num_correct = 0
82
+ human_role = None
83
+ adapter_name = "initial" if model_iteration == "Initial System" else "final"
84
+ internal_model = model
85
+
86
+ for speaker_image, listener_image, target_image, model_role in image_role_pairs:
87
+ acc_message = f"{num_correct}/{turn}"
88
+ if model_role == "speaker":
89
+ human_role = "Listener"
90
+ turn += 1
91
+ turn_message = f"{turn}/12"
92
+ human_context = listener_image
93
+ model_context = speaker_image
94
+ target_idx = human_context.index(target_image)
95
+
96
+ conversation.extend([
97
+ f"TURN: {turn}/12",
98
+ f"Guess the target image given the speaker's description. ",
99
+ ])
100
+ model_message = get_model_response(internal_model, adapter_name, processor, index_to_token, model_role, model_context, target_image=target_image)
101
+ conversation.append(f"Model: {model_message}")
102
+ conversation.append("You: The target is Image ")
103
+ user_message = yield human_context, conversation, human_role, turn_message, acc_message
104
+
105
+ conversation[-1] += f"{user_message}"
106
+ if int(user_message) == target_idx + 1:
107
+ conversation.append("Correct!\n")
108
+ num_correct += 1
109
+ else:
110
+ conversation.append(f"Incorrect!\n")
111
+ else:
112
+ # listener
113
+ human_role = "Speaker"
114
+ turn += 1
115
+ turn_message = f"{turn}/12"
116
+ human_context = speaker_image
117
+ model_context = listener_image
118
+ target_idx = human_context.index(target_image)
119
+
120
+ conversation.extend([
121
+ f"TURN: {turn}/12",
122
+ f"Generate a description for the target image. Your target is Image {target_idx + 1}",
123
+ ])
124
+
125
+ user_message = yield human_context, conversation, human_role, turn_message, acc_message
126
+ conversation.append(f"You: {user_message}")
127
+ model_message = get_model_response(internal_model, adapter_name, processor, index_to_token, model_role, model_context, user_message=user_message)
128
+ model_idx = human_context.index(model_message)
129
+
130
+ if int(model_idx) == int(target_idx):
131
+ conversation.append("The model guessed correctly!\n")
132
+ num_correct += 1
133
+ else:
134
+ conversation.append(f"The model guessed incorrectly.\n")
135
+
136
+ acc_message = f"{num_correct}/{turn}"
137
+ conversation.append("The game is over!")
138
+ yield human_context, conversation, human_role, turn_message, acc_message
139
+
140
+ def create_app():
141
+ with gr.Blocks(css=css) as app:
142
+ gr.Markdown("# Tangram Reference Game")
143
+ gr.Markdown(
144
+ '### You will be playing a sequence of reference games against a model. To start a game, first select whether ' +\
145
+ 'you wish to play against our initial trained model ("Initial System") or our model at the end of deployment ("Final System") ' +\
146
+ 'and press the "Start Game" button. There will be 12 rounds of reference games. You will take on a "listener" or a "speaker" role at each round.'
147
+ )
148
+
149
+ gr.Markdown(
150
+ '### In the speaker role, you will be assigned a target image. Your goal will be to describe this image (via a message in the textbox) ' +\
151
+ 'so that your partner can guess what it is.'
152
+ )
153
+ gr.Markdown(
154
+ '### In the listener role, you will be given a description. Your goal will be ' +\
155
+ 'to select the image that the description best describes (by clicking on the relevant button).'
156
+ )
157
+ gr.Markdown(
158
+ '### Press "Send" to submit your action in either role and make the game proceed.'
159
+ )
160
+
161
+ with gr.Row():
162
+ model_iteration = gr.Radio(["Initial System", "Final System"], label="Model Iteration")
163
+ start_btn = gr.Button("Start Game")
164
+
165
+ with gr.Row():
166
+ current_role = gr.Textbox(label="YOUR ROLE")
167
+ current_turn = gr.Textbox(label="TURN")
168
+ accuracy = gr.Textbox(label="FINAL ACCURACY")
169
+
170
+ with gr.Row():
171
+ image_output = gr.Gallery(
172
+ label="CONTEXT", show_label=False, elem_id="gallery",
173
+ columns=5, rows=2, object_fit="contain", height="250px",
174
+ allow_preview=False, container=True
175
+ )
176
+
177
+ with gr.Row():
178
+ conversation_output = gr.Textbox(label="Interaction History")
179
+
180
+ with gr.Column():
181
+ user_input = gr.Textbox(label="Your Message as Speaker", interactive=False)
182
+ radio_buttons = gr.Radio(
183
+ label="Your Guess as Listener",
184
+ elem_classes="radio-group",
185
+ choices=list(range(1, 11)),
186
+ interactive=False,
187
+ )
188
+
189
+ send_btn = gr.Button("Send")
190
+
191
+ interaction_generator = None
192
+ model = get_model()
193
+ processor = get_processor()
194
+ index_to_token = get_index_to_token()
195
+
196
+ def start_interaction(model_iteration):
197
+ if model_iteration is None:
198
+ return [], "Please select a model iteration.", "", "", "", gr.update(interactive=False), \
199
+ gr.update(interactive=False), gr.update(interactive=False)
200
+
201
+ nonlocal interaction_generator
202
+ nonlocal model
203
+ nonlocal processor
204
+ nonlocal index_to_token
205
+ interaction_generator = interaction(model, processor, index_to_token, model_iteration)
206
+ images, conversation, role, turn, acc_message = next(interaction_generator)
207
+ human_listener = role == "Listener"
208
+ return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, acc_message, \
209
+ gr.update(interactive=not human_listener), gr.update(interactive=human_listener), gr.update(interactive=True)
210
+
211
+ def send_message(message, radio_choice):
212
+ nonlocal interaction_generator
213
+ if interaction_generator is None:
214
+ return [], "Please start the interaction first.", "", gr.update(interactive=False), gr.update(interactive=False, value=None)
215
+
216
+ try:
217
+ user_output = message if radio_choice is None else radio_choice
218
+ images, conversation, role, turn, acc_message = interaction_generator.send(user_output)
219
+ human_listener = role == "Listener"
220
+ return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, acc_message, \
221
+ gr.update(interactive=not human_listener, value=""), gr.update(interactive=human_listener, value=None), gr.update(interactive=True)
222
+ except StopIteration:
223
+ return [], conversation_output.value, current_role.value, current_turn.value, accuracy.value, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
224
+
225
+ start_btn.click(
226
+ start_interaction,
227
+ inputs=[model_iteration],
228
+ outputs=[image_output, conversation_output, current_role, current_turn, accuracy, user_input, radio_buttons, send_btn]
229
+ )
230
+ send_btn.click(send_message, inputs=[user_input, radio_buttons], outputs=[image_output, conversation_output, current_role, current_turn, accuracy, user_input, radio_buttons, send_btn])
231
+
232
+ return app
233
+
234
+ app = create_app()
235
+ app.launch()
clip_similarities/page-A-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page-B-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page-C-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page-D-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page-E-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page-F-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page-G-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page-H-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page-I-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page-J-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page-K-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page-L-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-0-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-1-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-10-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-103-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-105-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-106-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-107-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-108-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-109-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-110-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-112-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-113-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-114-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-116-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-117-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-118-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-119-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-122-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-125-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-128-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-129-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-13-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-130-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-132-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-133-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-136-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-137-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-14-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-142-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-143-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-147-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-148-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-149-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-150-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-151-similarities.pkl ADDED
Binary file (20.8 kB). View file
 
clip_similarities/page1-153-similarities.pkl ADDED
Binary file (20.8 kB). View file