Dimitre commited on
Commit
d8e827d
1 Parent(s): 1586c56

refactoring app

Browse files
Files changed (3) hide show
  1. app.py +3 -137
  2. hangman.py +35 -0
  3. hf_utils.py +109 -0
app.py CHANGED
@@ -1,28 +1,19 @@
1
  import logging
2
  import os
3
- import string
4
- import re
5
 
6
  import streamlit as st
7
- from streamlit import session_state
8
  import torch
9
  from dotenv import load_dotenv
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
11
 
12
- # from common import CATEGORIES, MAX_TRIES, configs
13
- # from hangman import guess_letter
14
- # from hf_utils import query_hint, query_word
15
 
16
 
17
  CONFIGS_PATH = "configs.yaml"
18
  MAX_TRIES = 6
19
  CATEGORIES = ["Country", "Animal", "Food", "Movie"]
20
 
21
- GEMMA_WORD_PATTERNS = [
22
- "(?<=\*)(.*?)(?=\*)",
23
- '(?<=")(.*?)(?=")',
24
- ]
25
-
26
  configs = {
27
  "os_model": "google/gemma-2b-it",
28
  "device": "cpu",
@@ -35,131 +26,6 @@ configs = {
35
  }
36
 
37
 
38
- def guess_letter(letter: str, session: session_state) -> session_state:
39
- """Take a letter and evaluate if it is part of the hangman puzzle
40
- then updates the session object accordingly.
41
-
42
- Args:Chosen letter
43
- letter (str): Streamlit session object
44
- session (session_state): _description_
45
-
46
- Returns:
47
- session_state: Updated session
48
- """
49
- logger.info(f"Letter '{letter}' picked")
50
- if letter in session["word"]:
51
- session["correct_letters"].append(letter)
52
- else:
53
- session["missed_letters"].append(letter)
54
-
55
- hangman = "".join(
56
- [
57
- (letter if letter in session["correct_letters"] else "_")
58
- for letter in session["word"]
59
- ]
60
- )
61
- session["hangman"] = hangman
62
- logger.info("Session state updated")
63
- return session
64
-
65
-
66
- def query_hf(
67
- query: str,
68
- model: AutoModelForCausalLM,
69
- tokenizer: AutoTokenizer,
70
- generation_config: dict,
71
- device: str,
72
- ) -> str:
73
- """Queries an LLM model using the Vertex AI API.
74
-
75
- Args:
76
- query (str): Query sent to the Vertex API
77
- model (str): Model target by Vertex
78
- generation_config (dict): Configurations used by the model
79
-
80
- Returns:
81
- str: Vertex AI text response
82
- """
83
- generation_config = GenerationConfig(
84
- do_sample=True,
85
- max_new_tokens=generation_config["max_output_tokens"],
86
- top_k=generation_config["top_k"],
87
- top_p=generation_config["top_p"],
88
- temperature=generation_config["temperature"],
89
- )
90
-
91
- input_ids = tokenizer(query, return_tensors="pt").to(device)
92
- outputs = model.generate(**input_ids, generation_config=generation_config)
93
- outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
- outputs = outputs.replace(query, "")
95
- return outputs
96
-
97
-
98
- def query_word(
99
- category: str,
100
- model: AutoModelForCausalLM,
101
- tokenizer: AutoTokenizer,
102
- generation_config: dict,
103
- device: str,
104
- ) -> str:
105
- """Queries a word to be used for the hangman game.
106
-
107
- Args:
108
- category (str): Category used as source sample a word
109
- model (str): Model target by Vertex
110
- generation_config (dict): Configurations used by the model
111
-
112
- Returns:
113
- str: Queried word
114
- """
115
- logger.info(f"Quering word for category: '{category}'...")
116
- query = f"Name a single existing {category}."
117
-
118
- matched_word = ""
119
- while not matched_word:
120
- # word = query_hf(query, model, tokenizer, generation_config, device)
121
- word = "placeholder word"
122
-
123
- # Extract word of interest from Gemma's output
124
- for pattern in GEMMA_WORD_PATTERNS:
125
- matched_words = re.findall(rf"{pattern}", word)
126
- matched_words = [x for x in matched_words if x != ""]
127
- if matched_words:
128
- matched_word = matched_words[-1]
129
-
130
- matched_word = matched_word.translate(str.maketrans("", "", string.punctuation))
131
- matched_word = matched_word.lower()
132
-
133
- logger.info("Word queried successful")
134
- return matched_word
135
-
136
-
137
- def query_hint(
138
- word: str,
139
- model: AutoModelForCausalLM,
140
- tokenizer: AutoTokenizer,
141
- generation_config: dict,
142
- device: str,
143
- ) -> str:
144
- """Queries a hint for the hangman game.
145
-
146
- Args:
147
- word (str): Word used as source to create the hint
148
- model (str): Model target by Vertex
149
- generation_config (dict): Configurations used by the model
150
-
151
- Returns:
152
- str: Queried hint
153
- """
154
- logger.info(f"Quering hint for word: '{word}'...")
155
- query = f"Describe the word '{word}' without mentioning it."
156
- # hint = query_hf(query, model, tokenizer, generation_config, device)
157
- hint = "placeholder hint"
158
- hint = re.sub(re.escape(word), "***", hint, flags=re.IGNORECASE)
159
- logger.info("Hint queried successful")
160
- return hint
161
-
162
-
163
  @st.cache_resource()
164
  def setup(model_id: str, device: str) -> None:
165
  """Initializes the model and tokenizer.
 
1
  import logging
2
  import os
 
 
3
 
4
  import streamlit as st
 
5
  import torch
6
  from dotenv import load_dotenv
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
+ from hangman import guess_letter
10
+ from hf_utils import query_hint, query_word
 
11
 
12
 
13
  CONFIGS_PATH = "configs.yaml"
14
  MAX_TRIES = 6
15
  CATEGORIES = ["Country", "Animal", "Food", "Movie"]
16
 
 
 
 
 
 
17
  configs = {
18
  "os_model": "google/gemma-2b-it",
19
  "device": "cpu",
 
26
  }
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  @st.cache_resource()
30
  def setup(model_id: str, device: str) -> None:
31
  """Initializes the model and tokenizer.
hangman.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from streamlit import session_state
4
+
5
+
6
+ def guess_letter(letter: str, session: session_state) -> session_state:
7
+ """Take a letter and evaluate if it is part of the hangman puzzle
8
+ then updates the session object accordingly.
9
+
10
+ Args:Chosen letter
11
+ letter (str): Streamlit session object
12
+ session (session_state): _description_
13
+
14
+ Returns:
15
+ session_state: Updated session
16
+ """
17
+ logger.info(f"Letter '{letter}' picked")
18
+ if letter in session["word"]:
19
+ session["correct_letters"].append(letter)
20
+ else:
21
+ session["missed_letters"].append(letter)
22
+
23
+ hangman = "".join(
24
+ [
25
+ (letter if letter in session["correct_letters"] else "_")
26
+ for letter in session["word"]
27
+ ]
28
+ )
29
+ session["hangman"] = hangman
30
+ logger.info("Session state updated")
31
+ return session
32
+
33
+
34
+ logging.basicConfig(level=logging.INFO)
35
+ logger = logging.getLogger(__file__)
hf_utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ import string
4
+
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
6
+
7
+ GEMMA_WORD_PATTERNS = [
8
+ "(?<=\*)(.*?)(?=\*)",
9
+ '(?<=")(.*?)(?=")',
10
+ ]
11
+
12
+
13
+ def query_hf(
14
+ query: str,
15
+ model: AutoModelForCausalLM,
16
+ tokenizer: AutoTokenizer,
17
+ generation_config: dict,
18
+ device: str,
19
+ ) -> str:
20
+ """Queries an LLM model using the Vertex AI API.
21
+
22
+ Args:
23
+ query (str): Query sent to the Vertex API
24
+ model (str): Model target by Vertex
25
+ generation_config (dict): Configurations used by the model
26
+
27
+ Returns:
28
+ str: Vertex AI text response
29
+ """
30
+ generation_config = GenerationConfig(
31
+ do_sample=True,
32
+ max_new_tokens=generation_config["max_output_tokens"],
33
+ top_k=generation_config["top_k"],
34
+ top_p=generation_config["top_p"],
35
+ temperature=generation_config["temperature"],
36
+ )
37
+
38
+ input_ids = tokenizer(query, return_tensors="pt").to(device)
39
+ outputs = model.generate(**input_ids, generation_config=generation_config)
40
+ outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
+ outputs = outputs.replace(query, "")
42
+ return outputs
43
+
44
+
45
+ def query_word(
46
+ category: str,
47
+ model: AutoModelForCausalLM,
48
+ tokenizer: AutoTokenizer,
49
+ generation_config: dict,
50
+ device: str,
51
+ ) -> str:
52
+ """Queries a word to be used for the hangman game.
53
+
54
+ Args:
55
+ category (str): Category used as source sample a word
56
+ model (str): Model target by Vertex
57
+ generation_config (dict): Configurations used by the model
58
+
59
+ Returns:
60
+ str: Queried word
61
+ """
62
+ logger.info(f"Quering word for category: '{category}'...")
63
+ query = f"Name a single existing {category}."
64
+
65
+ matched_word = ""
66
+ while not matched_word:
67
+ word = query_hf(query, model, tokenizer, generation_config, device)
68
+
69
+ # Extract word of interest from Gemma's output
70
+ for pattern in GEMMA_WORD_PATTERNS:
71
+ matched_words = re.findall(rf"{pattern}", word)
72
+ matched_words = [x for x in matched_words if x != ""]
73
+ if matched_words:
74
+ matched_word = matched_words[-1]
75
+
76
+ matched_word = matched_word.translate(str.maketrans("", "", string.punctuation))
77
+ matched_word = matched_word.lower()
78
+
79
+ logger.info("Word queried successful")
80
+ return matched_word
81
+
82
+
83
+ def query_hint(
84
+ word: str,
85
+ model: AutoModelForCausalLM,
86
+ tokenizer: AutoTokenizer,
87
+ generation_config: dict,
88
+ device: str,
89
+ ) -> str:
90
+ """Queries a hint for the hangman game.
91
+
92
+ Args:
93
+ word (str): Word used as source to create the hint
94
+ model (str): Model target by Vertex
95
+ generation_config (dict): Configurations used by the model
96
+
97
+ Returns:
98
+ str: Queried hint
99
+ """
100
+ logger.info(f"Quering hint for word: '{word}'...")
101
+ query = f"Describe the word '{word}' without mentioning it."
102
+ hint = query_hf(query, model, tokenizer, generation_config, device)
103
+ hint = re.sub(re.escape(word), "***", hint, flags=re.IGNORECASE)
104
+ logger.info("Hint queried successful")
105
+ return hint
106
+
107
+
108
+ logging.basicConfig(level=logging.INFO)
109
+ logger = logging.getLogger(__file__)