tmzh commited on
Commit
e640a42
1 Parent(s): 80920ec

working version

Browse files
Files changed (1) hide show
  1. app.py +299 -0
app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import math
4
+ import os
5
+ import random
6
+ import json
7
+ import re
8
+ from typing import List
9
+
10
+ import gradio as gr
11
+ import outlines
12
+ import requests
13
+ from outlines import models, generate, samplers
14
+ from pydantic import BaseModel
15
+
16
+
17
+ def merge_games(clues, num_merges=10):
18
+ """Generates around 10 merges of words from the given clues.
19
+
20
+ Args:
21
+ clues: A list of clues, where each clue is a list containing the words, the answer, and the explanation.
22
+ num_merges: The approximate number of merges to generate (default: 10).
23
+
24
+ Returns:
25
+ A list of tuples, where each tuple contains the merged words and the indices of the selected rows.
26
+ """
27
+
28
+ merges = []
29
+ while len(merges) < num_merges:
30
+ num_rows = random.choice([3, 4])
31
+ selected_rows = random.sample(range(len(clues)), num_rows)
32
+ merged_words = " ".join([word for row in [clues[i][0] for i in selected_rows] for word in row])
33
+ if len(merged_words.split()) in [8, 9]:
34
+ merges.append((merged_words.split(), selected_rows))
35
+
36
+ return merges
37
+
38
+
39
+ class Clue(BaseModel):
40
+ word: str
41
+ explanation: str
42
+
43
+
44
+ class Group(BaseModel):
45
+ words: List[str]
46
+ clue: str
47
+ explanation: str
48
+
49
+
50
+ class Groups(BaseModel):
51
+ groups: List[Group]
52
+
53
+
54
+ example_clues = [
55
+ (['ARROW', 'TIE', 'HONOR'], 'BOW', 'such as a bow and arrow, a bow tie, or a bow as a sign of honor'),
56
+ (['DOG', 'TREE'], 'BARK', 'such as the sound a dog makes, or a tree is made of bark'),
57
+ (['MONEY', 'RIVER', 'ROB', 'BLOOD'], 'CRIME', 'such as money being stolen, a river being a potential crime scene, '
58
+ 'robbery, or blood being a result of a violent crime'),
59
+ (['BEEF', 'TURKEY', 'FIELD', 'GRASS'], 'GROUND',
60
+ 'such as ground beef, a turkey being a ground-dwelling bird, a field or grass being a type of ground'),
61
+ (['BANK', 'GUITAR', 'LIBRARY'], 'NOTE',
62
+ 'such as a bank note, a musical note on a guitar, or a note being a written comment in a library book'),
63
+ (['ROOM', 'PIANO', 'TYPEWRITER'], 'KEYS', 'such as a room key, piano keys, or typewriter keys'),
64
+ (['TRAFFIC', 'RADAR', 'PHONE'], 'SIGNAL', 'such as traffic signals, radar signals, or phone signals'),
65
+ (['FENCE', 'PICTURE', 'COOKIE'], 'FRAME',
66
+ 'such as a frame around a yard, a picture frame, or a cookie cutter being a type of frame'),
67
+ (['YARN', 'VIOLIN', 'DRESS'], 'STRING', 'strings like material, instrument, clothing fastener'),
68
+ (['JUMP', 'FLOWER', 'CLOCK'], 'SPRING',
69
+ 'such as jumping, flowers blooming in the spring, or a clock having a sprint component'),
70
+ (['SPY', 'KNIFE'], 'WAR',
71
+ 'Both relate to aspects of war, such as spies being involved in war or knives being used as weapons'),
72
+ (['STADIUM', 'SHOE', 'FIELD'], 'SPORT', 'Sports like venues, equipment, playing surfaces'),
73
+ (['TEACHER', 'CLUB'], 'SCHOOL',
74
+ 'such as a teacher being a school staff member or a club being a type of school organization'),
75
+ (['CYCLE', 'ARMY', 'COURT', 'FEES'], 'CHARGE', 'charges like electricity, battle, legal, payments'),
76
+ (['FRUIT', 'MUSIC', 'TRAFFIC', 'STUCK'], 'JAM',
77
+ 'Jams such as fruit jam, a music jam session, traffic jam, or being stuck in a jam'),
78
+ (['POLICE', 'DOG', 'THIEF'], 'CRIME',
79
+ 'such as police investigating crimes, dogs being used to detect crimes, or a thief committing a crime'),
80
+ (['ARCTIC', 'SHUT', 'STAMP'], 'SEAL',
81
+ 'such as the Arctic being home to seals, or shutting a seal on an envelope, or a stamp being a type of seal'),
82
+ ]
83
+
84
+
85
+ def group_words(words):
86
+ @outlines.prompt
87
+ def chat_group_template(system_prompt, query, history=[]):
88
+ '''<s><|system|>
89
+ {{ system_prompt }}
90
+ {% for example in history %}
91
+ <|user|>
92
+ {{ example[0] }}<|end|>
93
+ <|assistant|>
94
+ {{ example[1] }}<|end|>
95
+ {% endfor %}
96
+ <|user|>
97
+ {{ query }}<|end|>
98
+ <|assistant|>
99
+ '''
100
+
101
+ grouping_system_prompt = ("You are an assistant for the game Codenames. Your task is to help players by grouping a "
102
+ "given group of secrets into 3 to 4 groups. Each group should consist of secrets that "
103
+ "share a common theme or other word connections such as homonym, hypernyms or synonyms")
104
+ example_groupings = []
105
+ merges = merge_games(example_clues, 5)
106
+ for merged_words, indices in merges:
107
+ groups = [{
108
+ "secrets": example_clues[i][0],
109
+ "clue": example_clues[i][1],
110
+ "explanation": example_clues[i][2]
111
+ } for i in indices]
112
+ example_groupings.append((merged_words, json.dumps(groups, separators=(',', ':'))))
113
+
114
+ prompt = chat_group_template(grouping_system_prompt, words, example_groupings)
115
+ sampler = samplers.greedy()
116
+ generator = generate.json(model, Groups, sampler)
117
+
118
+ print("Grouping words:", words)
119
+ generations = generator(
120
+ prompt,
121
+ max_tokens=500
122
+ )
123
+ print("Got groupings: ", generations)
124
+ return generations.groups
125
+
126
+
127
+ def generate_clues(group):
128
+ @outlines.prompt
129
+ def chat_clue_template(system, query, history=[]):
130
+ '''<s><|system|>
131
+ {{ system }}
132
+ {% for example in history %}
133
+ <|user|>
134
+ {{ example[0] }}<|end|>
135
+ <|assistant|>
136
+ {"Clue": "{{ example[1] }}", "Description": "{{ example[2] }}" }<|end|>
137
+ {% endfor %}
138
+ <|user|>
139
+ {{ query }}<|end|>
140
+ <|assistant|>
141
+ '''
142
+
143
+ clue_system_prompt = ("You are a codenames game companion. Your task is to give a single word clue related to "
144
+ "a given group of words. You will only respond with a single word clue. Compound words are "
145
+ "allowed. Do not include the word 'Clue'. Do not provide explanations or notes.")
146
+
147
+ prompt = chat_clue_template(clue_system_prompt, group, example_clues)
148
+ # sampler = samplers.greedy()
149
+ sampler = samplers.multinomial(2, top_k=10)
150
+ generator = generate.json(model, Clue, sampler)
151
+ generations = generator(prompt, max_tokens=100)
152
+ print("Got clues: ", generations)
153
+ return generations[0]
154
+
155
+
156
+ def jpeg_with_target_size(im, target):
157
+ """Return the image as JPEG with the given name at best quality that makes less than "target" bytes
158
+
159
+ https://stackoverflow.com/a/52281257
160
+ """
161
+ # Min and Max quality
162
+ qmin, qmax = 25, 96
163
+ # Highest acceptable quality found
164
+ qacc = -1
165
+ while qmin <= qmax:
166
+ m = math.floor((qmin + qmax) / 2)
167
+
168
+ # Encode into memory and get size
169
+ buffer = io.BytesIO()
170
+ im.save(buffer, format="JPEG", quality=m)
171
+ s = buffer.getbuffer().nbytes
172
+
173
+ if s <= target:
174
+ qacc = m
175
+ qmin = m + 1
176
+ elif s > target:
177
+ qmax = m - 1
178
+
179
+ # Write to disk at the defined quality
180
+ if qacc > -1:
181
+ image_byte_array = io.BytesIO()
182
+ print("Acceptable quality", im, im.format, f"{im.size}x{im.mode}")
183
+ im.save(image_byte_array, format='JPEG', quality=qacc)
184
+ return image_byte_array.getvalue()
185
+
186
+
187
+ def process_image(img):
188
+ # Resize the image
189
+ max_size = (1024, 1024)
190
+ img.thumbnail(max_size)
191
+
192
+ image_byte_array = jpeg_with_target_size(img, 180_000)
193
+ image_b64 = base64.b64encode(image_byte_array).decode()
194
+
195
+ invoke_url = "https://ai.api.nvidia.com/v1/vlm/microsoft/phi-3-vision-128k-instruct"
196
+ stream = False
197
+
198
+ if os.environ.get("NVIDIA_API_KEY", "").startswith("nvapi-"):
199
+ print("Valid NVIDIA_API_KEY already in the environment. Delete to reset")
200
+
201
+ headers = {
202
+ "Authorization": f"Bearer {os.environ.get('NVIDIA_API_KEY', '')}",
203
+ "Accept": "text/event-stream" if stream else "application/json"
204
+ }
205
+
206
+ payload = {
207
+ "messages": [
208
+ {
209
+ "role": "user",
210
+ "content": f'Identify the words in this game of Codenames. Provide only a list of words. Provide the '
211
+ f'words in capital letters only. <img src="data:image/png;base64,{image_b64}" />'
212
+ }
213
+ ],
214
+ "max_tokens": 512,
215
+ "temperature": 0.1,
216
+ "top_p": 0.70,
217
+ "stream": stream
218
+ }
219
+
220
+ response = requests.post(invoke_url, headers=headers, json=payload)
221
+ if response.ok:
222
+ print(response.json())
223
+ # Define the pattern to match uppercase words separated by commas
224
+ pattern = r'[A-Z]+(?:\s+[A-Z]+)?'
225
+ words = re.findall(pattern, response.json()['choices'][0]['message']['content'])
226
+
227
+ return gr.update(choices=words, value=words)
228
+
229
+
230
+ if __name__ == '__main__':
231
+ with gr.Blocks() as demo:
232
+ gr.Markdown("# *Codenames* clue generator")
233
+ gr.Markdown("Provide a list of words to generate a clue")
234
+
235
+ with gr.Row():
236
+ game_image = gr.Image(type="pil")
237
+ word_list_input = gr.Dropdown(label="Enter list of words (comma separated)",
238
+ choices='WEREWOLF, CHAIN, MOSQUITO, CRAFT, RANCH, LIP, VALENTINE, CLOUD, '
239
+ 'BEARD, BUNK, SECOND, SADDLE, BUCKET, JAIL, ANT, POCKET, LACE, '
240
+ 'BREAK, CUCKOO, FLAT, NIL, TIN, CHERRY, CHRISTMAS, MOSES, '
241
+ 'TEAM'.split(', '),
242
+ multiselect=True,
243
+ interactive=True)
244
+
245
+ with gr.Row():
246
+ detect_words_button = gr.Button("Detect Words")
247
+ group_words_button = gr.Button("Group Words")
248
+
249
+ dropdowns, buttons, outputs = [], [], []
250
+
251
+ for i in range(4):
252
+ with gr.Row():
253
+ group_input = gr.Dropdown(label=f"Group {i + 1}",
254
+ choices=[],
255
+ allow_custom_value=True,
256
+ multiselect=True,
257
+ interactive=True)
258
+ clue_button = gr.Button("Generate Clue", size='sm')
259
+ clue_output = gr.Textbox(label=f"Clue {i + 1}")
260
+ dropdowns.append(group_input)
261
+ buttons.append(clue_button)
262
+ outputs.append(clue_output)
263
+
264
+ def pad_or_truncate(lst, n=4):
265
+ # Ensure the length of the list is at most n
266
+ truncated_lst = lst[:n]
267
+ return truncated_lst + (n - len(truncated_lst)) * [Group(words=[],clue='',explanation='')]
268
+
269
+
270
+ def group_words_callback(words):
271
+ groups = group_words(words)
272
+ groups = pad_or_truncate(groups, 4)
273
+ print("Got groups: ", groups, type(groups))
274
+ return [gr.update(value=groups[i].words, choices=groups[i].words, info=groups[i].explanation) for i in range(4)]
275
+
276
+
277
+
278
+ def generate_clues_callback(group):
279
+ print("Generating clues: ", group)
280
+ g = generate_clues(group)
281
+ return gr.update(value=g.word, info=g.explanation)
282
+
283
+
284
+ model = models.transformers("microsoft/Phi-3-mini-4k-instruct",
285
+ model_kwargs={'device_map': "cuda", 'torch_dtype': "auto",
286
+ 'trust_remote_code': True,
287
+ 'attn_implementation': "flash_attention_2"})
288
+
289
+ detect_words_button.click(fn=process_image,
290
+ inputs=game_image,
291
+ outputs=[word_list_input])
292
+ group_words_button.click(fn=group_words_callback,
293
+ inputs=word_list_input,
294
+ outputs=dropdowns)
295
+
296
+ for i in range(4):
297
+ buttons[i].click(generate_clues_callback, inputs=dropdowns[i], outputs=outputs[i])
298
+
299
+ demo.launch(share=False)