marksverdhei
:construction: Make the repo work somewhat
09c334f
raw
history blame
1.27 kB
"""
This module contains the utilities necessary to pre-generate
predicted tokens.
"""
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizer
from src.constants import MAX_ATTEMPTS
from src.text import get_text
tokenizer = AutoTokenizer.from_pretrained("gpt2")
def make_predictions(tokenizer: PreTrainedTokenizer) -> tuple:
"""
Run this on startup.
Returns tuple of target_prediction_pairs and target_prediction_tokens:
"""
text = get_text()
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval()
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
top_n = torch.topk(logits, MAX_ATTEMPTS)
token_id_preds = top_n.indices.squeeze().tolist()
tokens = list(map(tokenizer.convert_ids_to_tokens, token_id_preds))
whole_text_token_ids = tokenizer.encode(text)
whole_text_tokens = tokenizer.convert_ids_to_tokens(whole_text_token_ids)
target_prediction_pairs = list(zip(whole_text_token_ids[1:], token_id_preds))
target_prediction_tokens = list(zip(whole_text_tokens[1:], tokens))
return target_prediction_pairs, target_prediction_tokens