marksverdhei commited on
Commit
7eee83c
β€’
1 Parent(s): 0e7f280

Add more handlers

Browse files
Files changed (5) hide show
  1. app.py +2 -1
  2. src/handler.py +132 -59
  3. src/interface.py +10 -2
  4. src/state.py +49 -13
  5. src/utils.py +20 -0
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import logging
 
2
  from src import interface
3
 
4
  logging.basicConfig(level="DEBUG")
5
 
6
 
7
  def main():
8
- demo = interface.get_demo(wip=True)
9
  demo.launch(debug=True)
10
 
11
 
 
1
  import logging
2
+
3
  from src import interface
4
 
5
  logging.basicConfig(level="DEBUG")
6
 
7
 
8
  def main():
9
+ demo = interface.get_demo(wip=False)
10
  demo.launch(debug=True)
11
 
12
 
src/handler.py CHANGED
@@ -27,15 +27,26 @@ def get_model_predictions(input_text: str) -> torch.Tensor:
27
  return top_3
28
 
29
 
30
- def guess_is_correct(text: str, next_token: int) -> bool:
31
  """
32
  We check if the predicted token or a corresponding one with a leading whitespace
33
  matches that of the next token
34
  """
35
- logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([next_token])))
 
36
  predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text)
37
  logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
38
- return next_token in (predicted_token_start, predicted_token_whitespace)
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
@@ -44,73 +55,135 @@ def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
44
  return predicted_token_start, predicted_token_whitespace
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def handle_guess(
48
  text: str,
49
- remaining_attempts: int,
50
  *args,
51
  **kwargs,
52
  ) -> str:
53
  """
54
  * Retreives model predictions and compares the top 3 predicted tokens
55
  """
56
- logger.debug(
57
- f"Params:\ntext = {text}\n"
58
- f"remaining_attempts = {remaining_attempts}\n"
59
- f"args = {args}\n"
60
- f"kwargs = {kwargs}\n"
61
- )
62
  logger.debug(f"Initial STATE:\n{STATE}")
63
 
64
- current_tokens = all_tokens[:STATE.current_word_index]
65
- current_text = tokenizer.decode(current_tokens)
66
- player_guesses = ""
67
- lm_guesses = ""
68
 
69
  if not text:
70
- logger.debug("Returning early")
71
- return (
72
- current_text,
73
- STATE.player_points,
74
- STATE.lm_points,
75
- STATE.player_guess_str,
76
- STATE.get_lm_guess_display(remaining_attempts),
77
- remaining_attempts,
78
- "",
79
- "Guess!"
80
- )
81
-
82
- if remaining_attempts == 0:
83
- STATE.next_word()
84
- current_tokens = all_tokens[: STATE.current_word_index]
85
- remaining_attempts = MAX_ATTEMPTS
86
-
87
- remaining_attempts -= 1
88
-
89
- next_token = all_tokens[STATE.current_word_index]
90
-
91
- if guess_is_correct(text, next_token):
92
- # STATE.correct_guess()
93
- STATE.player_points += 1
94
- remaining_attempts = 0
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  else:
97
- STATE.player_guesses.append(text)
98
-
99
- # FIXME: unoptimized, computing all three every time
100
- current_text = tokenizer.decode(current_tokens)
101
- STATE.lm_guesses = get_model_predictions(current_text)[: MAX_ATTEMPTS - remaining_attempts]
102
- logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}")
103
- logger.debug(f"Pre-return STATE:\n{STATE}")
104
-
105
- # BUG: if you enter the word guess field when it says next
106
- # word, it will guess it as the next
107
- return (
108
- current_text,
109
- STATE.player_points,
110
- STATE.lm_points,
111
- STATE.player_guess_str,
112
- STATE.get_lm_guess_display(remaining_attempts),
113
- remaining_attempts,
114
- "",
115
- "Guess!" if remaining_attempts else "Next word",
116
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  return top_3
28
 
29
 
30
+ def guess_is_correct(text: str) -> bool:
31
  """
32
  We check if the predicted token or a corresponding one with a leading whitespace
33
  matches that of the next token
34
  """
35
+ current_target = all_tokens[STATE.current_word_index]
36
+ logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
37
  predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text)
