nz commited on
Commit
af3f5e2
1 Parent(s): 410950a

Create rita_modeling.py

Browse files
Files changed (1) hide show
  1. rita_modeling.py +249 -0
rita_modeling.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss
10
+
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutputWithPast,
13
+ BaseModelOutputWithPastAndCrossAttentions,
14
+ CausalLMOutputWithCrossAttentions,
15
+ CausalLMOutputWithPast,
16
+ )
17
+
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.utils import logging
20
+
21
+ from .rita_configuration import RITAConfig
22
+ import torch.nn.functional as F
23
+ logger = logging.get_logger(__name__)
24
+
25
+ @torch.jit.script
26
+ def RITA_gelu(hidden_states):
27
+ return hidden_states * 0.5 * (1.0 + torch.tanh(0.79788456 * hidden_states * (1 + 0.044715 * hidden_states * hidden_states)))
28
+
29
+ class RITAGELU(nn.Module):
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def forward(self, hidden_states):
34
+ return RITA_gelu(hidden_states)
35
+
36
+ def rotate_half(x):
37
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
38
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
39
+
40
+ class RotaryEmbedding(nn.Module):
41
+ def __init__(self, config):
42
+ super().__init__()
43
+ assert config.d_model % config.num_heads == 0
44
+
45
+ self.d_model = config.d_model
46
+ self.num_heads = config.num_heads
47
+ self.max_seq_len = config.max_seq_len
48
+
49
+ head_dim = self.d_model // self.num_heads
50
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
51
+ self.register_buffer('inv_freq', inv_freq)
52
+ self.seq_len_cached = None
53
+ self.cos_cached = None
54
+ self.sin_cached = None
55
+
56
+ def forward(self, x: torch.FloatTensor, seq_dim=1) -> torch.FloatTensor:
57
+ seq_len = x.shape[seq_dim]
58
+ if seq_len != self.seq_len_cached:
59
+ self.seq_len_cached = seq_len
60
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
61
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
62
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
63
+ self.cos_cached = emb.cos()[None, None, :, :]
64
+ self.sin_cached = emb.sin()[None, None, :, :]
65
+ return self.cos_cached, self.sin_cached
66
+
67
+ def apply_rotary_pos_emb(self, q, k, cos, sin):
68
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
69
+
70
+
71
+ class SelfAttention(nn.Module):
72
+ """Implementation of MultiHeadAttention following `Karpathy's MinGPT <https://github.com/karpathy/minGPT>`_.
73
+ modified to use rotary embeddings.
74
+
75
+ Parameters
76
+ ----------
77
+ d_model: int,
78
+ total dimension of the model.
79
+ num_heads: int,
80
+ number of parallel attention heads.
81
+ num_layers: int,
82
+ number of layers in the model, used for the Megatron-like init.
83
+ rotaty_embedding: Optional[Block], default None,
84
+ a RotaryEmbedding Block to add positionnal information in Queries and Keys
85
+ dropout: float, default 0.1,
86
+ amount of dropout on the attention weights.
87
+ sigma: float, default 0.02,
88
+ standard deviation used for the init.
89
+ trainable: bool, default True,
90
+ if False, the Module parameters will be hidden from the optimizer.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ d_model: int,
96
+ num_heads: int,
97
+ num_layers: int,
98
+ rotary_embedding= None,
99
+ dropout: float = 0.1,
100
+ sigma=0.02,
101
+ use_cache: bool = False,
102
+ bias=True,
103
+ ):
104
+ super().__init__()
105
+ assert d_model % num_heads == 0
106
+ self.d_model = d_model
107
+ self.num_heads = num_heads
108
+ self.head_dim = self.d_model // self.num_heads
109
+ self.num_layers = num_layers
110
+ self.dropout = dropout
111
+ self.sigma = sigma
112
+ self.bias = bias
113
+
114
+ # key, query, value projections for all heads
115
+ self.key = nn.Linear(d_model, d_model, bias=bias)
116
+ self.query = nn.Linear(d_model, d_model, bias=bias)
117
+ self.value = nn.Linear(d_model, d_model, bias=bias)
118
+ # regularization
119
+ self.attn_drop = nn.Dropout(dropout)
120
+ self.resid_drop = nn.Dropout(dropout)
121
+ # output projection
122
+ self.proj = nn.Linear(d_model, d_model, bias=bias)
123
+
124
+ self.rotary_embedding = rotary_embedding
125
+ self.layer_id = None # will be set by the Transformer itself
126
+ self.use_cache = use_cache
127
+ self.qkv = None
128
+ self.bias = bias
129
+
130
+ def forward(
131
+ self,
132
+ x,
133
+ attn_mask: Optional[torch.BoolTensor] = None,
134
+ padding_mask: Optional[torch.BoolTensor] = None,
135
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
136
+
137
+ N, L, D = x.size() # Batch_size, Context_size, d_model
138
+
139
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
140
+ k = (
141
+ self.key(x).view(N, L, self.num_heads, D // self.num_heads).transpose(1, 2)
142
+ ) # (N, nh, L, hs)
143
+ q = (
144
+ self.query(x).view(N, L, self.num_heads, D // self.num_heads).transpose(1, 2)
145
+ ) # (N, nh, L, hs)
146
+ v = (
147
+ self.value(x).view(N, L, self.num_heads, D // self.num_heads).transpose(1, 2)
148
+ ) # (N, nh, L, hs)
149
+
150
+ if self.rotary_embedding is not None:
151
+ cos, sin = self.rotary_embedding(x)
152
+ q, k = self.rotary_embedding.apply_rotary_pos_emb(q, k, cos, sin)
153
+
154
+ # causal self-attention; Self-attend: (N, nh, L, hs) x (N, nh, hs, L) -> (N, nh, L, L)
155
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
156
+
157
+ if attn_mask is not None:
158
+ att[:,:,-L:, -L: ].masked_fill_(attn_mask.view(1, 1, L, L), float("-inf"))
159
+
160
+ att = (
161
+ att.transpose(0, 2)
162
+ .masked_fill(padding_mask.view(1, 1, N, L), float("-inf"))
163
+ .transpose(0, 2)
164
+ if padding_mask is not None
165
+ else att
166
+ )
167
+
168
+ att = F.softmax(att, dim=-1)
169
+ att = self.attn_drop(att)
170
+ y = att @ v # (N, nh, L, L) x (N, nh, L, hs) -> (N, nh, L, hs)
171
+ y = (
172
+ y.transpose(1, 2).contiguous().view(N, L, D)
173
+ ) # re-assemble all head outputs side by side
174
+
175
+ # output projection
176
+ y = self.resid_drop(self.proj(y))
177
+ return y
178
+
179
+ class DecoderLayer(nn.Module):
180
+ """Transformer block containing the self-attention module and the feedfoward module."""
181
+
182
+ def __init__(
183
+ self, config
184
+ ):
185
+ super().__init__()
186
+ self.self_attention = SelfAttention(config.d_model, config.num_heads, config.dropout, rotary_embedding=RotaryEmbedding(config))
187
+ self.attn_norm = nn.LayerNorm(config.d_model)
188
+ self.attn_dropout = nn.Dropout(config.dropout)
189
+
190
+ self.mlp = nn.Sequential(
191
+ nn.Linear(config.d_model, config.d_feedforward, bias=True),
192
+ RITAGELU(),
193
+ nn.Linear(config.d_feedforward, config.d_model, bias=True),
194
+ )
195
+ self.mlp_norm = nn.LayerNorm(config.d_model)
196
+ self.mlp_dropout = nn.Dropout(config.dropout)
197
+
198
+ def forward(
199
+ self,
200
+ x: torch.FloatTensor,
201
+ attn_mask: torch.BoolTensor,
202
+ padding_mask: Optional[torch.BoolTensor] = None,
203
+ ) -> torch.FloatTensor:
204
+ y = self.attn_norm(x)
205
+ y = self.self_attention(y, attn_mask=attn_mask, padding_mask=padding_mask)
206
+ x = x + self.attn_dropout(y)
207
+
208
+ y = self.mlp_norm(x)
209
+ y = self.mlp(y)
210
+ x = x + self.mlp_dropout(y)
211
+ return x
212
+
213
+ class RITAModel(PreTrainedModel):
214
+ config_class = RITAConfig
215
+ def __init__(
216
+ self,
217
+ config
218
+ ):
219
+ super().__init__(config)
220
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
221
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_layers)])
222
+ self.final_norm = nn.LayerNorm(config.d_model)
223
+ self.projector = nn.Linear(config.d_model, config.vocab_size, bias = False)
224
+
225
+ def forward(self, input_ids, attn_mask=None, padding_mask=None, return_hidden=False) -> torch.FloatTensor:
226
+ x = self.embedding(input_ids) # N x L x D
227
+ if attn_mask == None:
228
+ attn_mask = (torch.triu(torch.ones(input_ids.size(1), input_ids.size(1))) == 0).transpose(0, 1).contiguous().to(input_ids.device)
229
+ for layer in self.layers:
230
+ x = layer(x, attn_mask=attn_mask, padding_mask=padding_mask)
231
+ x = self.final_norm(x) # N x L x D
232
+
233
+ if return_hidden:
234
+ return x
235
+ else:
236
+ return self.projector(x)
237
+
238
+ #Some common HF functions.
239
+ def get_input_embeddings(self):
240
+ return self.embedding
241
+
242
+ def set_input_embeddings(self, new_embeddings):
243
+ self.embedding = new_embeddings
244
+
245
+ def get_output_embeddings(self):
246
+ return self.projector
247
+
248
+ def set_output_embeddings(self, new_projector):
249
+ self.projector = new_projector