Spaces:
Runtime error
Runtime error
File size: 1,265 Bytes
1532c35 6d4a32a 1532c35 09c334f 1532c35 09c334f 1532c35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
"""
This module contains the utilities necessary to pre-generate
predicted tokens.
"""
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizer
from src.constants import MAX_ATTEMPTS
from src.text import get_text
tokenizer = AutoTokenizer.from_pretrained("gpt2")
def make_predictions(tokenizer: PreTrainedTokenizer) -> tuple:
"""
Run this on startup.
Returns tuple of target_prediction_pairs and target_prediction_tokens:
"""
text = get_text()
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval()
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
top_n = torch.topk(logits, MAX_ATTEMPTS)
token_id_preds = top_n.indices.squeeze().tolist()
tokens = list(map(tokenizer.convert_ids_to_tokens, token_id_preds))
whole_text_token_ids = tokenizer.encode(text)
whole_text_tokens = tokenizer.convert_ids_to_tokens(whole_text_token_ids)
target_prediction_pairs = list(zip(whole_text_token_ids[1:], token_id_preds))
target_prediction_tokens = list(zip(whole_text_tokens[1:], tokens))
return target_prediction_pairs, target_prediction_tokens
|