38
  logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
39
+ return current_target in (predicted_token_start, predicted_token_whitespace)
40
+
41
+
42
+ def lm_is_correct() -> bool:
43
+ # NOTE: out of range if remaining attempts is 0
44
+ if STATE.remaining_attempts > 1:
45
+ return False
46
+
47
+ current_guess = STATE.lm_guesses[MAX_ATTEMPTS - STATE.remaining_attempts]
48
+ current_target = all_tokens[STATE.current_word_index]
49
+ return current_guess == current_target
50
 
51
 
52
  def get_start_and_whitespace_tokens(word: str) -> tuple[int, int]:
 
55
  return predicted_token_start, predicted_token_whitespace
56
 
57
 
58
+ def get_current_text():
59
+ return tokenizer.decode(all_tokens[: STATE.current_word_index])
60
+
61
+
62
+ def handle_player_win():
63
+ # TODO: point system
64
+ points = 1
65
+ STATE.player_points += points
66
+ STATE.button_label = "Next word"
67
+ return STATE.get_tuple(
68
+ get_current_text(),
69
+ bottom_html=f"Player gets {points} point!",
70
+ )
71
+
72
+
73
+ def handle_lm_win():
74
+ points = 1
75
+ STATE.lm_points += points
76
+ STATE.button_label = "Next word"
77
+ return STATE.get_tuple(
78
+ get_current_text(),
79
+ bottom_html=f"GPT2 gets {points} point!",
80
+ )
81
+
82
+
83
+ def handle_out_of_attempts():
84
+ STATE.button_label = "Next word"
85
+ return STATE.get_tuple(
86
+ get_current_text(),
87
+ bottom_html="Out of attempts. No one gets points!",
88
+ )
89
+
90
+
91
+ def handle_tie():
92
+ STATE.button_label = "Next word"
93
+ return STATE.get_tuple(
94
+ get_current_text(),
95
+ bottom_html="TIE! No one gets points!",
96
+ )
97
+
98
+
99
+ def handle_next_attempt():
100
+ STATE.remaining_attempts -= 1
101
+ return STATE.get_tuple(
102
+ get_current_text(), bottom_html=f"That was not it... {STATE.remaining_attempts} attempts left"
103
+ )
104
+
105
+
106
+ def handle_no_input():
107
+ return STATE.get_tuple(
108
+ get_current_text(),
109
+ bottom_html="Please write something",
110
+ )
111
+
112
+
113
+ def handle_next_word():
114
+ STATE.next_word()
115
+ STATE.lm_guesses = get_model_predictions(get_current_text())
116
+ return STATE.get_tuple()
117
+
118
+
119
  def handle_guess(
120
  text: str,
 
121
  *args,
122
  **kwargs,
123
  ) -> str:
124
  """
125
  * Retreives model predictions and compares the top 3 predicted tokens
126
  """
127
+ logger.debug("Params:\n" f"text = {text}\n" f"args = {args}\n" f"kwargs = {kwargs}\n")
 
 
 
 
 
128
  logger.debug(f"Initial STATE:\n{STATE}")
129
 
130
+ if STATE.button_label == "Next word":
131
+ return handle_next_word()
 
 
132
 
133
  if not text:
134
+ return handle_no_input()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ STATE.player_guesses.append(text)
137
+
138
+ player_correct = guess_is_correct(text)
139
+ lm_correct = lm_is_correct()
140
+
141
+ if player_correct and lm_correct:
142
+ return handle_tie()
143
+ elif player_correct and not lm_correct:
144
+ return handle_player_win()
145
+ elif lm_correct and not player_correct:
146
+ return handle_lm_win()
147
+ elif STATE.remaining_attempts == 0:
148
+ return handle_out_of_attempts()
149
  else:
