Spaces:
Runtime error
Runtime error
import logging | |
from transformers import PreTrainedTokenizer | |
from src import shared | |
from src.constants import MAX_ATTEMPTS | |
from src.constants import STARTING_INDEX | |
from src.params import ReducerParams | |
from src.shared import token_id_predictions | |
logger = logging.getLogger(__name__) | |
def get_current_lm_guess_str(word_number, remaining_attempts): | |
# FIXME: indexerror | |
guess_list = token_id_predictions[(STARTING_INDEX + word_number) - 1][1] | |
guess_list = [shared.tokenizer.decode(i) for i in guess_list] | |
censored_list = ["*****"] * MAX_ATTEMPTS | |
for i in range(MAX_ATTEMPTS - remaining_attempts): | |
censored_list[i] = guess_list[i] | |
return "\n".join(censored_list) | |
def get_current_prompt_text(word_number): | |
# FIXME: indexerror | |
return shared.tokenizer.decode(shared.all_tokens[: STARTING_INDEX + word_number]) | |
def get_start_and_whitespace_tokens( | |
word: str, | |
tokenizer: PreTrainedTokenizer, | |
) -> tuple[int]: | |
""" | |
It is difficult to tell whether | |
""" | |
predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0] | |
predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1] | |
return predicted_token_start, predicted_token_whitespace | |
def lm_is_correct(params: ReducerParams) -> bool: | |
# NOTE: out of range if remaining attempts is 0 | |
if params.remaining_attempts < 1: | |
return False | |
idx = MAX_ATTEMPTS - params.remaining_attempts | |
# FIXME: indexerror | |
current_guess = token_id_predictions[STARTING_INDEX + params.word_number - 1][1][idx] | |
current_target = token_id_predictions[STARTING_INDEX + params.word_number - 1][0] | |
return current_guess == current_target | |
def guess_is_correct(params: ReducerParams, tokenizer: PreTrainedTokenizer) -> bool: | |
""" | |
We check if the predicted token or a corresponding one with a leading whitespace | |
matches that of the next token | |
""" | |
# FIXME: handle indexerro | |
print(STARTING_INDEX + params.word_number) | |
current_target = shared.all_tokens[STARTING_INDEX + params.word_number] | |
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target]))) | |
predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(params.guess_field, tokenizer) | |
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace])) | |
return current_target in (predicted_token_start, predicted_token_whitespace) | |