File size: 1,265 Bytes
1532c35
 
 
 
 
 
 
 
 
 
6d4a32a
 
1532c35
 
 
 
 
 
09c334f
 
 
1532c35
 
 
 
 
 
09c334f
 
 
 
1532c35
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""
This module contains the utilities necessary to pre-generate
predicted tokens.
"""

import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizer

from src.constants import MAX_ATTEMPTS
from src.text import get_text

tokenizer = AutoTokenizer.from_pretrained("gpt2")


def make_predictions(tokenizer: PreTrainedTokenizer) -> tuple:
    """
    Run this on startup.
    Returns tuple of target_prediction_pairs and target_prediction_tokens:

    """
    text = get_text()
    model = AutoModelForCausalLM.from_pretrained("gpt2")
    model.eval()

    inputs = tokenizer(text, return_tensors="pt")

    with torch.no_grad():
        logits = model(**inputs).logits

    top_n = torch.topk(logits, MAX_ATTEMPTS)
    token_id_preds = top_n.indices.squeeze().tolist()
    tokens = list(map(tokenizer.convert_ids_to_tokens, token_id_preds))

    whole_text_token_ids = tokenizer.encode(text)
    whole_text_tokens = tokenizer.convert_ids_to_tokens(whole_text_token_ids)
    target_prediction_pairs = list(zip(whole_text_token_ids[1:], token_id_preds))
    target_prediction_tokens = list(zip(whole_text_tokens[1:], tokens))

    return target_prediction_pairs, target_prediction_tokens