Gemma-Hangman / app.py
Dimitre's picture
Update app.py
cf4e3ae verified
import logging
import os
import streamlit as st
import torch
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from hangman import guess_letter
from hf_utils import query_hint, query_word
CONFIGS_PATH = "configs.yaml"
MAX_TRIES = 6
CATEGORIES = ["Country", "Animal", "Food", "Movie"]
configs = {
"os_model": "google/gemma-2b-it",
"device": "cpu",
"generation_config": {
"max_output_tokens": 128,
"temperature": 1,
"top_p": 1,
"top_k": 4,
},
}
@st.cache_resource()
def setup(model_id: str, device: str) -> None:
"""Initializes the model and tokenizer.
Args:
model_id (str): Model ID used to load the tokenizer and model.
"""
logger.info(f"Loading model and tokenizer from model: '{model_id}'")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=os.environ["HF_ACCESS_TOKEN"],
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
token=os.environ["HF_ACCESS_TOKEN"],
).to(device)
logger.info("Setup finished")
return {"tokenizer": tokenizer, "model": model}
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__file__)
st.set_page_config(
page_title="Gemma Hangman",
page_icon="🧩",
)
load_dotenv()
assets = setup(configs["os_model"], configs["device"])
tokenizer = assets["tokenizer"]
model = assets["model"]
if not st.session_state:
st.session_state["word"] = ""
st.session_state["hint"] = ""
st.session_state["hangman"] = ""
st.session_state["missed_letters"] = []
st.session_state["correct_letters"] = []
st.title("Gemini Hangman")
st.markdown("## Guess the word based on a hint")
col1, col2 = st.columns(2)
with col1:
category = st.selectbox(
"Choose a category",
CATEGORIES,
)
with col2:
start_btn = st.button("Start game")
reset_btn = st.button("Reset game")
if start_btn:
st.session_state["word"] = query_word(
category,
model,
tokenizer,
configs["generation_config"],
configs["device"],
)
st.session_state["hint"] = query_hint(
st.session_state["word"],
model,
tokenizer,
configs["generation_config"],
configs["device"],
)
st.session_state["hangman"] = "_" * len(st.session_state["word"])
st.session_state["missed_letters"] = []
st.session_state["correct_letters"] = []
if reset_btn:
st.session_state["word"] = ""
st.session_state["hint"] = ""
st.session_state["hangman"] = ""
st.session_state["missed_letters"] = []
st.session_state["correct_letters"] = []
st.markdown(
"""
Note: you must input whitespaces and special characters.
"""
)
st.markdown(f'### Hint:\n{st.session_state["hint"]}')
col3, col4 = st.columns(2)
with col3:
guess = st.text_input(label="Enter letter")
guess_btn = st.button("Guess letter")
if guess_btn:
st.session_state = guess_letter(guess, st.session_state)
with col4:
hangman = st.text_input(
label="Hangman",
value=st.session_state["hangman"],
)
st.text_input(
label=f"Missed letters (max {MAX_TRIES} tries)",
value=", ".join(st.session_state["missed_letters"]),
)
if st.session_state["word"] == st.session_state["hangman"] != "":
st.success("You won!")
st.balloons()
if len(st.session_state["missed_letters"]) >= MAX_TRIES:
st.error(f"""You lost, the correct word was '{st.session_state["word"]}'""")
st.snow()