150
+ return handle_next_attempt()
151
+
152
+
153
+ STATE.lm_guesses = get_model_predictions(get_current_text())
154
+
155
+
156
+ # # STATE.correct_guess()
157
+ # # remaining_attempts = 0
158
+ # # elif lm_guess_is_correct():
159
+ # # pass
160
+ # else:
161
+ # return handle_incorrect_guess()
162
+ # # elif remaining_attempts == 0:
163
+ # # return handle_out_of_attempts()
164
+
165
+ # remaining_attempts -= 1
166
+ # STATE.player_guesses.append(text)
167
+
168
+ # if remaining_attempts == 0:
169
+ # STATE.next_word()
170
+ # current_tokens = all_tokens[: STATE.current_word_index]
171
+ # remaining_attempts = MAX_ATTEMPTS
172
+
173
+ # # FIXME: unoptimized, computing all three every time
174
+ # current_text = tokenizer.decode(current_tokens)
175
+ # logger.debug(f"lm_guesses: {tokenizer.decode(STATE.lm_guesses)}")
176
+ # logger.debug(f"Pre-return STATE:\n{STATE}")
177
+
178
+ # # BUG: if you enter the word guess field when it says next
179
+ # # word, it will guess it as the next
180
+ # return (
181
+ # current_text,
182
+ # STATE.player_points,
183
+ # STATE.lm_points,
184
+ # STATE.player_guess_str,
185
+ # STATE.get_lm_guess_display(remaining_attempts),
186
+ # remaining_attempts,
187
+ # "",
188
+ # "Guess!" if remaining_attempts else "Next word",
189
+ # )
src/interface.py CHANGED
@@ -7,7 +7,11 @@ from src.state import STATE
7
  from src.state import tokenizer
8
 
9
 
10
- def build_demo():
 
 
 
 
11
  with gr.Blocks() as demo:
12
  with gr.Row():
13
  gr.Markdown("<h1><center>Can you beat a language model?</center></h1>")
@@ -35,6 +39,7 @@ def build_demo():
35
  value=MAX_ATTEMPTS,
36
  label="Remaining attempts",
37
  precision=0,
 
38
  )
39
  current_guesses = gr.Textbox(label="Your guesses")
40
  with gr.Column():
@@ -45,11 +50,13 @@ def build_demo():
45
  guess_field = gr.Textbox(label="")
46
  guess_button = gr.Button(value="Guess!")
47
 
 
 
 
48
  guess_button.click(
49
  handle_guess,
50
  inputs=[
51
  guess_field,
52
- remaining_attempts,
53
  ],
54
  outputs=[
55
  prompt_text,
@@ -60,6 +67,7 @@ def build_demo():
60
  remaining_attempts,
61
  guess_field,
62
  guess_button,
 
63
  ],
64
  )
65
 
 
7
  from src.state import tokenizer
8
 
9
 
10
+ def build_demo():
11
+ """
12
+ Builds and returns the gradio app interface
13
+ """
14
+
15
  with gr.Blocks() as demo:
16
  with gr.Row():
17
  gr.Markdown("<h1><center>Can you beat a language model?</center></h1>")
 
39
  value=MAX_ATTEMPTS,
40
  label="Remaining attempts",
41
  precision=0,
42
+ interactive=False,
43
  )
44
  current_guesses = gr.Textbox(label="Your guesses")
45
  with gr.Column():
 
50
  guess_field = gr.Textbox(label="")
51
  guess_button = gr.Button(value="Guess!")
52
 
53
+ with gr.Row():
54
+ bottom_html = gr.HTML()
55
+
56
  guess_button.click(
57
  handle_guess,
58
  inputs=[
59
  guess_field,
 
60
  ],
61
  outputs=[
62
  prompt_text,
 
67
  remaining_attempts,
68
  guess_field,
69
  guess_button,
70
+ bottom_html,
71
  ],
72
  )
73
 
src/state.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from dataclasses import dataclass
2
 
3
  from transformers import AutoModelForCausalLM
@@ -5,31 +6,64 @@ from transformers import AutoTokenizer
5
 
