tmzh
fix bug
b084b77
raw
history blame contribute delete
No virus
8.66 kB
import json
import os
import random
from typing import List, Tuple
import google.generativeai as genai
import gradio as gr
from jinja2 import Environment, FileSystemLoader
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
genai.configure(api_key=GOOGLE_API_KEY)
MODEL = genai.GenerativeModel("gemini-1.5-flash-latest",
generation_config={"response_mime_type": "application/json"})
# Example clues and groupings
EXAMPLE_CLUES = [
(['ARROW', 'TIE', 'HONOR'], 'BOW', 'such as a bow and arrow, a bow tie, or a bow as a sign of honor'),
(['DOG', 'TREE'], 'BARK', 'such as the sound a dog makes, or a tree is made of bark'),
(['MONEY', 'RIVER', 'ROB', 'BLOOD'], 'CRIME', 'such as money being stolen, a river being a potential crime scene, robbery, or blood being a result of a violent crime'),
(['BEEF', 'TURKEY', 'FIELD', 'GRASS'], 'GROUND', 'such as ground beef, a turkey being a ground-dwelling bird, a field or grass being a type of ground'),
(['BANK', 'GUITAR', 'LIBRARY'], 'NOTE', 'such as a bank note, a musical note on a guitar, or a note being a written comment in a library book'),
(['ROOM', 'PIANO', 'TYPEWRITER'], 'KEYS', 'such as a room key, piano keys, or typewriter keys'),
(['TRAFFIC', 'RADAR', 'PHONE'], 'SIGNAL', 'such as traffic signals, radar signals, or phone signals'),
(['FENCE', 'PICTURE', 'COOKIE'], 'FRAME', 'such as a frame around a yard, a picture frame, or a cookie cutter being a type of frame'),
(['YARN', 'VIOLIN', 'DRESS'], 'STRING', 'strings like material, instrument, clothing fastener'),
(['JUMP', 'FLOWER', 'CLOCK'], 'SPRING', 'such as jumping, flowers blooming in the spring, or a clock having a sprint component'),
(['SPY', 'KNIFE'], 'WAR', 'Both relate to aspects of war, such as spies being involved in war or knives being used as weapons'),
(['STADIUM', 'SHOE', 'FIELD'], 'SPORT', 'Sports like venues, equipment, playing surfaces'),
(['TEACHER', 'CLUB'], 'SCHOOL', 'such as a teacher being a school staff member or a club being a type of school organization'),
(['CYCLE', 'ARMY', 'COURT', 'FEES'], 'CHARGE', 'charges like electricity, battle, legal, payments'),
(['FRUIT', 'MUSIC', 'TRAFFIC', 'STUCK'], 'JAM', 'Jams such as fruit jam, a music jam session, traffic jam, or being stuck in a jam'),
(['POLICE', 'DOG', 'THIEF'], 'CRIME', 'such as police investigating crimes, dogs being used to detect crimes, or a thief committing a crime'),
(['ARCTIC', 'SHUT', 'STAMP'], 'SEAL', 'such as the Arctic being home to seals, or shutting a seal on an envelope, or a stamp being a type of seal'),
]
def create_random_word_groups(clues: List[Tuple[List[str], str, str]], num_groups: int = 10) -> List[Tuple[List[str], List[int]]]:
"""Creates random groups of words from the given clues."""
word_groups = []
while len(word_groups) < num_groups:
group_size = random.choice([3, 4])
selected_indices = random.sample(range(len(clues)), group_size)
words = [word for row in [clues[i][0] for i in selected_indices] for word in row]
if len(words) in [8, 9]:
word_groups.append((words, selected_indices))
return word_groups
def create_example_groupings(clues: List[Tuple[List[str], str, str]], num_groups: int = 5) -> List[Tuple[List[str], str]]:
"""Creates example groupings from the given clues."""
merged = create_random_word_groups(clues, num_groups)
return [
(
merged_words,
json.dumps([{
"words": clues[i][0],
"clue": clues[i][1],
"explanation": clues[i][2]
} for i in indices], separators=(',', ':'))
)
for merged_words, indices in merged
]
EXAMPLE_GROUPINGS = create_example_groupings(EXAMPLE_CLUES)
def render_template(template: str, system_prompt: str, history: List[Tuple], query: str) -> str:
"""Renders a Jinja2 template with the given parameters."""
env = Environment(loader=FileSystemLoader('.'))
template = env.from_string(template)
return template.render(system_prompt=system_prompt, history=history, query=query)
def group_words(words: List[str]) -> List[dict]:
"""Groups the given words using the AI model."""
template = '''
{% for example in history %}
INPUT:
{{ example[0] }}
OUTPUT:
{{ example[1] }}
{% endfor %}
INPUT:
{{ query }}
OUTPUT:
{{ system }}
Groups = {'words': list[str], 'clue': str, 'explanation': str}
Return: Groups
'''
grouping_prompt = ("You are an assistant for the game Codenames. Group the given words into 3 to 4 sets of 2 to 4 words each. "
"Each group should share a common theme or word connection. Avoid generic or easily guessable clues.")
prompt = render_template(template, grouping_prompt, EXAMPLE_GROUPINGS, words)
response = MODEL.generate_content(prompt, generation_config={'top_k': 3, 'temperature': 1.1})
return json.loads(response.text)
def generate_clue(group: List[str]) -> dict:
"""Generates a clue for the given group of words using the AI model."""
template = '''
{% for example in history %}
INPUT:
{{ example[0] }}
OUTPUT:
{ 'clue':{{ example[1] }}, 'explanation':{{ example[2] }} }
{% endfor %}
INPUT:
{{ query }}
OUTPUT:
{{ system }}
Clue = {'clue': str, 'explanation': str}
Return: Clue
'''
clue_prompt = ("As a Codenames game companion, provide a single-word clue for the given group of words. "
"The clue should relate to a common theme or word connection. DO NOT reuse any of the given "
"words as a clue. Avoid generic or easily guessable clues.")
prompt = render_template(template, clue_prompt, EXAMPLE_CLUES, group)
response = MODEL.generate_content(prompt, generation_config={'top_k': 3, 'temperature': 1.1})
return json.loads(response.text)
def process_image(img) -> gr.update:
"""Processes the uploaded image and extracts words for the game."""
prompt = ('Identify the words in this Codenames game image. Provide only a list of words in capital letters. '
'Group these words into 6 or 8 sets that can be guessed together using a single-word clue. '
'Respond with JSON in the format: {"Game": <list of words in the game>}')
response = MODEL.generate_content([prompt, img], stream=True)
response.resolve()
words = json.loads(response.text)['Game']
return gr.update(choices=words, value=words)
def pad_or_truncate(lst: List, n: int = 4) -> List:
"""Ensures the list has exactly n elements, padding with None if necessary."""
truncated_lst = lst[:n]
return truncated_lst + (n - len(truncated_lst)) * [{}]
def group_words_callback(words: List[str]) -> List[gr.update]:
"""Callback function for grouping words."""
groups = group_words(words)
groups = pad_or_truncate(groups, 4)
return [gr.update(value=groups[i].get("words", ""), choices=words, info=groups[i].get("explanation","")) for i in range(4)]
def generate_clue_callback(group: List[str]) -> gr.update:
"""Callback function for generating clues."""
clue = generate_clue(group)
return gr.update(value=clue['clue'], info=clue['explanation'])
# UI Setup
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# *Codenames* Clue Generator")
gr.Markdown("Provide a list of words to generate clues")
with gr.Row():
game_image = gr.Image(type="pil")
word_list_input = gr.Dropdown(label="Detected words", choices=[], multiselect=True, interactive=True)
with gr.Row():
detect_words_button = gr.Button("Detect Words")
group_words_button = gr.Button("Group Words")
group_inputs, clue_buttons, clue_outputs = [], [], []
for i in range(4):
with gr.Row():
group_input = gr.Dropdown(label=f"Group {i + 1}", choices=[], allow_custom_value=True, multiselect=True, interactive=True)
clue_button = gr.Button("Generate Clue", size='sm')
clue_output = gr.Textbox(label=f"Clue {i + 1}")
group_inputs.append(group_input)
clue_buttons.append(clue_button)
clue_outputs.append(clue_output)
# Event handlers
detect_words_button.click(fn=process_image, inputs=game_image, outputs=[word_list_input])
group_words_button.click(fn=group_words_callback, inputs=word_list_input, outputs=group_inputs)
for i in range(4):
clue_buttons[i].click(generate_clue_callback, inputs=group_inputs[i], outputs=clue_outputs[i])
demo.launch(share=True, debug=True)