marksverdhei commited on
Commit
09c334f
β€’
1 Parent(s): 6d4a32a

:construction: Make the repo work somewhat

Browse files
Files changed (6) hide show
  1. src/handlers.py +2 -0
  2. src/interface.py +2 -2
  3. src/params.py +3 -0
  4. src/predictions.py +7 -21
  5. src/reducer.py +5 -4
  6. src/utils.py +12 -2
src/handlers.py CHANGED
@@ -4,6 +4,7 @@ from transformers import PreTrainedTokenizer
4
 
5
  from src.params import ReducerParams
6
  from src.shared import all_tokens
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
@@ -56,4 +57,5 @@ def handle_next_word(params: ReducerParams) -> ReducerParams:
56
  params.word_number += 1
57
  params.button_label = "Guess!"
58
  params.bottom_html = ""
 
59
  return params
 
4
 
5
  from src.params import ReducerParams
6
  from src.shared import all_tokens
7
+ from src.utils import get_current_prompt_text
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
57
  params.word_number += 1
58
  params.button_label = "Guess!"
59
  params.bottom_html = ""
60
+ params.prompt_text = get_current_prompt_text(params.word_number)
61
  return params
src/interface.py CHANGED
@@ -23,7 +23,7 @@ def build_demo():
23
  )
24
  with gr.Row():
25
  prompt_text = gr.Textbox(
26
- value=shared.all_tokens[:STARTING_INDEX],
27
  label="Context",
28
  interactive=False,
29
  )
@@ -52,7 +52,7 @@ def build_demo():
52
  with gr.Row():
53
  bottom_html = gr.HTML()
54
  with gr.Row():
55
- word_number = gr.Number(label="Word no.", interactive=False)
56
 
57
  guess_button.click(
58
  reducer.handle_guess,
 
23
  )
24
  with gr.Row():
25
  prompt_text = gr.Textbox(
26
+ value=shared.tokenizer.decode(shared.all_tokens[:STARTING_INDEX]),
27
  label="Context",
28
  interactive=False,
29
  )
 
52
  with gr.Row():
53
  bottom_html = gr.HTML()
54
  with gr.Row():
55
+ word_number = gr.Number(label="Word no.", interactive=False, precision=0)
56
 
