File size: 2,476 Bytes
1532c35
7eee83c
1532c35
7eee83c
09c334f
6d4a32a
09c334f
6d4a32a
 
1532c35
 
 
 
a88f931
 
 
 
 
 
 
 
 
 
 
09c334f
a88f931
09c334f
 
 
1532c35
 
 
 
 
 
 
 
 
 
 
 
6d4a32a
 
 
 
 
 
 
a88f931
 
 
6d4a32a
 
 
 
09c334f
1532c35
 
 
 
09c334f
 
 
 
1532c35
6d4a32a
09c334f
1532c35
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import logging

from transformers import PreTrainedTokenizer

from src import shared
from src.constants import MAX_ATTEMPTS
from src.constants import STARTING_INDEX
from src.params import ReducerParams
from src.shared import token_id_predictions

logger = logging.getLogger(__name__)


def get_current_lm_guess_str(word_number, remaining_attempts):
    # FIXME: indexerror
    guess_list = token_id_predictions[(STARTING_INDEX + word_number) - 1][1]
    guess_list = [shared.tokenizer.decode(i) for i in guess_list]
    censored_list = ["*****"] * MAX_ATTEMPTS
    for i in range(MAX_ATTEMPTS - remaining_attempts):
        censored_list[i] = guess_list[i]

    return "\n".join(censored_list)


def get_current_prompt_text(word_number):
    # FIXME: indexerror
    return shared.tokenizer.decode(shared.all_tokens[: STARTING_INDEX + word_number])


def get_start_and_whitespace_tokens(
    word: str,
    tokenizer: PreTrainedTokenizer,
) -> tuple[int]:
    """
    It is difficult to tell whether
    """
    predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0]
    predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1]
    return predicted_token_start, predicted_token_whitespace


def lm_is_correct(params: ReducerParams) -> bool:
    # NOTE: out of range if remaining attempts is 0
    if params.remaining_attempts < 1:
        return False

    idx = MAX_ATTEMPTS - params.remaining_attempts

    # FIXME: indexerror
    current_guess = token_id_predictions[STARTING_INDEX + params.word_number - 1][1][idx]
    current_target = token_id_predictions[STARTING_INDEX + params.word_number - 1][0]

    return current_guess == current_target


def guess_is_correct(params: ReducerParams, tokenizer: PreTrainedTokenizer) -> bool:
    """
    We check if the predicted token or a corresponding one with a leading whitespace
    matches that of the next token
    """
    # FIXME: handle indexerro
    print(STARTING_INDEX + params.word_number)
    current_target = shared.all_tokens[STARTING_INDEX + params.word_number]

    logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))

    predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(params.guess_field, tokenizer)
    logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
    return current_target in (predicted_token_start, predicted_token_whitespace)