6
  from src.constants import MAX_ATTEMPTS
7
 
 
 
8
 
9
  @dataclass
10
  class ProgramState:
11
  current_word_index: int
12
- player_guesses: list
13
  player_points: int
14
- lm_guesses: list
15
  lm_points: int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def correct_guess(self):
18
- # FIXME: not 1 for every point
19
- self.player_points += 1
20
- self.next_word()
21
 
22
  def next_word(self):
23
  self.current_word_index += 1
24
  self.player_guesses = []
25
- self.lm_guesses = []
26
-
27
- @property
28
- def player_guess_str(self):
29
- return "\n".join(self.player_guesses)
30
 
31
- def get_lm_guess_display(self, remaining_attempts: int) -> str:
32
- return "\n".join(map(tokenizer.decode, self.lm_guesses[: MAX_ATTEMPTS - remaining_attempts]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  STATE = ProgramState(
@@ -38,6 +72,8 @@ STATE = ProgramState(
38
  lm_guesses=[],
39
  player_points=0,
40
  lm_points=0,
 
 
41
  )
42
 
43
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
1
+ import logging
2
  from dataclasses import dataclass
3
 
4
  from transformers import AutoModelForCausalLM
 
6
 
7
  from src.constants import MAX_ATTEMPTS
8
 
9
+ logger = logging.getLogger(__name__)
10
+
11
 
12
  @dataclass
13
  class ProgramState:
14
  current_word_index: int
15
+ player_guesses: list[str]
16
  player_points: int
17
+ lm_guesses: list[int]
18
  lm_points: int
19
+ remaining_attempts: int
20
+ button_label: str
21
+
22
+ @property
23
+ def player_guess_str(self):
24
+ return "\n".join(self.player_guesses)
25
+
26
+ @property
27
+ def lm_guess_str(self):
28
+ strings = list(map(tokenizer.decode, self.lm_guesses))
29
+ logger.debug(strings)
30
+
31
+ n_censored = self.remaining_attempts
32
+ for i in range(1, n_censored + 1):
33
+ strings[-i] = "****"
34
 
35
+ logger.debug(strings)
36
+ return "\n".join(strings)
 
 
37
 
38
  def next_word(self):
39
  self.current_word_index += 1
40
  self.player_guesses = []
41
+ self.lm_guesses = [] # TODO: make guesses?
42
+ self.button_label = "Guess!"
 
 
 
43
 
44
+ def get_tuple(
45
+ self,
46
+ prompt_text=None,
47
+ player_points=None,
48
+ lm_points=None,
49
+ player_guess_str=None,
50
+ lm_guess_str=None,
51
+ remaining_attempts=None,
52
+ text_field=None,
53
+ button_label=None,
54
+ bottom_html=None,
55
+ ) -> tuple:
56
+ return (
57
+ prompt_text or "", # FIXME
58
+ player_points or self.player_points,
59
+ lm_points or self.lm_points,
60
+ player_guess_str or self.player_guess_str,
61
+ lm_guess_str or self.lm_guess_str,
62
+ remaining_attempts or self.remaining_attempts,
63
+ text_field or "", # FIXME
64
+ button_label or self.button_label,
65
+ bottom_html or "", # FIXME
66
+ )
67
 
68
 
69
  STATE = ProgramState(
 
72
  lm_guesses=[],
73
  player_points=0,
74
  lm_points=0,
75
+ remaining_attempts=MAX_ATTEMPTS,
76
+ button_label="Guess!",
77
  )
78
 
79
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
src/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+
4
+ HandlerInput = namedtuple(
5
+ typename="HandlerInput",
6
+ field_names=[
7
+ "text",
8
+ "remaining_attempts",
9
+ "button_label",
10
+ ],
11
+ )
12
+
13
+ HandlerOutput = namedtuple(
14
+ typename="HandlerOutput",
15
+ field_names=[
16
+ "text",
17
+ "remaining_attempts",
18
+ "button_label",
19
+ ],
20
+ )