Spaces:
Runtime error
Runtime error
marksverdhei
commited on
Commit
β’
09c334f
1
Parent(s):
6d4a32a
:construction: Make the repo work somewhat
Browse files- src/handlers.py +2 -0
- src/interface.py +2 -2
- src/params.py +3 -0
- src/predictions.py +7 -21
- src/reducer.py +5 -4
- src/utils.py +12 -2
src/handlers.py
CHANGED
@@ -4,6 +4,7 @@ from transformers import PreTrainedTokenizer
|
|
4 |
|
5 |
from src.params import ReducerParams
|
6 |
from src.shared import all_tokens
|
|
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
@@ -56,4 +57,5 @@ def handle_next_word(params: ReducerParams) -> ReducerParams:
|
|
56 |
params.word_number += 1
|
57 |
params.button_label = "Guess!"
|
58 |
params.bottom_html = ""
|
|
|
59 |
return params
|
|
|
4 |
|
5 |
from src.params import ReducerParams
|
6 |
from src.shared import all_tokens
|
7 |
+
from src.utils import get_current_prompt_text
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
|
|
57 |
params.word_number += 1
|
58 |
params.button_label = "Guess!"
|
59 |
params.bottom_html = ""
|
60 |
+
params.prompt_text = get_current_prompt_text(params.word_number)
|
61 |
return params
|
src/interface.py
CHANGED
@@ -23,7 +23,7 @@ def build_demo():
|
|
23 |
)
|
24 |
with gr.Row():
|
25 |
prompt_text = gr.Textbox(
|
26 |
-
value=shared.all_tokens[:STARTING_INDEX],
|
27 |
label="Context",
|
28 |
interactive=False,
|
29 |
)
|
@@ -52,7 +52,7 @@ def build_demo():
|
|
52 |
with gr.Row():
|
53 |
bottom_html = gr.HTML()
|
54 |
with gr.Row():
|
55 |
-
word_number = gr.Number(label="Word no.", interactive=False)
|
56 |
|
57 |
guess_button.click(
|
58 |
reducer.handle_guess,
|
|
|
23 |
)
|
24 |
with gr.Row():
|
25 |
prompt_text = gr.Textbox(
|
26 |
+
value=shared.tokenizer.decode(shared.all_tokens[:STARTING_INDEX]),
|
27 |
label="Context",
|
28 |
interactive=False,
|
29 |
)
|
|
|
52 |
with gr.Row():
|
53 |
bottom_html = gr.HTML()
|
54 |
with gr.Row():
|
55 |
+
word_number = gr.Number(label="Word no.", interactive=False, precision=0)
|
56 |
|
57 |
guess_button.click(
|
58 |
reducer.handle_guess,
|
src/params.py
CHANGED
@@ -23,3 +23,6 @@ class ReducerParams:
|
|
23 |
|
24 |
def __iter__(self) -> Iterator:
|
25 |
return map(partial(getattr, self), self.__dataclass_fields__)
|
|
|
|
|
|
|
|
23 |
|
24 |
def __iter__(self) -> Iterator:
|
25 |
return map(partial(getattr, self), self.__dataclass_fields__)
|
26 |
+
|
27 |
+
def __getitem__(self, index) -> str | int:
|
28 |
+
return list(self)[index]
|
src/predictions.py
CHANGED
@@ -14,35 +14,21 @@ from src.text import get_text
|
|
14 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
15 |
|
16 |
|
17 |
-
# def get_model_predictions(
|
18 |
-
# *,
|
19 |
-
# input_text: str,
|
20 |
-
# model: PreTrainedModel,
|
21 |
-
# tokenizer: PreTrainedTokenizer,
|
22 |
-
# ) -> torch.Tensor:
|
23 |
-
# """
|
24 |
-
# Returns the indices as a torch tensor of the top 3 predicted tokens.
|
25 |
-
# """
|
26 |
-
# inputs = tokenizer(input_text, return_tensors="pt")
|
27 |
-
|
28 |
-
# with torch.no_grad():
|
29 |
-
# logits = model(**inputs).logits
|
30 |
-
|
31 |
-
# last_token = logits[0, :]
|
32 |
-
# top_3 = torch.topk(last_token, MAX_ATTEMPTS).indices.tolist()
|
33 |
-
# return top_3
|
34 |
-
|
35 |
-
|
36 |
def make_predictions(tokenizer: PreTrainedTokenizer) -> tuple:
|
37 |
"""
|
38 |
-
Run this on startup
|
|
|
|
|
39 |
"""
|
40 |
text = get_text()
|
41 |
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
42 |
model.eval()
|
43 |
|
44 |
inputs = tokenizer(text, return_tensors="pt")
|
45 |
-
|
|
|
|
|
|
|
46 |
top_n = torch.topk(logits, MAX_ATTEMPTS)
|
47 |
token_id_preds = top_n.indices.squeeze().tolist()
|
48 |
tokens = list(map(tokenizer.convert_ids_to_tokens, token_id_preds))
|
|
|
14 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def make_predictions(tokenizer: PreTrainedTokenizer) -> tuple:
|
18 |
"""
|
19 |
+
Run this on startup.
|
20 |
+
Returns tuple of target_prediction_pairs and target_prediction_tokens:
|
21 |
+
|
22 |
"""
|
23 |
text = get_text()
|
24 |
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
25 |
model.eval()
|
26 |
|
27 |
inputs = tokenizer(text, return_tensors="pt")
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
logits = model(**inputs).logits
|
31 |
+
|
32 |
top_n = torch.topk(logits, MAX_ATTEMPTS)
|
33 |
token_id_preds = top_n.indices.squeeze().tolist()
|
34 |
tokens = list(map(tokenizer.convert_ids_to_tokens, token_id_preds))
|
src/reducer.py
CHANGED
@@ -8,6 +8,7 @@ from src.handlers import handle_out_of_attempts
|
|
8 |
from src.handlers import handle_player_win
|
9 |
from src.handlers import handle_tie
|
10 |
from src.params import ReducerParams
|
|
|
11 |
from src.utils import guess_is_correct
|
12 |
from src.utils import lm_is_correct
|
13 |
|
@@ -28,12 +29,12 @@ def _handle_guess(params: ReducerParams) -> ReducerParams:
|
|
28 |
if params.button_label == "Next word":
|
29 |
return handle_next_word(params)
|
30 |
|
31 |
-
if not params.
|
32 |
return handle_no_input(params)
|
33 |
|
34 |
-
params.
|
35 |
-
player_correct = guess_is_correct(params
|
36 |
-
lm_correct = lm_is_correct()
|
37 |
|
38 |
if player_correct and lm_correct:
|
39 |
return handle_tie(params)
|
|
|
8 |
from src.handlers import handle_player_win
|
9 |
from src.handlers import handle_tie
|
10 |
from src.params import ReducerParams
|
11 |
+
from src.shared import tokenizer
|
12 |
from src.utils import guess_is_correct
|
13 |
from src.utils import lm_is_correct
|
14 |
|
|
|
29 |
if params.button_label == "Next word":
|
30 |
return handle_next_word(params)
|
31 |
|
32 |
+
if not params.guess_field:
|
33 |
return handle_no_input(params)
|
34 |
|
35 |
+
params.current_guesses += "\n" + params.guess_field
|
36 |
+
player_correct = guess_is_correct(params, tokenizer)
|
37 |
+
lm_correct = lm_is_correct(params)
|
38 |
|
39 |
if player_correct and lm_correct:
|
40 |
return handle_tie(params)
|
src/utils.py
CHANGED
@@ -2,13 +2,19 @@ import logging
|
|
2 |
|
3 |
from transformers import PreTrainedTokenizer
|
4 |
|
|
|
5 |
from src.constants import MAX_ATTEMPTS
|
|
|
6 |
from src.params import ReducerParams
|
7 |
from src.shared import token_id_predictions
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
12 |
def get_start_and_whitespace_tokens(
|
13 |
word: str,
|
14 |
tokenizer: PreTrainedTokenizer,
|
@@ -34,13 +40,17 @@ def lm_is_correct(params: ReducerParams) -> bool:
|
|
34 |
return current_guess == current_target
|
35 |
|
36 |
|
37 |
-
def guess_is_correct(params: ReducerParams,
|
38 |
"""
|
39 |
We check if the predicted token or a corresponding one with a leading whitespace
|
40 |
matches that of the next token
|
41 |
"""
|
|
|
|
|
|
|
|
|
42 |
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
|
43 |
|
44 |
-
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(params.guess_field)
|
45 |
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
|
46 |
return current_target in (predicted_token_start, predicted_token_whitespace)
|
|
|
2 |
|
3 |
from transformers import PreTrainedTokenizer
|
4 |
|
5 |
+
from src import shared
|
6 |
from src.constants import MAX_ATTEMPTS
|
7 |
+
from src.constants import STARTING_INDEX
|
8 |
from src.params import ReducerParams
|
9 |
from src.shared import token_id_predictions
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
|
14 |
+
def get_current_prompt_text(word_number):
|
15 |
+
return shared.tokenizer.decode(shared.all_tokens[: STARTING_INDEX + word_number])
|
16 |
+
|
17 |
+
|
18 |
def get_start_and_whitespace_tokens(
|
19 |
word: str,
|
20 |
tokenizer: PreTrainedTokenizer,
|
|
|
40 |
return current_guess == current_target
|
41 |
|
42 |
|
43 |
+
def guess_is_correct(params: ReducerParams, tokenizer: PreTrainedTokenizer) -> bool:
|
44 |
"""
|
45 |
We check if the predicted token or a corresponding one with a leading whitespace
|
46 |
matches that of the next token
|
47 |
"""
|
48 |
+
# FIXME: handle indexerro
|
49 |
+
print(STARTING_INDEX + params.word_number)
|
50 |
+
current_target = shared.all_tokens[STARTING_INDEX + params.word_number]
|
51 |
+
|
52 |
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
|
53 |
|
54 |
+
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(params.guess_field, tokenizer)
|
55 |
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
|
56 |
return current_target in (predicted_token_start, predicted_token_whitespace)
|