andrecornman commited on
Commit
ac9212b
1 Parent(s): d850aa6

Update modeling_glm2.py

Browse files
Files changed (1) hide show
  1. modeling_glm2.py +99 -197
modeling_glm2.py CHANGED
@@ -1,12 +1,11 @@
1
  """PyTorch gLM2 model.
2
 
3
- Requires flash attention.
4
  Some modules adapted from:
5
  https://github.com/meta-llama/llama/blob/main/llama/model.py
6
  """
7
- import math
8
  import torch
9
- from einops import rearrange
10
  from typing import Optional, Tuple, Union
11
  from torch import nn
12
  from torch.nn import CrossEntropyLoss
@@ -17,30 +16,51 @@ from transformers.modeling_outputs import (
17
  )
18
  from transformers.modeling_utils import PreTrainedModel
19
  from transformers.utils import logging
 
20
 
21
- try:
22
- from flash_attn.ops.activations import swiglu
23
- from flash_attn.layers.rotary import apply_rotary_emb_func
24
- from flash_attn import (
25
- flash_attn_kvpacked_func,
26
- flash_attn_varlen_kvpacked_func,
27
- )
28
- from flash_attn.bert_padding import pad_input, unpad_input
29
- from flash_attn.ops.triton.layer_norm import RMSNorm
30
- except ImportError:
31
- raise ImportError(
32
- "gLM2 requires flash attention: `pip install flash-attn --no-build-isolation`")
33
 
34
- from .configuration_glm2 import gLM2Config, gLM2EmbedConfig
 
 
 
 
 
 
 
 
 
35
 
36
 
37
- logger = logging.get_logger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  class RotaryEmbedding(torch.nn.Module):
41
  """
42
  Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
43
- Changed to only support passing in q or k individually, so that we can use varlen rotary.
44
  """
45
 
