gugarosa commited on
Commit
b5c5161
1 Parent(s): 470e18a

Fixes any potential overflow when calculating attention weights.

Browse files
Files changed (1) hide show
  1. modeling_phi.py +76 -14
modeling_phi.py CHANGED
@@ -8,7 +8,8 @@ from __future__ import annotations
8
 
9
  import math
10
  from dataclasses import dataclass, field
11
- from typing import Any, Dict, Optional, Tuple, Union
 
12
 
13
  import torch
14
  import torch.nn as nn
@@ -31,6 +32,15 @@ except:
31
  FusedDense = None
32
 
33
 
 
 
 
 
 
 
 
 
 
34
  @dataclass
35
  class InferenceParams:
36
  """Inference parameters passed to model to efficiently calculate
@@ -218,7 +228,10 @@ class RotaryEmbedding(nn.Module):
218
  return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
219
 
220
  def _update_cos_sin_cache(
221
- self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
 
 
 
222
  ) -> None:
223
  self._seq_len_cached = seqlen
224
 
@@ -261,14 +274,30 @@ class RotaryEmbedding(nn.Module):
261
  seq_start = seqlen_offset
262
  seq_end = seq_start + qkv.shape[1]
263
 
264
- if self._cos_cached.device != qkv.device or self._cos_cached.dtype != qkv.dtype or (self.training and self._cos_cached.is_inference()):
 
 
 
 
265
  self._update_cos_sin_cache(self.max_position_embeddings, device=qkv.device, dtype=qkv.dtype)
266
-
267
  if kv is None:
268
- return _apply_rotary_emb_qkv(qkv, self._cos_cached[seq_start:seq_end], self._sin_cached[seq_start:seq_end])
 
 
 
 
269
  else:
270
- q = _apply_rotary_emb(qkv, self._cos_cached[seq_start:seq_end], self._sin_cached[seq_start:seq_end])
271
- kv = _apply_rotary_emb_kv(kv, self._cos_cached[seq_start:seq_end], self._sin_cached[seq_start:seq_end])
 
 
 
 
 
 
 
 
272
 
273
  return q, kv
274
 
@@ -327,6 +356,7 @@ class SelfAttention(nn.Module):
327
  self.softmax_scale = softmax_scale
328
  self.drop = nn.Dropout(attention_dropout)
329
 
 
330
  def forward(
331
  self,
332
  qkv: torch.FloatTensor,
@@ -337,9 +367,14 @@ class SelfAttention(nn.Module):
337
  batch_size, seqlen = qkv.shape[0], qkv.shape[1]
338
  q, k, v = qkv.unbind(dim=2)
339
 
 
 
 
340
  causal = self.causal if causal is None else causal
341
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
342
 
 
 
343
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
344
 
345
  if key_padding_mask is not None:
@@ -352,7 +387,7 @@ class SelfAttention(nn.Module):
352
  causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
353
  scores = scores + causal_mask.to(dtype=scores.dtype)
354
 
355
- attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
356
  attention = self.drop(attention)
357
 
358
  output = torch.einsum("bhts,bshd->bthd", attention, v)
@@ -380,6 +415,7 @@ class CrossAttention(nn.Module):
380
  self.softmax_scale = softmax_scale
381
  self.drop = nn.Dropout(attention_dropout)
382
 
 
383
  def forward(
384
  self,
385
  q: torch.FloatTensor,
@@ -395,9 +431,14 @@ class CrossAttention(nn.Module):
395
  kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
396
  k, v = kv.unbind(dim=2)
397
 
 
 
 
398
  causal = self.causal if causal is None else causal
399
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
400
 
 
 
401
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
402
 
403
  if key_padding_mask is not None:
@@ -418,7 +459,7 @@ class CrossAttention(nn.Module):
418
 
419
  scores = scores.masked_fill(causal_mask, -10000.0)
420
 
421
- attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
422
  attention = self.drop(attention)
423
 
424
  output = torch.einsum("bhts,bshd->bthd", attention, v)
@@ -507,7 +548,13 @@ class MHA(nn.Module):
507
  if rotary_cls is RotaryEmbedding:
508
  rotary_kwargs["max_position_embeddings"] = config.n_positions
509
 
510
- self.rotary_emb = rotary_cls(self.rotary_dim, base=rotary_base, scale_base=rotary_scale_base, device=device, **rotary_kwargs)
 
 
 
 
 
 
511
 
512
  # MLP
513
  self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
@@ -532,9 +579,15 @@ class MHA(nn.Module):
532
  if cross_attn_cls is None:
533
  cross_attn_cls = CrossAttention
534
 
535
- self.inner_attn = attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
 
 
 
 
536
  self.inner_cross_attn = cross_attn_cls(
537
- causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop
 
 
538
  )
539
 
540
  self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
@@ -603,7 +656,12 @@ class MHA(nn.Module):
603
  batch_size, seqlen_q = q.shape[0], q.shape[1]
604
  seqlen_k = kv.shape[1]
605
 
606
- cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = None, None, None, None
 
 
 
 
 
607
  if key_padding_mask is not None:
608
  kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
609
 
@@ -644,7 +702,11 @@ class MHA(nn.Module):
644
 
645
  if self.checkpointing:
646
  return torch.utils.checkpoint.checkpoint(
647
- self.inner_cross_attn, q, kv, key_padding_mask=key_padding_mask, causal=causal
 
 
 
 
648
  )
649
 
650
  return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
 
8
 
9
  import math
10
  from dataclasses import dataclass, field
11
+ from functools import wraps
12
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
13
 
14
  import torch
15
  import torch.nn as nn
 
32
  FusedDense = None
33
 
34
 
35
+ def disable_autocast(func: Callable) -> Callable:
36
+ @wraps(func)
37
+ def wrapper(*args, **kwargs):
38
+ with torch.cuda.amp.autocast(enabled=False):
39
+ return func(*args, **kwargs)
40
+
41
+ return wrapper
42
+
43
+
44
  @dataclass
45
  class InferenceParams:
46
  """Inference parameters passed to model to efficiently calculate
 
