File size: 1,896 Bytes
8969f81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch


def forward(model_name, model, input_ids, past, device='cpu'):
    if "gpt2" in model_name or "ctrl" in model_name:
        if past is not None:
            return model(input_ids[:, -1], past=past)
        return model(input_ids)
    elif "xlnet" in model_name:
        input_ids = torch.cat((
            input_ids,
            torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=device)
        ), dim=1)

        perm_mask = torch.zeros(
            (input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]),
            dtype=torch.float,
            device=device
        )
        perm_mask[:, :, -1] = 1.0

        target_mapping = torch.zeros(
            (input_ids.shape[0], 1, input_ids.shape[1]),
            dtype=torch.float,
            device=device)
        target_mapping[:, 0, -1] = 1.0

        return model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
    elif "transfo-xl" in model_name:
        return model(input_ids, mems=past)
    else:
        return model(input_ids)


def create_context(model_name, tokenizer, initial_text="", padding_text=None, max_tokens=512):
    if not len(initial_text) and "gpt2" in model_name:
        initial_text = "<|endoftext|>"
    if 'xlnet' in model_name or "transfo-xl" in model_name:
        initial_text = padding_text + initial_text

    if 'transfo-xl' in model_name:
        max_tokens = int(max_tokens / 2)

    context_tokens = tokenizer.encode(initial_text)[-max_tokens:]

    if "gpt2" in model_name:
        eot_token = tokenizer.encoder["<|endoftext|>"]
        if len(context_tokens) == 0:
            context_tokens = [tokenizer.encoder["<|endoftext|>"]]
    elif "xlnet" in model_name:
        eot_token = tokenizer.convert_tokens_to_ids('<eop>')
    else:
        eot_token = None
    dot_token = tokenizer.encode(".")[-1]

    return context_tokens, eot_token, dot_token