File size: 6,302 Bytes
a420fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c07c430
a420fe7
455129a
a420fe7
c07c430
a420fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455129a
a420fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c572a14
a420fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f3418e
a420fe7
 
 
 
 
 
c572a14
a420fe7
 
c572a14
a420fe7
 
 
 
 
 
 
 
c572a14
a420fe7
 
 
 
 
 
 
 
 
 
 
 
10aca20
78f6f3b
 
a420fe7
 
 
 
 
 
 
 
 
455129a
a420fe7
10aca20
78f6f3b
 
a420fe7
 
 
 
 
c07c430
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Any, cast

from .attention import ParallelAttentionBlock, KVCache
from .phi2_configuration import Phi2Config


class Phi2PreTrainedModel(PreTrainedModel):
    config_class = Phi2Config  # not necessary unless you want to register model with auto classes
    supports_gradient_checkpointing = False
    # _no_split_modules = ["ParallelAttentionBlock"]

    def __init__(self, config: Phi2Config):
        super().__init__(config)
        self.config = config

    def _init_weights(self, module: nn.Module) -> None:
        # initialize weights - will get overwritten by saved weights in from_pretrained() if they exist
        if isinstance(module, (nn.Linear,)):
            module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            if module.bias is not None:
                module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.LongTensor,  # dim: (batch_size, seq_len)
        past_key_values: KVCache | None = None,  # has to be named this
        key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
        **kwargs,  # has to be here
    ) -> dict[str, Any]:
        kv_cache = past_key_values
        if not kv_cache:
            kv_cache = KVCache(
                max_seqlen=self.config.initial_cos_sin_cache_len,
                max_batch_size=input_ids.shape[0],
                seqlen_offset=0,
                batch_size_offset=0,
                kv_block_map={},
                lengths_per_sample=None,
            )
        else:
            # assume that `kv_cache` has cached all tokens up to the last token in `input_ids`
            kv_cache.seqlen_offset = input_ids.shape[1] - 1
            input_ids = cast(torch.LongTensor, input_ids[:, -1].unsqueeze(-1))

        return {  # to be passed to forward()
            "input_ids": input_ids,
            "kv_cache": kv_cache,
            "key_padding_mask": key_padding_mask,
        }


class Embedding(nn.Module):
    """Token embedding with dropout."""

    def __init__(
        self,
        vocab_size: int,
        d_embedding: int,
        embd_pdrop: float,
    ) -> None:
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, d_embedding)
        self.dropout = nn.Dropout(embd_pdrop)

    def forward(
        self,
        input_ids: torch.LongTensor,  # dim: (batch_size, seq_len)
    ) -> torch.FloatTensor:
        x = self.embeddings(  # dim: (batch_size, seq_len, d_embedding)
            input_ids.view(-1, input_ids.size()[-1])
        )
        x = self.dropout(x)
        return x


class Phi2Model(Phi2PreTrainedModel):
    def __init__(self, config: Phi2Config) -> None:
        super().__init__(config)
        self.embedding = Embedding(
            vocab_size=config.vocab_size,
            d_embedding=config.d_embedding,
            embd_pdrop=config.embd_pdrop,
        )
        self.parallel_blocks = nn.ModuleList([
            ParallelAttentionBlock(
                resid_pdrop=config.resid_pdrop,
                layer_norm_epsilon=config.layer_norm_epsilon,
                d_embedding=config.d_embedding,
                n_attn_heads=config.n_attn_heads,
                block_n=i,
                initial_cos_sin_cache_len=config.initial_cos_sin_cache_len,
                attn_pdrop=config.attn_pdrop,
                use_flash_rotary=config.use_flash_rotary,
                use_flash_attn=config.use_flash_attn,
                use_fused_dense=config.use_fused_dense,
                checkpointing=config.checkpointing,
            )
            for i in range(config.n_attn_blocks)
        ])
        self.gradient_checkpointing_disable()  # https://github.com/cybertronai/gradient-checkpointing - I think this is turned off due to flash attention?
        self.post_init()  # calls self._init_weights() for all modules

    """
    def get_input_embeddings(self) -> nn.Embedding:
        return self.embedding.embeddings

    def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
        self.embedding.embeddings = new_embeddings
    """

    def forward(
        self,
        input_ids: torch.LongTensor,
        kv_cache: KVCache | None = None,
        key_padding_mask: torch.BoolTensor | None = None,
    ) -> torch.FloatTensor:
        x = self.embedding(input_ids)
        for block in self.parallel_blocks:
            x = block(
                x,
                kv_cache=kv_cache,
                key_padding_mask=key_padding_mask,
            )
        return x


class Phi2ModelForCausalLM(Phi2PreTrainedModel):
    def __init__(self, config: Phi2Config) -> None:
        super().__init__(config)
        self.model = Phi2Model(config)
        self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon)
        self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size)
        self.loss_fn = nn.CrossEntropyLoss()
        self.post_init()  # calls self._init_weights() for all modules

    def forward(
        self,
        input_ids: torch.LongTensor,
        kv_cache: KVCache | None = None,
        key_padding_mask: torch.BoolTensor | None = None,
        labels: torch.LongTensor | None = None,
        **kwargs,  # has to be here
    ) -> CausalLMOutputWithPast:
        x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
        x = self.lm_head_layer_norm(x)
        logits = self.lm_head_linear(x).to(torch.float32)
        loss = (
            self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            if labels is not None
            else None
        )
        return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=kv_cache)