46
  def __init__(
@@ -138,92 +158,52 @@ class RotaryEmbedding(torch.nn.Module):
138
 
139
  def forward(
140
  self,
141
- q: torch.Tensor,
142
- k: torch.Tensor,
143
- seqlen_offset: Union[int, torch.Tensor] = 0,
144
- cu_seqlens: Optional[torch.Tensor] = None,
145
  max_seqlen: Optional[int] = None,
146
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
147
  """
148
- q: (batch, seqlen, nheads, headdim). If cu_seqlens is not None,
149
- shape (total_seqlen, nheads, headdim).
150
- k: (batch, seqlen, nheads, headdim). If cu_seqlens is not None,
151
- shape (total_seqlen, nheads, headdim).
152
- seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
153
- Most commonly used in inference when we have KV cache.
154
- If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
155
- should pass in max_seqlen, which will update the cos / sin cache up to that length.
156
- Apply rotary embedding *inplace* to qkv and / or kv.
157
  """
158
- if cu_seqlens is not None:
159
- assert max_seqlen is not None
160
- seqlen = q.shape[1] if max_seqlen is None else max_seqlen
161
- if max_seqlen is not None:
162
  self._update_cos_sin_cache(
163
- max_seqlen, device=q.device, dtype=q.dtype)
164
- elif isinstance(seqlen_offset, int):
165
  self._update_cos_sin_cache(
166
- seqlen + seqlen_offset, device=q.device, dtype=q.dtype
167
- )
168
- q = apply_rotary_emb_func(
169
- q,
170
- self._cos_cached,
171
- self._sin_cached,
172
- interleaved=self.interleaved,
173
- inplace=True,
174
- seqlen_offsets=seqlen_offset,
175
- cu_seqlens=cu_seqlens,
176
- max_seqlen=max_seqlen,
177
  )
178
- if self.scale is None:
179
- k = apply_rotary_emb_func(
180
- k,
181
- self._cos_cached,
182
- self._sin_cached,
183
- interleaved=self.interleaved,
184
- inplace=True,
185
- seqlen_offsets=seqlen_offset,
186
- cu_seqlens=cu_seqlens,
187
- max_seqlen=max_seqlen,
188
- )
189
- else:
190
- k = apply_rotary_emb_func(
191
- k,
192
- self._cos_k_cached,
193
- self._sin_k_cached,
194
- interleaved=self.interleaved,
195
- inplace=True,
196
- seqlen_offsets=seqlen_offset,
197
- cu_seqlens=cu_seqlens,
198
- max_seqlen=max_seqlen,
199
- )
200
- return q, k
201
 
202
 
203
  # @torch.jit.script
204
- # def rmsnorm_func(hidden_states, weight, variance_epsilon):
205
- # """Apply the root mean square normalization."""
206
- # input_dtype = hidden_states.dtype
207
- # hidden_states = hidden_states.to(torch.float32)
208
- # variance = hidden_states.pow(2).mean(-1, keepdim=True)
209
- # hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
210
- # return (weight * hidden_states).to(input_dtype)
211
 
212
 
213
- # class RMSNorm(nn.Module):
214
- # """Root mean square normalization."""
215
 
216
- # def __init__(self, dim, eps=1e-6):
217
- # super().__init__()
218
- # self.weight = nn.Parameter(torch.ones(dim))
219
- # self.register_buffer(
220
- # "variance_epsilon",
221
- # torch.tensor(eps),
222
- # persistent=False,
223
- # )
224
 
225
- # def forward(self, hidden_states):
226
- # return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
227
 
228
 
229
  class Attention(nn.Module):
@@ -241,67 +221,33 @@ class Attention(nn.Module):
241
 
242
  self.rotary_emb = RotaryEmbedding(self.head_dim)
243
 
244
- def _forward_varlen(
245
- self,
246
- x: torch.Tensor,
247
- cu_seqlens: Optional[torch.Tensor] = None,
248
- max_seq_len: Optional[torch.Tensor] = None,
249
- ) -> torch.Tensor:
250
- total_seqlen, h_size = x.shape
251
- qkv = self.wqkv(x)
252
- q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
253
-
254
- q = q.view(total_seqlen, self.n_heads, self.head_dim)
255
- k = k.view(total_seqlen, self.n_heads, self.head_dim)
256
- v = v.view(total_seqlen, self.n_heads, self.head_dim)
257
-
258
- q, k = self.rotary_emb(
259
- q, k, cu_seqlens=cu_seqlens, max_seqlen=max_seq_len)
260
-
261
- # (seqlen, 2, n_heads, head_dim)
262
- kv = torch.stack([k, v], 1)
263
-
264
- # (seqlen, n_heads, head_dim)
265
- output = flash_attn_varlen_kvpacked_func(
266
- q,
267
- kv,
268
- cu_seqlens_q=cu_seqlens,
269
- cu_seqlens_k=cu_seqlens,
270
- max_seqlen_q=max_seq_len,
271
- max_seqlen_k=max_seq_len,
272
- dropout_p=0.0,
273
- causal=False,
274
- )
275
- output = output.view(total_seqlen, h_size)
276
- return self.wo(output)
277
-
278
  def forward(
279
  self,
280
  x: torch.Tensor,
281
- cu_seqlens: Optional[torch.Tensor] = None,
282
- max_seq_len: Optional[torch.Tensor] = None,
283
  ) -> torch.Tensor:
284
- if cu_seqlens is not None:
285
- assert max_seq_len is not None
286
- return self._forward_varlen(x, cu_seqlens, max_seq_len)
287
-
288
  bsz, seqlen, h_size = x.shape
289
  qkv = self.wqkv(x)
290
- q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
291
- q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
292
- k = k.view(bsz, seqlen, self.n_heads, self.head_dim)
293
- v = v.view(bsz, seqlen, self.n_heads, self.head_dim)
294
-
295
- q, k = self.rotary_emb(q, k)
296
- # (bs, seqlen, 2, n_heads, head_dim)
297
- kv = torch.stack([k, v], 2)
298
-
299
- output = flash_attn_kvpacked_func(
300
- q,
301
- kv,
302
- dropout_p=0.0,
303
- causal=False,
 
 
 
304
  )
 
 
305
  output = output.view(bsz, seqlen, h_size)
306
  return self.wo(output)
307
 
@@ -336,7 +282,7 @@ class FeedForward(nn.Module):
336
  self.w3 = nn.Linear(dim, hidden_dim, bias=False)
337
 
338
  def forward(self, x):
339
- return self.w2(swiglu(self.w1(x), self.w3(x)))
340
 
341
 
342
  class TransformerBlock(nn.Module):
@@ -358,12 +304,10 @@ class TransformerBlock(nn.Module):
358
  def forward(
359
  self,
360
  x: torch.Tensor,
361
- cu_seqlens: Optional[torch.Tensor] = None,
362
- max_seq_len: Optional[torch.Tensor] = None,
363
  ) -> torch.Tensor:
364
- r = self.attention(
365
- self.attention_norm(x), cu_seqlens, max_seq_len
366
- )
367
  h = x + r
368
  r = self.feed_forward(self.ffn_norm(h))
369
  out = h + r
@@ -377,19 +321,6 @@ class TransformerLayers(nn.Module):
377
  self.layers = torch.nn.ModuleList(
378
  [TransformerBlock(config=config) for _ in range(config.depth)]
379
  )
380
- self.apply(self._init_weights)
381
- # Apply special scaled init to the residual projections, per GPT-2 paper.
382
- # Weight w2 is output of FeedForward. Weight wo is output of Attention.
383
- for pn, p in self.named_parameters():
384
- if pn.endswith('w2.weight') or pn.endswith('wo.weight'):
385
- torch.nn.init.normal_(
386
- p, mean=0.0, std=0.02/math.sqrt(2 * self.config.depth))
387
-
388
- def _init_weights(self, module):
389
- if isinstance(module, nn.Linear):
390
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
391
- if module.bias is not None:
392
- torch.nn.init.zeros_(module.bias)
393
 
394
  def forward(
395
  self,
@@ -401,26 +332,12 @@ class TransformerLayers(nn.Module):
401
  raise ValueError(
402
  f"Input feature dim should be {self.config.dim}, but input has shape {x.shape}"
403
  )
404
- batch_size, seq_len = x.shape[:2]
405
- should_unpad = attention_mask is not None and not attention_mask.all()
406
- if should_unpad:
407
- x, indices, cu_seqlens, max_seq_len_in_batch = unpad_input(
408
- x, attention_mask
409
- )
410
- else:
411
- indices, cu_seqlens, max_seq_len_in_batch = None, None, None
412
  hiddens = []
413
  for layer in self.layers:
414
- x = layer(x, cu_seqlens, max_seq_len_in_batch)
415
  if return_all_hiddens:
416
  hiddens.append(x)
417
 
418
- if should_unpad:
419
- x = pad_input(x, indices, batch_size, seq_len)
420
- if return_all_hiddens:
421
- hiddens = [pad_input(h, indices, batch_size, seq_len)
422
- for h in hiddens]
423
-
424
  if return_all_hiddens:
425
  return x, hiddens
426
  return x
@@ -455,16 +372,9 @@ class gLM2Model(gLM2PreTrainedModel):
455
  self.config = config
456
 
457
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
458
- self._init_weights(self.tok_embeddings)
459
  self.encoder = TransformerLayers(config)
460
-
461
- def _init_weights(self, module):
462
- if isinstance(module, nn.Linear):
463
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
464
- if module.bias is not None:
465
- torch.nn.init.zeros_(module.bias)
466
- elif isinstance(module, nn.Embedding):
467
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
468
 
469
  def forward(
470
  self,
@@ -556,15 +466,7 @@ class gLM2ForMaskedLM(gLM2PreTrainedModel):
556
 
557
  self.glm2 = gLM2Model(config)
558
  self.lm_head = gLM2LMHead(config)
559
- self._init_weights(self.lm_head)
560
-
561
- def _init_weights(self, module):
562
- if isinstance(module, nn.Linear):
563
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
564
- if module.bias is not None:
565
- torch.nn.init.zeros_(module.bias)
566
- elif isinstance(module, nn.Embedding):
567
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
568
 
569
  def forward(
570
  self,
@@ -616,4 +518,4 @@ class gLM2LMHead(nn.Module):
616
  config.dim, config.vocab_size, bias=False)
617
 
618
  def forward(self, features):
619
- return self.proj_output(self.norm(features))
 
1
  """PyTorch gLM2 model.
2
 
 
3
  Some modules adapted from:
4
  https://github.com/meta-llama/llama/blob/main/llama/model.py
5
  """
6
+
7
  import torch
8
+ from einops import rearrange, repeat
9
  from typing import Optional, Tuple, Union
10
  from torch import nn
11
  from torch.nn import CrossEntropyLoss
 
16
  )
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import logging
19
+ from .configuration_glm2 import gLM2Config, gLM2EmbedConfig
20
 
21
+ logger = logging.get_logger(__name__)
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+
24
+ def rotate_half(x, interleaved=False):
25
+ if not interleaved:
26
+ x1, x2 = x.chunk(2, dim=-1)
27
+ return torch.cat((-x2, x1), dim=-1)
28
+ else:
29
+ x1, x2 = x[..., ::2], x[..., 1::2]
30
+ return rearrange(
31
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
32
+ )
33
 
34
 
35
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
36
+ """
37
+ x: (batch_size, seqlen, nheads, headdim)
38
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
39
+ """
40
+ ro_dim = cos.shape[-1] * 2
41
+ assert ro_dim <= x.shape[-1]
42
+ seqlen = x.shape[1]
43
+ cos, sin = cos[:seqlen], sin[:seqlen]
44
+ cos = repeat(
45
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
46
+ )
47
+ sin = repeat(
48
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
49
+ )
50
+ return torch.cat(
51
+ [
52
+ x[..., :ro_dim] * cos +
53
+ rotate_half(x[..., :ro_dim], interleaved) * sin,
54
+ x[..., ro_dim:],
55
+ ],
56
+ dim=-1,
57
+ )
58
 
59
 
60
  class RotaryEmbedding(torch.nn.Module):
61
  """
62
  Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
63
+ Changed to use the torch version of apply_rotary_emb_func.
64
  """
65
 
66
  def __init__(
 
158
 
159
  def forward(
160
  self,
161
+ qkv: torch.Tensor,
 
 
 
162
  max_seqlen: Optional[int] = None,
163
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
164
  """
165
+ qkv: (batch, seqlen, 3, nheads, headdim)
 
 
 
 
 
 
 
 
166
  """
167
+ seqlen = qkv.shape[1]
168
+ if seqlen > self._seq_len_cached:
 
 
169
  self._update_cos_sin_cache(
170
+ seqlen, device=qkv.device, dtype=qkv.dtype)
171
+ elif max_seqlen is not None:
172
  self._update_cos_sin_cache(
173
+ max_seqlen, device=qkv.device, dtype=qkv.dtype)
174
+ q_rot = apply_rotary_emb_torch(
175
+ qkv[:, :, 0], self._cos_cached, self._sin_cached, self.interleaved
 
 
 
 
 
 
 
 
176
  )
177
+ k_rot = apply_rotary_emb_torch(
178
+ qkv[:, :, 1], self._cos_cached, self._sin_cached, self.interleaved
179
+ )
180
+ return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
 
183
  # @torch.jit.script
184
+ def rmsnorm_func(hidden_states, weight, variance_epsilon):
185
+ """Apply the root mean square normalization."""
186
+ input_dtype = hidden_states.dtype
187
+ hidden_states = hidden_states.to(torch.float32)
188
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
189
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
190
+ return (weight * hidden_states).to(input_dtype)
191
 
192
 
193
+ class RMSNorm(nn.Module):
194
+ """Root mean square normalization."""
195
 
196
+ def __init__(self, dim, eps=1e-6):
197
+ super().__init__()
198
+ self.weight = nn.Parameter(torch.ones(dim))
199
+ self.register_buffer(
200
+ "variance_epsilon",
201
+ torch.tensor(eps),
202
+ persistent=False,
203
+ )
204
 
205
+ def forward(self, hidden_states):
206
+ return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
207
 
208
 
209
  class Attention(nn.Module):
 
221
 
222
  self.rotary_emb = RotaryEmbedding(self.head_dim)
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def forward(
225
  self,
226
  x: torch.Tensor,
227
+ attention_mask: Optional[torch.Tensor] = None,
 
228
  ) -> torch.Tensor:
 
 
 
 
229
  bsz, seqlen, h_size = x.shape
230
  qkv = self.wqkv(x)
231
+
232
+ qkv = qkv.view(bsz, seqlen, 3, self.n_heads, self.head_dim)
233
+ qkv = self.rotary_emb(qkv)
234
+
235
+ # (batch, nheads, 3, seqlen, headdim)
236
+ qkv = torch.transpose(qkv, 3, 1)
237
+ q = qkv[:, :, 0]
238
+ k = qkv[:, :, 1]
239
+ v = qkv[:, :, 2]
240
+ if attention_mask is not None:
241
+ attention_mask = attention_mask[:, None, None, :]
242
+ attention_mask = attention_mask.expand(
243
+ bsz, self.n_heads, seqlen, seqlen
244
+ ).bool()
245
+ # [B, heads, seq, D]
246
+ output = torch.nn.functional.scaled_dot_product_attention(
247
+ q, k, v, attn_mask=attention_mask
248
  )
249
+ output = output.permute(0, 2, 1, 3).contiguous()
250
+
251
  output = output.view(bsz, seqlen, h_size)
252
  return self.wo(output)
253
 
 
282
  self.w3 = nn.Linear(dim, hidden_dim, bias=False)
283
 
284
  def forward(self, x):
285
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
286
 
287
 
288
  class TransformerBlock(nn.Module):
 
304
  def forward(
305
  self,
306
  x: torch.Tensor,
307
+ attention_mask: Optional[torch.Tensor] = None,
 
308
  ) -> torch.Tensor:
309
+ r = self.attention(self.attention_norm(
310
+ x), attention_mask=attention_mask)
 
311
  h = x + r
312
  r = self.feed_forward(self.ffn_norm(h))
313
  out = h + r
 
321
  self.layers = torch.nn.ModuleList(
322
  [TransformerBlock(config=config) for _ in range(config.depth)]
323
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
  def forward(
326
  self,
 
332
  raise ValueError(
333
  f"Input feature dim should be {self.config.dim}, but input has shape {x.shape}"
334
  )
 
 
 
 
 
 
 
 
335
  hiddens = []
336
  for layer in self.layers:
337
+ x = layer(x, attention_mask=attention_mask)
338
  if return_all_hiddens:
339
  hiddens.append(x)
340
 
 
 
 
 
 
 
341
  if return_all_hiddens:
342
  return x, hiddens
343
  return x
 
372
  self.config = config
373
 
374
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
 
375
  self.encoder = TransformerLayers(config)
376
+ # Initialize weights and apply final processing
377
+ self.post_init()
 
 
 
 
 
 
378
 
379
  def forward(
380
  self,
 
466
 
467
  self.glm2 = gLM2Model(config)
468
  self.lm_head = gLM2LMHead(config)
469
+ self.init_weights()
 
 
 
 
 
 
 
 
470
 
471
  def forward(
472
  self,
 
518
  config.dim, config.vocab_size, bias=False)
519
 
520
  def forward(self, features):
521
+ return self.proj_output(self.norm(features))