""" This module contains the utilities necessary to pre-generate predicted tokens. """ import torch from transformers import AutoModelForCausalLM from transformers import AutoTokenizer from transformers import PreTrainedTokenizer from src.constants import MAX_ATTEMPTS from src.text import get_text tokenizer = AutoTokenizer.from_pretrained("gpt2") def make_predictions(tokenizer: PreTrainedTokenizer) -> tuple: """ Run this on startup. Returns tuple of target_prediction_pairs and target_prediction_tokens: """ text = get_text() model = AutoModelForCausalLM.from_pretrained("gpt2") model.eval() inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits top_n = torch.topk(logits, MAX_ATTEMPTS) token_id_preds = top_n.indices.squeeze().tolist() tokens = list(map(tokenizer.convert_ids_to_tokens, token_id_preds)) whole_text_token_ids = tokenizer.encode(text) whole_text_tokens = tokenizer.convert_ids_to_tokens(whole_text_token_ids) target_prediction_pairs = list(zip(whole_text_token_ids[1:], token_id_preds)) target_prediction_tokens = list(zip(whole_text_tokens[1:], tokens)) return target_prediction_pairs, target_prediction_tokens