Spaces:
Runtime error
Runtime error
marksverdhei
commited on
Commit
β’
7eee83c
1
Parent(s):
0e7f280
Add more handlers
Browse files- app.py +2 -1
- src/handler.py +132 -59
- src/interface.py +10 -2
- src/state.py +49 -13
- src/utils.py +20 -0
app.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import logging
|
|
|
2 |
from src import interface
|
3 |
|
4 |
logging.basicConfig(level="DEBUG")
|
5 |
|
6 |
|
7 |
def main():
|
8 |
-
demo = interface.get_demo(wip=
|
9 |
demo.launch(debug=True)
|
10 |
|
11 |
|
|
|
1 |
import logging
|
2 |
+
|
3 |
from src import interface
|
4 |
|
5 |
logging.basicConfig(level="DEBUG")
|
6 |
|
7 |
|
8 |
def main():
|
9 |
+
demo = interface.get_demo(wip=False)
|
10 |
demo.launch(debug=True)
|
11 |
|
12 |
|
src/handler.py
CHANGED
@@ -27,15 +27,26 @@ def get_model_predictions(input_text: str) -> torch.Tensor:
|
|
27 |
return top_3
|
28 |
|
29 |
|
30 |
-
def guess_is_correct(text: str
|
31 |
"""
|
32 |
We check if the predicted token or a corresponding one with a leading whitespace
|
33 |
matches that of the next token
|
34 |
"""
|
35 |
-
|
|
|
36 |
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text)
|
37 |
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
|
38 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
|
@@ -44,73 +55,135 @@ def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
|
|
44 |
return predicted_token_start, predicted_token_whitespace
|
45 |
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def handle_guess(
|
48 |
text: str,
|
49 |
-
remaining_attempts: int,
|
50 |
*args,
|
51 |
**kwargs,
|
52 |
) -> str:
|
53 |
"""
|
54 |
* Retreives model predictions and compares the top 3 predicted tokens
|
55 |
"""
|
56 |
-
logger.debug(
|
57 |
-
f"Params:\ntext = {text}\n"
|
58 |
-
f"remaining_attempts = {remaining_attempts}\n"
|
59 |
-
f"args = {args}\n"
|
60 |
-
f"kwargs = {kwargs}\n"
|
61 |
-
)
|
62 |
logger.debug(f"Initial STATE:\n{STATE}")
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
player_guesses = ""
|
67 |
-
lm_guesses = ""
|
68 |
|
69 |
if not text:
|
70 |
-
|
71 |
-
return (
|
72 |
-
current_text,
|
73 |
-
STATE.player_points,
|
74 |
-
STATE.lm_points,
|
75 |
-
STATE.player_guess_str,
|
76 |
-
STATE.get_lm_guess_display(remaining_attempts),
|
77 |
-
remaining_attempts,
|
78 |
-
"",
|
79 |
-
"Guess!"
|
80 |
-
)
|
81 |
-
|
82 |
-
if remaining_attempts == 0:
|
83 |
-
STATE.next_word()
|
84 |
-
current_tokens = all_tokens[: STATE.current_word_index]
|
85 |
-
remaining_attempts = MAX_ATTEMPTS
|
86 |
-
|
87 |
-
remaining_attempts -= 1
|
88 |
-
|
89 |
-
next_token = all_tokens[STATE.current_word_index]
|
90 |
-
|
91 |
-
if guess_is_correct(text, next_token):
|
92 |
-
# STATE.correct_guess()
|
93 |
-
STATE.player_points += 1
|
94 |
-
remaining_attempts = 0
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
else:
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
return top_3
|
28 |
|
29 |
|
30 |
+
def guess_is_correct(text: str) -> bool:
|
31 |
"""
|
32 |
We check if the predicted token or a corresponding one with a leading whitespace
|
33 |
matches that of the next token
|
34 |
"""
|
35 |
+
current_target = all_tokens[STATE.current_word_index]
|
36 |
+
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
|
37 |
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text)
|
38 |
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
|
39 |
+
return current_target in (predicted_token_start, predicted_token_whitespace)
|
40 |
+
|
41 |
+
|
42 |
+
def lm_is_correct() -> bool:
|
43 |
+
# NOTE: out of range if remaining attempts is 0
|
44 |
+
if STATE.remaining_attempts > 1:
|
45 |
+
return False
|
46 |
+
|
47 |
+
current_guess = STATE.lm_guesses[MAX_ATTEMPTS - STATE.remaining_attempts]
|
48 |
+
current_target = all_tokens[STATE.current_word_index]
|
49 |
+
return current_guess == current_target
|
50 |
|
51 |
|
52 |
def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
|
|
|
55 |
return predicted_token_start, predicted_token_whitespace
|
56 |
|
57 |
|
58 |
+
def get_current_text():
|
59 |
+
return tokenizer.decode(all_tokens[: STATE.current_word_index])
|
60 |
+
|
61 |
+
|
62 |
+
def handle_player_win():
|
63 |
+
# TODO: point system
|
64 |
+
points = 1
|
65 |
+
STATE.player_points += points
|
66 |
+
STATE.button_label = "Next word"
|
67 |
+
return STATE.get_tuple(
|
68 |
+
get_current_text(),
|
69 |
+
bottom_html=f"Player gets {points} point!",
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
def handle_lm_win():
|
74 |
+
points = 1
|
75 |
+
STATE.lm_points += points
|
76 |
+
STATE.button_label = "Next word"
|
77 |
+
return STATE.get_tuple(
|
78 |
+
get_current_text(),
|
79 |
+
bottom_html=f"GPT2 gets {points} point!",
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
def handle_out_of_attempts():
|
84 |
+
STATE.button_label = "Next word"
|
85 |
+
return STATE.get_tuple(
|
86 |
+
get_current_text(),
|
87 |
+
bottom_html="Out of attempts. No one gets points!",
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
def handle_tie():
|
92 |
+
STATE.button_label = "Next word"
|
93 |
+
return STATE.get_tuple(
|
94 |
+
get_current_text(),
|
95 |
+
bottom_html="TIE! No one gets points!",
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def handle_next_attempt():
|
100 |
+
STATE.remaining_attempts -= 1
|
101 |
+
return STATE.get_tuple(
|
102 |
+
get_current_text(), bottom_html=f"That was not it... {STATE.remaining_attempts} attempts left"
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
def handle_no_input():
|
107 |
+
return STATE.get_tuple(
|
108 |
+
get_current_text(),
|
109 |
+
bottom_html="Please write something",
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def handle_next_word():
|
114 |
+
STATE.next_word()
|
115 |
+
STATE.lm_guesses = get_model_predictions(get_current_text())
|
116 |
+
return STATE.get_tuple()
|
117 |
+
|
118 |
+
|
119 |
def handle_guess(
|
120 |
text: str,
|
|
|
121 |
*args,
|
122 |
**kwargs,
|
123 |
) -> str:
|
124 |
"""
|
125 |
* Retreives model predictions and compares the top 3 predicted tokens
|
126 |
"""
|
127 |
+
logger.debug("Params:\n" f"text = {text}\n" f"args = {args}\n" f"kwargs = {kwargs}\n")
|
|
|
|
|
|
|
|
|
|
|
128 |
logger.debug(f"Initial STATE:\n{STATE}")
|
129 |
|
130 |
+
if STATE.button_label == "Next word":
|
131 |
+
return handle_next_word()
|
|
|
|
|
132 |
|
133 |
if not text:
|
134 |
+
return handle_no_input()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
STATE.player_guesses.append(text)
|
137 |
+
|
138 |
+
player_correct = guess_is_correct(text)
|
139 |
+
lm_correct = lm_is_correct()
|
140 |
+
|
141 |
+
if player_correct and lm_correct:
|
142 |
+
return handle_tie()
|
143 |
+
elif player_correct and not lm_correct:
|
144 |
+
return handle_player_win()
|
145 |
+
elif lm_correct and not player_correct:
|
146 |
+
return handle_lm_win()
|
147 |
+
elif STATE.remaining_attempts == 0:
|
148 |
+
return handle_out_of_attempts()
|
149 |
else:
|
150 |
+
return handle_next_attempt()
|
151 |
+
|
152 |
+
|
153 |
+
STATE.lm_guesses = get_model_predictions(get_current_text())
|
154 |
+
|
155 |
+
|
156 |
+
# # STATE.correct_guess()
|
157 |
+
# # remaining_attempts = 0
|
158 |
+
# # elif lm_guess_is_correct():
|
159 |
+
# # pass
|
160 |
+
# else:
|
161 |
+
# return handle_incorrect_guess()
|
162 |
+
# # elif remaining_attempts == 0:
|
163 |
+
# # return handle_out_of_attempts()
|
164 |
+
|
165 |
+
# remaining_attempts -= 1
|
166 |
+
# STATE.player_guesses.append(text)
|
167 |
+
|
168 |
+
# if remaining_attempts == 0:
|
169 |
+
# STATE.next_word()
|
170 |
+
# current_tokens = all_tokens[: STATE.current_word_index]
|
171 |
+
# remaining_attempts = MAX_ATTEMPTS
|
172 |
+
|
173 |
+
# # FIXME: unoptimized, computing all three every time
|
174 |
+
# current_text = tokenizer.decode(current_tokens)
|
175 |
+
# logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}")
|
176 |
+
# logger.debug(f"Pre-return STATE:\n{STATE}")
|
177 |
+
|
178 |
+
# # BUG: if you enter the word guess field when it says next
|
179 |
+
# # word, it will guess it as the next
|
180 |
+
# return (
|
181 |
+
# current_text,
|
182 |
+
# STATE.player_points,
|
183 |
+
# STATE.lm_points,
|
184 |
+
# STATE.player_guess_str,
|
185 |
+
# STATE.get_lm_guess_display(remaining_attempts),
|
186 |
+
# remaining_attempts,
|
187 |
+
# "",
|
188 |
+
# "Guess!" if remaining_attempts else "Next word",
|
189 |
+
# )
|
src/interface.py
CHANGED
@@ -7,7 +7,11 @@ from src.state import STATE
|
|
7 |
from src.state import tokenizer
|
8 |
|
9 |
|
10 |
-
def build_demo():
|
|
|
|
|
|
|
|
|
11 |
with gr.Blocks() as demo:
|
12 |
with gr.Row():
|
13 |
gr.Markdown("<h1><center>Can you beat a language model?</center></h1>")
|
@@ -35,6 +39,7 @@ def build_demo():
|
|
35 |
value=MAX_ATTEMPTS,
|
36 |
label="Remaining attempts",
|
37 |
precision=0,
|
|
|
38 |
)
|
39 |
current_guesses = gr.Textbox(label="Your guesses")
|
40 |
with gr.Column():
|
@@ -45,11 +50,13 @@ def build_demo():
|
|
45 |
guess_field = gr.Textbox(label="")
|
46 |
guess_button = gr.Button(value="Guess!")
|
47 |
|
|
|
|
|
|
|
48 |
guess_button.click(
|
49 |
handle_guess,
|
50 |
inputs=[
|
51 |
guess_field,
|
52 |
-
remaining_attempts,
|
53 |
],
|
54 |
outputs=[
|
55 |
prompt_text,
|
@@ -60,6 +67,7 @@ def build_demo():
|
|
60 |
remaining_attempts,
|
61 |
guess_field,
|
62 |
guess_button,
|
|
|
63 |
],
|
64 |
)
|
65 |
|
|
|
7 |
from src.state import tokenizer
|
8 |
|
9 |
|
10 |
+
def build_demo():
|
11 |
+
"""
|
12 |
+
Builds and returns the gradio app interface
|
13 |
+
"""
|
14 |
+
|
15 |
with gr.Blocks() as demo:
|
16 |
with gr.Row():
|
17 |
gr.Markdown("<h1><center>Can you beat a language model?</center></h1>")
|
|
|
39 |
value=MAX_ATTEMPTS,
|
40 |
label="Remaining attempts",
|
41 |
precision=0,
|
42 |
+
interactive=False,
|
43 |
)
|
44 |
current_guesses = gr.Textbox(label="Your guesses")
|
45 |
with gr.Column():
|
|
|
50 |
guess_field = gr.Textbox(label="")
|
51 |
guess_button = gr.Button(value="Guess!")
|
52 |
|
53 |
+
with gr.Row():
|
54 |
+
bottom_html = gr.HTML()
|
55 |
+
|
56 |
guess_button.click(
|
57 |
handle_guess,
|
58 |
inputs=[
|
59 |
guess_field,
|
|
|
60 |
],
|
61 |
outputs=[
|
62 |
prompt_text,
|
|
|
67 |
remaining_attempts,
|
68 |
guess_field,
|
69 |
guess_button,
|
70 |
+
bottom_html,
|
71 |
],
|
72 |
)
|
73 |
|
src/state.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from dataclasses import dataclass
|
2 |
|
3 |
from transformers import AutoModelForCausalLM
|
@@ -5,31 +6,64 @@ from transformers import AutoTokenizer
|
|
5 |
|
6 |
from src.constants import MAX_ATTEMPTS
|
7 |
|
|
|
|
|
8 |
|
9 |
@dataclass
|
10 |
class ProgramState:
|
11 |
current_word_index: int
|
12 |
-
player_guesses: list
|
13 |
player_points: int
|
14 |
-
lm_guesses: list
|
15 |
lm_points: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
self.player_points += 1
|
20 |
-
self.next_word()
|
21 |
|
22 |
def next_word(self):
|
23 |
self.current_word_index += 1
|
24 |
self.player_guesses = []
|
25 |
-
self.lm_guesses = []
|
26 |
-
|
27 |
-
@property
|
28 |
-
def player_guess_str(self):
|
29 |
-
return "\n".join(self.player_guesses)
|
30 |
|
31 |
-
def
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
STATE = ProgramState(
|
@@ -38,6 +72,8 @@ STATE = ProgramState(
|
|
38 |
lm_guesses=[],
|
39 |
player_points=0,
|
40 |
lm_points=0,
|
|
|
|
|
41 |
)
|
42 |
|
43 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
1 |
+
import logging
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
from transformers import AutoModelForCausalLM
|
|
|
6 |
|
7 |
from src.constants import MAX_ATTEMPTS
|
8 |
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
|
12 |
@dataclass
|
13 |
class ProgramState:
|
14 |
current_word_index: int
|
15 |
+
player_guesses: list[str]
|
16 |
player_points: int
|
17 |
+
lm_guesses: list[int]
|
18 |
lm_points: int
|
19 |
+
remaining_attempts: int
|
20 |
+
button_label: str
|
21 |
+
|
22 |
+
@property
|
23 |
+
def player_guess_str(self):
|
24 |
+
return "\n".join(self.player_guesses)
|
25 |
+
|
26 |
+
@property
|
27 |
+
def lm_guess_str(self):
|
28 |
+
strings = list(map(tokenizer.decode, self.lm_guesses))
|
29 |
+
logger.debug(strings)
|
30 |
+
|
31 |
+
n_censored = self.remaining_attempts
|
32 |
+
for i in range(1, n_censored + 1):
|
33 |
+
strings[-i] = "****"
|
34 |
|
35 |
+
logger.debug(strings)
|
36 |
+
return "\n".join(strings)
|
|
|
|
|
37 |
|
38 |
def next_word(self):
|
39 |
self.current_word_index += 1
|
40 |
self.player_guesses = []
|
41 |
+
self.lm_guesses = [] # TODO: make guesses?
|
42 |
+
self.button_label = "Guess!"
|
|
|
|
|
|
|
43 |
|
44 |
+
def get_tuple(
|
45 |
+
self,
|
46 |
+
prompt_text=None,
|
47 |
+
player_points=None,
|
48 |
+
lm_points=None,
|
49 |
+
player_guess_str=None,
|
50 |
+
lm_guess_str=None,
|
51 |
+
remaining_attempts=None,
|
52 |
+
text_field=None,
|
53 |
+
button_label=None,
|
54 |
+
bottom_html=None,
|
55 |
+
) -> tuple:
|
56 |
+
return (
|
57 |
+
prompt_text or "", # FIXME
|
58 |
+
player_points or self.player_points,
|
59 |
+
lm_points or self.lm_points,
|
60 |
+
player_guess_str or self.player_guess_str,
|
61 |
+
lm_guess_str or self.lm_guess_str,
|
62 |
+
remaining_attempts or self.remaining_attempts,
|
63 |
+
text_field or "", # FIXME
|
64 |
+
button_label or self.button_label,
|
65 |
+
bottom_html or "", # FIXME
|
66 |
+
)
|
67 |
|
68 |
|
69 |
STATE = ProgramState(
|
|
|
72 |
lm_guesses=[],
|
73 |
player_points=0,
|
74 |
lm_points=0,
|
75 |
+
remaining_attempts=MAX_ATTEMPTS,
|
76 |
+
button_label="Guess!",
|
77 |
)
|
78 |
|
79 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
src/utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
|
3 |
+
|
4 |
+
HandlerInput = namedtuple(
|
5 |
+
typename="HandlerInput",
|
6 |
+
field_names=[
|
7 |
+
"text",
|
8 |
+
"remaining_attempts",
|
9 |
+
"button_label",
|
10 |
+
],
|
11 |
+
)
|
12 |
+
|
13 |
+
HandlerOutput = namedtuple(
|
14 |
+
typename="HandlerOutput",
|
15 |
+
field_names=[
|
16 |
+
"text",
|
17 |
+
"remaining_attempts",
|
18 |
+
"button_label",
|
19 |
+
],
|
20 |
+
)
|