marksverdhei
:sparkles: Create a semi-working buggy first version
a88f931
raw
history blame
2.48 kB
import logging
from transformers import PreTrainedTokenizer
from src import shared
from src.constants import MAX_ATTEMPTS
from src.constants import STARTING_INDEX
from src.params import ReducerParams
from src.shared import token_id_predictions
logger = logging.getLogger(__name__)
def get_current_lm_guess_str(word_number, remaining_attempts):
# FIXME: indexerror
guess_list = token_id_predictions[(STARTING_INDEX + word_number) - 1][1]
guess_list = [shared.tokenizer.decode(i) for i in guess_list]
censored_list = ["*****"] * MAX_ATTEMPTS
for i in range(MAX_ATTEMPTS - remaining_attempts):
censored_list[i] = guess_list[i]
return "\n".join(censored_list)
def get_current_prompt_text(word_number):
# FIXME: indexerror
return shared.tokenizer.decode(shared.all_tokens[: STARTING_INDEX + word_number])
def get_start_and_whitespace_tokens(
word: str,
tokenizer: PreTrainedTokenizer,
) -> tuple[int]:
"""
It is difficult to tell whether
"""
predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0]
predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1]
return predicted_token_start, predicted_token_whitespace
def lm_is_correct(params: ReducerParams) -> bool:
# NOTE: out of range if remaining attempts is 0
if params.remaining_attempts < 1:
return False
idx = MAX_ATTEMPTS - params.remaining_attempts
# FIXME: indexerror
current_guess = token_id_predictions[STARTING_INDEX + params.word_number - 1][1][idx]
current_target = token_id_predictions[STARTING_INDEX + params.word_number - 1][0]
return current_guess == current_target
def guess_is_correct(params: ReducerParams, tokenizer: PreTrainedTokenizer) -> bool:
"""
We check if the predicted token or a corresponding one with a leading whitespace
matches that of the next token
"""
# FIXME: handle indexerro
print(STARTING_INDEX + params.word_number)
current_target = shared.all_tokens[STARTING_INDEX + params.word_number]
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target])))
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(params.guess_field, tokenizer)
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
return current_target in (predicted_token_start, predicted_token_whitespace)