Spaces:
Runtime error
Runtime error
File size: 3,589 Bytes
e662df9 1fd4dc6 e662df9 d8e827d e662df9 1586c56 9843dcc ac51a1c e662df9 ac51a1c 9843dcc e662df9 d3cd834 e662df9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import logging
import os
import streamlit as st
import torch
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from hangman import guess_letter
from hf_utils import query_hint, query_word
CONFIGS_PATH = "configs.yaml"
MAX_TRIES = 6
CATEGORIES = ["Country", "Animal", "Food", "Movie"]
configs = {
"os_model": "google/gemma-2b-it",
"device": "cpu",
"generation_config": {
"max_output_tokens": 128,
"temperature": 1,
"top_p": 1,
"top_k": 4,
},
}
@st.cache_resource()
def setup(model_id: str, device: str) -> None:
"""Initializes the model and tokenizer.
Args:
model_id (str): Model ID used to load the tokenizer and model.
"""
logger.info(f"Loading model and tokenizer from model: '{model_id}'")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=os.environ["HF_ACCESS_TOKEN"],
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
token=os.environ["HF_ACCESS_TOKEN"],
).to(device)
logger.info("Setup finished")
return {"tokenizer": tokenizer, "model": model}
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__file__)
st.set_page_config(
page_title="Gemma Hangman",
page_icon="🧩",
)
load_dotenv()
assets = setup(configs["os_model"], configs["device"])
tokenizer = assets["tokenizer"]
model = assets["model"]
if not st.session_state:
st.session_state["word"] = ""
st.session_state["hint"] = ""
st.session_state["hangman"] = ""
st.session_state["missed_letters"] = []
st.session_state["correct_letters"] = []
st.title("Gemini Hangman")
st.markdown("## Guess the word based on a hint")
col1, col2 = st.columns(2)
with col1:
category = st.selectbox(
"Choose a category",
CATEGORIES,
)
with col2:
start_btn = st.button("Start game")
reset_btn = st.button("Reset game")
if start_btn:
st.session_state["word"] = query_word(
category,
model,
tokenizer,
configs["generation_config"],
configs["device"],
)
st.session_state["hint"] = query_hint(
st.session_state["word"],
model,
tokenizer,
configs["generation_config"],
configs["device"],
)
st.session_state["hangman"] = "_" * len(st.session_state["word"])
st.session_state["missed_letters"] = []
st.session_state["correct_letters"] = []
if reset_btn:
st.session_state["word"] = ""
st.session_state["hint"] = ""
st.session_state["hangman"] = ""
st.session_state["missed_letters"] = []
st.session_state["correct_letters"] = []
st.markdown(
"""
Note: you must input whitespaces and special characters.
"""
)
st.markdown(f'### Hint:\n{st.session_state["hint"]}')
col3, col4 = st.columns(2)
with col3:
guess = st.text_input(label="Enter letter")
guess_btn = st.button("Guess letter")
if guess_btn:
st.session_state = guess_letter(guess, st.session_state)
with col4:
hangman = st.text_input(
label="Hangman",
value=st.session_state["hangman"],
)
st.text_input(
label=f"Missed letters (max {MAX_TRIES} tries)",
value=", ".join(st.session_state["missed_letters"]),
)
if st.session_state["word"] == st.session_state["hangman"] != "":
st.success("You won!")
st.balloons()
if len(st.session_state["missed_letters"]) >= MAX_TRIES:
st.error(f"""You lost, the correct word was '{st.session_state["word"]}'""")
st.snow()
|