File size: 5,523 Bytes
acff406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoConfig
from typing import Dict, List, Tuple, Union, Optional


class FasterChatGLM(PreTrainedModel):
    def __init__(self, model_dir, kernel, *inputs, **kwargs):
        config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
        config.n_head = config.num_attention_heads
        config.n_embd = config.hidden_size
        config.n_layer = config.num_layers
        super().__init__(config, *inputs, **kwargs)
        self.kernel = kernel
        self.fake_reg = torch.nn.Linear(2, 2)
        self.position_encoding_2d = True

    def forward(self, input_ids, position_ids, attention_mask, past_key_values, *args, **kwargs):
        inputs_values = [input_ids, position_ids, attention_mask]
        if past_key_values is not None:
            inputs_values = inputs_values + past_key_values

        computed = self.kernel.infer(inputs_values)
        logits = computed[0]
        if len(computed) == 1:
            present_key_values = None
        else:
            present_key_values = computed[1:]

        return CausalLMOutputWithPast(logits=logits, past_key_values=present_key_values)

    def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
        attention_mask = torch.ones((1, context_length, context_length), device=device)
        attention_mask.tril_()
        attention_mask[..., :context_length - 1] = 1
        attention_mask.unsqueeze_(1)
        attention_mask = (attention_mask < 0.5).bool()

        if self.position_encoding_2d:
            seq_length = seq.index(150004)
            position_ids = torch.arange(context_length, dtype=torch.long, device=device)
            if not gmask:
                position_ids[seq_length:] = mask_position
            block_position_ids = torch.cat((
                torch.zeros(seq_length, dtype=torch.long, device=device),
                torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
            ))
            position_ids = torch.stack((position_ids, block_position_ids), dim=0)
        else:
            position_ids = torch.arange(context_length, dtype=torch.long, device=device)
            if not gmask:
                position_ids[context_length - 1:] = mask_position

        position_ids = position_ids.unsqueeze(0)

        return attention_mask, position_ids

    def prepare_one_sample(self, input_id, mask_token, past, past_key_values, use_gmask):

        seq = input_id.tolist()
        mask_position = seq.index(mask_token)

        if mask_token not in seq:
            raise ValueError("You have to add either [MASK] or [gMASK] in your input")

        # only last token for input_ids if past is not None
        if past is not None or past_key_values is not None:
            context_length = seq.index(150004)
            last_token = input_id[-1].unsqueeze(-1).unsqueeze(0)  # 2 dim
            proc_input_id = last_token
            if self.position_encoding_2d:
                position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
                                            device=input_id.device)
            else:
                position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_id.device)

            attention_mask = torch.zeros(1, 1, 1, 1, device=input_id.device)
        else:
            proc_input_id = input_id.unsqueeze(0)
            attention_mask, position_ids = self.get_masks_and_position_ids(
                seq=seq,
                mask_position=mask_position,
                context_length=len(seq),
                device=input_id.device,
                gmask=use_gmask
            )

        return (proc_input_id.to(torch.int32), position_ids.to(torch.int32),
                attention_mask.to(torch.bool))

    def prepare_inputs_for_generation(
            self,
            input_ids: torch.LongTensor,
            past: Optional[torch.Tensor] = None,
            past_key_values: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            use_cache: bool = None,
            **kwargs
    ) -> dict:

        MASK, gMASK = 150000, 150001
        mask_token = MASK if MASK in input_ids else gMASK
        use_gmask = False if MASK in input_ids else gMASK

        batch_input_ids, batch_position_ids, batch_attention_mask = [], [], []
        for input_id in input_ids:
            proc_input_id, position_id, attention_mask = self.prepare_one_sample(
                input_id, mask_token, past, past_key_values, use_gmask)
            batch_input_ids.append(proc_input_id)
            batch_position_ids.append(position_id)
            batch_attention_mask.append(attention_mask)

        batch_input_ids = torch.vstack(batch_input_ids)
        batch_position_ids = torch.vstack(batch_position_ids)
        batch_attention_mask = torch.vstack(batch_attention_mask)

        if past is None:
            past = past_key_values

        if past is not None or past_key_values is not None:
            self.kernel.set_context_mode(False)
        else:
            self.kernel.set_context_mode(self.config.use_cache)

        return {
            "input_ids": batch_input_ids,
            "past_key_values": past_key_values,
            "position_ids": batch_position_ids,
            "attention_mask": batch_attention_mask
        }