marksverdhei commited on
Commit
6d4a32a
β€’
1 Parent(s): 1532c35

:white_check_mark: Fix imports

Browse files
Files changed (6) hide show
  1. src/handlers.py +2 -19
  2. src/interface.py +4 -4
  3. src/predictions.py +2 -3
  4. src/reducer.py +10 -11
  5. src/shared.py +3 -3
  6. src/utils.py +18 -2
src/handlers.py CHANGED
@@ -2,29 +2,12 @@ import logging
2
 
3
  from transformers import PreTrainedTokenizer
4
 
5
- from constants import MAX_ATTEMPTS
6
- from params import ReducerParams
7
- from shared import all_tokens
8
- from shared import text
9
- from shared import token_id_predictions
10
- from shared import tokenizer
11
- from utils import get_start_and_whitespace_tokens
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
- def lm_is_correct(params: ReducerParams) -> bool:
17
- # NOTE: out of range if remaining attempts is 0
18
- if params.remaining_attempts < 1:
19
- return False
20
-
21
- idx = MAX_ATTEMPTS - params.remaining_attempts
22
-
23
- current_guess = params.lm_guesses[]
24
- current_target = all_tokens[params.current_word_index]
25
- return current_guess == current_target
26
-
27
-
28
  def get_current_text(params: ReducerParams, tokenizer: PreTrainedTokenizer):
29
  return tokenizer.decode(all_tokens[: params.word_number])
30
 
 
2
 
3
  from transformers import PreTrainedTokenizer
4
 
5
+ from src.params import ReducerParams
6
+ from src.shared import all_tokens
 
 
 
 
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def get_current_text(params: ReducerParams, tokenizer: PreTrainedTokenizer):
12
  return tokenizer.decode(all_tokens[: params.word_number])
13
 
src/interface.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
 
3
- import reducer
 
4
  from src.constants import MAX_ATTEMPTS
5
- from src.state import STATE
6
- from src.state import tokenizer
7
 
8
 
9
  def build_demo():
@@ -23,7 +23,7 @@ def build_demo():
23
  )
24
  with gr.Row():
25
  prompt_text = gr.Textbox(
26
- value=tokenizer.decode(all_tokens[: STATE.current_word_index]),
27
  label="Context",
28
  interactive=False,
29
  )
 
1
  import gradio as gr
2
 
3
+ from src import reducer
4
+ from src import shared
5
  from src.constants import MAX_ATTEMPTS
6
+ from src.constants import STARTING_INDEX
 
7
 
8
 
9
  def build_demo():
 
23
  )
24
  with gr.Row():
25
  prompt_text = gr.Textbox(
26
+ value=shared.all_tokens[:STARTING_INDEX],
27
  label="Context",
28
  interactive=False,
29
  )
src/predictions.py CHANGED
@@ -6,11 +6,10 @@ predicted tokens.
6
  import torch
7
  from transformers import AutoModelForCausalLM
8
  from transformers import AutoTokenizer
9
- from transformers import PreTrainedModel
10
  from transformers import PreTrainedTokenizer
11
 
12
- from constants import MAX_ATTEMPTS
13
- from text import get_text
14
 
15
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
16
 
 
6
  import torch
7
  from transformers import AutoModelForCausalLM
8
  from transformers import AutoTokenizer
 
9
  from transformers import PreTrainedTokenizer
10
 
11
+ from src.constants import MAX_ATTEMPTS
12
+ from src.text import get_text
13
 
14
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
15
 
src/reducer.py CHANGED
@@ -1,16 +1,15 @@
1
  import logging
2
 
3
- from handlers import handle_lm_win
4
- from handlers import handle_next_attempt
5
- from handlers import handle_next_word
6
- from handlers import handle_no_input
7
- from handlers import handle_out_of_attempts
8
- from handlers import handle_player_win
9
- from handlers import handle_tie
10
- from handlers import lm_is_correct
11
- from params import ReducerParams
12
- from shared import tokenizer, token_id_predictions, token_predictions
13
- from utils import guess_is_correct
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
1
  import logging
2
 
3
+ from src.handlers import handle_lm_win
4
+ from src.handlers import handle_next_attempt
5
+ from src.handlers import handle_next_word
6
+ from src.handlers import handle_no_input
7
+ from src.handlers import handle_out_of_attempts
8
+ from src.handlers import handle_player_win
9
+ from src.handlers import handle_tie
10
+ from src.params import ReducerParams
11
+ from src.utils import guess_is_correct
12
+ from src.utils import lm_is_correct
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
src/shared.py CHANGED
@@ -1,9 +1,9 @@
1
  from transformers import AutoTokenizer
2
 
3
- from predictions import make_predictions
4
- from text import get_text
5
 
6
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
7
- token_id_predictions, token_predictions = make_predictions()
8
  text = get_text()
9
  all_tokens = tokenizer.encode(text)
 
1
  from transformers import AutoTokenizer
2
 
3
+ from src.predictions import make_predictions
4
+ from src.text import get_text
5
 
6
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
7
+ token_id_predictions, token_predictions = make_predictions(tokenizer)
8
  text = get_text()
9
  all_tokens = tokenizer.encode(text)
src/utils.py CHANGED
@@ -2,7 +2,9 @@ import logging
2
 
3
  from transformers import PreTrainedTokenizer
4
 
5
- from params import ReducerParams
 
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
@@ -19,12 +21,26 @@ def get_start_and_whitespace_tokens(
19
  return predicted_token_start, predicted_token_whitespace
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def guess_is_correct(params: ReducerParams, current_target: int, tokenizer: PreTrainedTokenizer) -> bool:
23
  """
24
  We check if the predicted token or a corresponding one with a leading whitespace
25
  matches that of the next token
26
  """
27
  logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
28
- predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(text)
 
29
  logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
30
  return current_target in (predicted_token_start, predicted_token_whitespace)
 
2
 
3
  from transformers import PreTrainedTokenizer
4
 
5
+ from src.constants import MAX_ATTEMPTS
6
+ from src.params import ReducerParams
7
+ from src.shared import token_id_predictions
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
21
  return predicted_token_start, predicted_token_whitespace
22
 
23
 
24
+ def lm_is_correct(params: ReducerParams) -> bool:
25
+ # NOTE: out of range if remaining attempts is 0
26
+ if params.remaining_attempts < 1:
27
+ return False
28
+
29
+ idx = MAX_ATTEMPTS - params.remaining_attempts
30
+
31
+ current_guess = token_id_predictions[params.word_number][1][idx]
32
+ current_target = token_id_predictions[params.word_number][0]
33
+
34
+ return current_guess == current_target
35
+
36
+
37
  def guess_is_correct(params: ReducerParams, current_target: int, tokenizer: PreTrainedTokenizer) -> bool:
38
  """
39
  We check if the predicted token or a corresponding one with a leading whitespace
40
  matches that of the next token
41
  """
42
  logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
43
+
44
+ predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(params.guess_field)
45
  logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
46
  return current_target in (predicted_token_start, predicted_token_whitespace)