228
  return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
229
 
230
  def _update_cos_sin_cache(
231
+ self,
232
+ seqlen: int,
233
+ device: Optional[str] = None,
234
+ dtype: Optional[torch.dtype] = None,
235
  ) -> None:
236
  self._seq_len_cached = seqlen
237
 
 
274
  seq_start = seqlen_offset
275
  seq_end = seq_start + qkv.shape[1]
276
 
277
+ if (
278
+ self._cos_cached.device != qkv.device
279
+ or self._cos_cached.dtype != qkv.dtype
280
+ or (self.training and self._cos_cached.is_inference())
281
+ ):
282
  self._update_cos_sin_cache(self.max_position_embeddings, device=qkv.device, dtype=qkv.dtype)
283
+
284
  if kv is None:
285
+ return _apply_rotary_emb_qkv(
286
+ qkv,
287
+ self._cos_cached[seq_start:seq_end],
288
+ self._sin_cached[seq_start:seq_end],
289
+ )
290
  else:
291
+ q = _apply_rotary_emb(
292
+ qkv,
293
+ self._cos_cached[seq_start:seq_end],
294
+ self._sin_cached[seq_start:seq_end],
295
+ )
296
+ kv = _apply_rotary_emb_kv(
297
+ kv,
298
+ self._cos_cached[seq_start:seq_end],
299
+ self._sin_cached[seq_start:seq_end],
300
+ )
301
 
302
  return q, kv
303
 
 
356
  self.softmax_scale = softmax_scale
357
  self.drop = nn.Dropout(attention_dropout)
358
 
359
+ @disable_autocast
360
  def forward(
361
  self,
362
  qkv: torch.FloatTensor,
 
367
  batch_size, seqlen = qkv.shape[0], qkv.shape[1]
368
  q, k, v = qkv.unbind(dim=2)
369
 
370
+ q = q.to(torch.float32)
371
+ k = k.to(torch.float32)
372
+
373
  causal = self.causal if causal is None else causal
374
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
375
 
376
+ # Autocast is manually disabled to avoid `torch.einsum` performing the operation
377
+ # using float16, which might lead to overflow
378
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
379
 
380
  if key_padding_mask is not None:
 
387
  causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
388
  scores = scores + causal_mask.to(dtype=scores.dtype)
389
 
390
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
391
  attention = self.drop(attention)
392
 
393
  output = torch.einsum("bhts,bshd->bthd", attention, v)
 
415
  self.softmax_scale = softmax_scale
416
  self.drop = nn.Dropout(attention_dropout)
417
 
418
+ @disable_autocast
419
  def forward(
420
  self,
421
  q: torch.FloatTensor,
 
431
  kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
432
  k, v = kv.unbind(dim=2)
433
 
434
+ q = q.to(torch.float32)
435
+ k = k.to(torch.float32)
436
+
437
  causal = self.causal if causal is None else causal
438
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
439
 
440
+ # Autocast is manually disabled to avoid `torch.einsum` performing the operation
441
+ # using float16, which might lead to overflow
442
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
443
 
444
  if key_padding_mask is not None:
 
459
 
460
  scores = scores.masked_fill(causal_mask, -10000.0)
461
 
462
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
463
  attention = self.drop(attention)
464
 
465
  output = torch.einsum("bhts,bshd->bthd", attention, v)
 
548
  if rotary_cls is RotaryEmbedding:
549
  rotary_kwargs["max_position_embeddings"] = config.n_positions
550
 
551
+ self.rotary_emb = rotary_cls(
552
+ self.rotary_dim,
553
+ base=rotary_base,
554
+ scale_base=rotary_scale_base,
555
+ device=device,
556
+ **rotary_kwargs,
557
+ )
558
 
559
  # MLP
560
  self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
 
579
  if cross_attn_cls is None:
580
  cross_attn_cls = CrossAttention
581
 
582
+ self.inner_attn = attn_cls(
583
+ causal=causal,
584
+ softmax_scale=softmax_scale,
585
+ attention_dropout=config.attn_pdrop,
586
+ )
587
  self.inner_cross_attn = cross_attn_cls(
588
+ causal=causal,
589
+ softmax_scale=softmax_scale,
590
+ attention_dropout=config.attn_pdrop,
591
  )
592
 
593
  self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
 
656
  batch_size, seqlen_q = q.shape[0], q.shape[1]
657
  seqlen_k = kv.shape[1]
658
 
659
+ cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
660
+ None,
661
+ None,
662
+ None,
663
+ None,
664
+ )
665
  if key_padding_mask is not None:
666
  kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
667
 
 
702
 
703
  if self.checkpointing:
704
  return torch.utils.checkpoint.checkpoint(
705
+ self.inner_cross_attn,
706
+ q,
707
+ kv,
708
+ key_padding_mask=key_padding_mask,
709
+ causal=causal,
710
  )
711
 
712
  return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)