Spaces:
Runtime error
Runtime error
""" | |
This module contains the utilities necessary to pre-generate | |
predicted tokens. | |
""" | |
import torch | |
from transformers import AutoModelForCausalLM | |
from transformers import AutoTokenizer | |
from transformers import PreTrainedTokenizer | |
from src.constants import MAX_ATTEMPTS | |
from src.text import get_text | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
def make_predictions(tokenizer: PreTrainedTokenizer) -> tuple: | |
""" | |
Run this on startup. | |
Returns tuple of target_prediction_pairs and target_prediction_tokens: | |
""" | |
text = get_text() | |
model = AutoModelForCausalLM.from_pretrained("gpt2") | |
model.eval() | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
top_n = torch.topk(logits, MAX_ATTEMPTS) | |
token_id_preds = top_n.indices.squeeze().tolist() | |
tokens = list(map(tokenizer.convert_ids_to_tokens, token_id_preds)) | |
whole_text_token_ids = tokenizer.encode(text) | |
whole_text_tokens = tokenizer.convert_ids_to_tokens(whole_text_token_ids) | |
target_prediction_pairs = list(zip(whole_text_token_ids[1:], token_id_preds)) | |
target_prediction_tokens = list(zip(whole_text_tokens[1:], tokens)) | |
return target_prediction_pairs, target_prediction_tokens | |