Spaces:
Runtime error
Runtime error
marksverdhei
commited on
Commit
β’
6d4a32a
1
Parent(s):
1532c35
:white_check_mark: Fix imports
Browse files- src/handlers.py +2 -19
- src/interface.py +4 -4
- src/predictions.py +2 -3
- src/reducer.py +10 -11
- src/shared.py +3 -3
- src/utils.py +18 -2
src/handlers.py
CHANGED
@@ -2,29 +2,12 @@ import logging
|
|
2 |
|
3 |
from transformers import PreTrainedTokenizer
|
4 |
|
5 |
-
from
|
6 |
-
from
|
7 |
-
from shared import all_tokens
|
8 |
-
from shared import text
|
9 |
-
from shared import token_id_predictions
|
10 |
-
from shared import tokenizer
|
11 |
-
from utils import get_start_and_whitespace_tokens
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
15 |
|
16 |
-
def lm_is_correct(params: ReducerParams) -> bool:
|
17 |
-
# NOTE: out of range if remaining attempts is 0
|
18 |
-
if params.remaining_attempts < 1:
|
19 |
-
return False
|
20 |
-
|
21 |
-
idx = MAX_ATTEMPTS - params.remaining_attempts
|
22 |
-
|
23 |
-
current_guess = params.lm_guesses[]
|
24 |
-
current_target = all_tokens[params.current_word_index]
|
25 |
-
return current_guess == current_target
|
26 |
-
|
27 |
-
|
28 |
def get_current_text(params: ReducerParams, tokenizer: PreTrainedTokenizer):
|
29 |
return tokenizer.decode(all_tokens[: params.word_number])
|
30 |
|
|
|
2 |
|
3 |
from transformers import PreTrainedTokenizer
|
4 |
|
5 |
+
from src.params import ReducerParams
|
6 |
+
from src.shared import all_tokens
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def get_current_text(params: ReducerParams, tokenizer: PreTrainedTokenizer):
|
12 |
return tokenizer.decode(all_tokens[: params.word_number])
|
13 |
|
src/interface.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
import reducer
|
|
|
4 |
from src.constants import MAX_ATTEMPTS
|
5 |
-
from src.
|
6 |
-
from src.state import tokenizer
|
7 |
|
8 |
|
9 |
def build_demo():
|
@@ -23,7 +23,7 @@ def build_demo():
|
|
23 |
)
|
24 |
with gr.Row():
|
25 |
prompt_text = gr.Textbox(
|
26 |
-
value=
|
27 |
label="Context",
|
28 |
interactive=False,
|
29 |
)
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
from src import reducer
|
4 |
+
from src import shared
|
5 |
from src.constants import MAX_ATTEMPTS
|
6 |
+
from src.constants import STARTING_INDEX
|
|
|
7 |
|
8 |
|
9 |
def build_demo():
|
|
|
23 |
)
|
24 |
with gr.Row():
|
25 |
prompt_text = gr.Textbox(
|
26 |
+
value=shared.all_tokens[:STARTING_INDEX],
|
27 |
label="Context",
|
28 |
interactive=False,
|
29 |
)
|
src/predictions.py
CHANGED
@@ -6,11 +6,10 @@ predicted tokens.
|
|
6 |
import torch
|
7 |
from transformers import AutoModelForCausalLM
|
8 |
from transformers import AutoTokenizer
|
9 |
-
from transformers import PreTrainedModel
|
10 |
from transformers import PreTrainedTokenizer
|
11 |
|
12 |
-
from constants import MAX_ATTEMPTS
|
13 |
-
from text import get_text
|
14 |
|
15 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
16 |
|
|
|
6 |
import torch
|
7 |
from transformers import AutoModelForCausalLM
|
8 |
from transformers import AutoTokenizer
|
|
|
9 |
from transformers import PreTrainedTokenizer
|
10 |
|
11 |
+
from src.constants import MAX_ATTEMPTS
|
12 |
+
from src.text import get_text
|
13 |
|
14 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
15 |
|
src/reducer.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1 |
import logging
|
2 |
|
3 |
-
from handlers import handle_lm_win
|
4 |
-
from handlers import handle_next_attempt
|
5 |
-
from handlers import handle_next_word
|
6 |
-
from handlers import handle_no_input
|
7 |
-
from handlers import handle_out_of_attempts
|
8 |
-
from handlers import handle_player_win
|
9 |
-
from handlers import handle_tie
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from utils import guess_is_correct
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
|
|
1 |
import logging
|
2 |
|
3 |
+
from src.handlers import handle_lm_win
|
4 |
+
from src.handlers import handle_next_attempt
|
5 |
+
from src.handlers import handle_next_word
|
6 |
+
from src.handlers import handle_no_input
|
7 |
+
from src.handlers import handle_out_of_attempts
|
8 |
+
from src.handlers import handle_player_win
|
9 |
+
from src.handlers import handle_tie
|
10 |
+
from src.params import ReducerParams
|
11 |
+
from src.utils import guess_is_correct
|
12 |
+
from src.utils import lm_is_correct
|
|
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
src/shared.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
from transformers import AutoTokenizer
|
2 |
|
3 |
-
from predictions import make_predictions
|
4 |
-
from text import get_text
|
5 |
|
6 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
7 |
-
token_id_predictions, token_predictions = make_predictions()
|
8 |
text = get_text()
|
9 |
all_tokens = tokenizer.encode(text)
|
|
|
1 |
from transformers import AutoTokenizer
|
2 |
|
3 |
+
from src.predictions import make_predictions
|
4 |
+
from src.text import get_text
|
5 |
|
6 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
7 |
+
token_id_predictions, token_predictions = make_predictions(tokenizer)
|
8 |
text = get_text()
|
9 |
all_tokens = tokenizer.encode(text)
|
src/utils.py
CHANGED
@@ -2,7 +2,9 @@ import logging
|
|
2 |
|
3 |
from transformers import PreTrainedTokenizer
|
4 |
|
5 |
-
from
|
|
|
|
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
8 |
|
@@ -19,12 +21,26 @@ def get_start_and_whitespace_tokens(
|
|
19 |
return predicted_token_start, predicted_token_whitespace
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def guess_is_correct(params: ReducerParams, current_target: int, tokenizer: PreTrainedTokenizer) -> bool:
|
23 |
"""
|
24 |
We check if the predicted token or a corresponding one with a leading whitespace
|
25 |
matches that of the next token
|
26 |
"""
|
27 |
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
|
28 |
-
|
|
|
29 |
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
|
30 |
return current_target in (predicted_token_start, predicted_token_whitespace)
|
|
|
2 |
|
3 |
from transformers import PreTrainedTokenizer
|
4 |
|
5 |
+
from src.constants import MAX_ATTEMPTS
|
6 |
+
from src.params import ReducerParams
|
7 |
+
from src.shared import token_id_predictions
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
|
|
21 |
return predicted_token_start, predicted_token_whitespace
|
22 |
|
23 |
|
24 |
+
def lm_is_correct(params: ReducerParams) -> bool:
|
25 |
+
# NOTE: out of range if remaining attempts is 0
|
26 |
+
if params.remaining_attempts < 1:
|
27 |
+
return False
|
28 |
+
|
29 |
+
idx = MAX_ATTEMPTS - params.remaining_attempts
|
30 |
+
|
31 |
+
current_guess = token_id_predictions[params.word_number][1][idx]
|
32 |
+
current_target = token_id_predictions[params.word_number][0]
|
33 |
+
|
34 |
+
return current_guess == current_target
|
35 |
+
|
36 |
+
|
37 |
def guess_is_correct(params: ReducerParams, current_target: int, tokenizer: PreTrainedTokenizer) -> bool:
|
38 |
"""
|
39 |
We check if the predicted token or a corresponding one with a leading whitespace
|
40 |
matches that of the next token
|
41 |
"""
|
42 |
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
|
43 |
+
|
44 |
+
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(params.guess_field)
|
45 |
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
|
46 |
return current_target in (predicted_token_start, predicted_token_whitespace)
|