57
  guess_button.click(
58
  reducer.handle_guess,
src/params.py CHANGED
@@ -23,3 +23,6 @@ class ReducerParams:
23
 
24
  def __iter__(self) -> Iterator:
25
  return map(partial(getattr, self), self.__dataclass_fields__)
 
 
 
 
23
 
24
  def __iter__(self) -> Iterator:
25
  return map(partial(getattr, self), self.__dataclass_fields__)
26
+
27
+ def __getitem__(self, index) -> str | int:
28
+ return list(self)[index]
src/predictions.py CHANGED
@@ -14,35 +14,21 @@ from src.text import get_text
14
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
15
 
16
 
17
- # def get_model_predictions(
18
- # *,
19
- # input_text: str,
20
- # model: PreTrainedModel,
21
- # tokenizer: PreTrainedTokenizer,
22
- # ) -> torch.Tensor:
23
- # """
24
- # Returns the indices as a torch tensor of the top 3 predicted tokens.
25
- # """
26
- # inputs = tokenizer(input_text, return_tensors="pt")
27
-
28
- # with torch.no_grad():
29
- # logits = model(**inputs).logits
30
-
31
- # last_token = logits[0, :]
32
- # top_3 = torch.topk(last_token, MAX_ATTEMPTS).indices.tolist()
33
- # return top_3
34
-
35
-
36
  def make_predictions(tokenizer: PreTrainedTokenizer) -> tuple:
37
  """
38
- Run this on startup
 
 
39
  """
40
  text = get_text()
41
  model = AutoModelForCausalLM.from_pretrained("gpt2")
42
  model.eval()
43
 
44
  inputs = tokenizer(text, return_tensors="pt")
45
- logits = model(**inputs).logits
 
 
 
46
  top_n = torch.topk(logits, MAX_ATTEMPTS)
47
  token_id_preds = top_n.indices.squeeze().tolist()
48
  tokens = list(map(tokenizer.convert_ids_to_tokens, token_id_preds))
 
14
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def make_predictions(tokenizer: PreTrainedTokenizer) -> tuple:
18
  """
19
+ Run this on startup.
20
+ Returns tuple of target_prediction_pairs and target_prediction_tokens:
21
+
22
  """
23
  text = get_text()
24
  model = AutoModelForCausalLM.from_pretrained("gpt2")
25
  model.eval()
26
 
27
  inputs = tokenizer(text, return_tensors="pt")
28
+
29
+ with torch.no_grad():
30
+ logits = model(**inputs).logits
31
+
32
  top_n = torch.topk(logits, MAX_ATTEMPTS)
33
  token_id_preds = top_n.indices.squeeze().tolist()
34
  tokens = list(map(tokenizer.convert_ids_to_tokens, token_id_preds))
src/reducer.py CHANGED
@@ -8,6 +8,7 @@ from src.handlers import handle_out_of_attempts
8
  from src.handlers import handle_player_win
9
  from src.handlers import handle_tie
10
  from src.params import ReducerParams
 
11
  from src.utils import guess_is_correct
12
  from src.utils import lm_is_correct
13
 
@@ -28,12 +29,12 @@ def _handle_guess(params: ReducerParams) -> ReducerParams:
28
  if params.button_label == "Next word":
29
  return handle_next_word(params)
30
 
31
- if not params.prompt_text:
32
  return handle_no_input(params)
33
 
34
- params.player_guesses += "\n" + params.prompt_text
35
- player_correct = guess_is_correct(params.prompt_text)
36
- lm_correct = lm_is_correct()
37
 
38
  if player_correct and lm_correct:
39
  return handle_tie(params)
 
8
  from src.handlers import handle_player_win
9
  from src.handlers import handle_tie
10
  from src.params import ReducerParams
11
+ from src.shared import tokenizer
12
  from src.utils import guess_is_correct
13
  from src.utils import lm_is_correct
14
 
 
29
  if params.button_label == "Next word":
30
  return handle_next_word(params)
31
 
32
+ if not params.guess_field:
33
  return handle_no_input(params)
34
 
35
+ params.current_guesses += "\n" + params.guess_field
36
+ player_correct = guess_is_correct(params, tokenizer)
37
+ lm_correct = lm_is_correct(params)
38
 
39
  if player_correct and lm_correct:
40
  return handle_tie(params)
src/utils.py CHANGED
@@ -2,13 +2,19 @@ import logging
2
 
3
  from transformers import PreTrainedTokenizer
4
 
 
5
  from src.constants import MAX_ATTEMPTS
 
6
  from src.params import ReducerParams
7
  from src.shared import token_id_predictions
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
 
 
 
 
 
12
  def get_start_and_whitespace_tokens(
13
  word: str,
14
  tokenizer: PreTrainedTokenizer,
@@ -34,13 +40,17 @@ def lm_is_correct(params: ReducerParams) -> bool:
34
  return current_guess == current_target
35
 
36
 
37
- def guess_is_correct(params: ReducerParams, current_target: int, tokenizer: PreTrainedTokenizer) -> bool:
38
  """
39
  We check if the predicted token or a corresponding one with a leading whitespace
40
  matches that of the next token
41
  """
 
 
 
 
42
  logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
43
 
44
- predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(params.guess_field)
45
  logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
46
  return current_target in (predicted_token_start, predicted_token_whitespace)
 
2
 
3
  from transformers import PreTrainedTokenizer
4
 
5
+ from src import shared
6
  from src.constants import MAX_ATTEMPTS
7
+ from src.constants import STARTING_INDEX
8
  from src.params import ReducerParams
9
  from src.shared import token_id_predictions
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
+ def get_current_prompt_text(word_number):
15
+ return shared.tokenizer.decode(shared.all_tokens[: STARTING_INDEX + word_number])
16
+
17
+
18
  def get_start_and_whitespace_tokens(
19
  word: str,
20
  tokenizer: PreTrainedTokenizer,
 
40
  return current_guess == current_target
41
 
42
 
43
+ def guess_is_correct(params: ReducerParams, tokenizer: PreTrainedTokenizer) -> bool:
44
  """
45
  We check if the predicted token or a corresponding one with a leading whitespace
46
  matches that of the next token
47
  """
48
+ # FIXME: handle indexerro
49
+ print(STARTING_INDEX + params.word_number)
50
+ current_target = shared.all_tokens[STARTING_INDEX + params.word_number]
51
+
52
  logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
53
 
54
+ predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(params.guess_field, tokenizer)
55
  logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
56
  return current_target in (predicted_token_start, predicted_token_whitespace)