gugarosa commited on
Commit
5819d04
1 Parent(s): 67ecc75

Uses native torch decorator for disabling autocast.

Browse files
Files changed (1) hide show
  1. modeling_phi.py +5 -16
modeling_phi.py CHANGED
@@ -8,8 +8,7 @@ from __future__ import annotations
8
 
9
  import math
10
  from dataclasses import dataclass, field
11
- from functools import wraps
12
- from typing import Any, Callable, Dict, Optional, Tuple, Union
13
 
14
  import torch
15
  import torch.nn as nn
@@ -32,18 +31,6 @@ except:
32
  FusedDense = None
33
 
34
 
35
- def disable_autocast(device_type: str = "cuda") -> None:
36
- def _disable_autocast(f: Callable) -> Callable:
37
- @wraps(f)
38
- def __disable_autocast(*args, **kwargs) -> Callable:
39
- with torch.autocast(device_type, enabled=False):
40
- return f(*args, **kwargs)
41
-
42
- return __disable_autocast
43
-
44
- return _disable_autocast
45
-
46
-
47
  @dataclass
48
  class InferenceParams:
49
  """Inference parameters passed to model to efficiently calculate
@@ -359,7 +346,8 @@ class SelfAttention(nn.Module):
359
  self.softmax_scale = softmax_scale
360
  self.drop = nn.Dropout(attention_dropout)
361
 
362
- @disable_autocast
 
363
  def forward(
364
  self,
365
  qkv: torch.FloatTensor,
@@ -418,7 +406,8 @@ class CrossAttention(nn.Module):
418
  self.softmax_scale = softmax_scale
419
  self.drop = nn.Dropout(attention_dropout)
420
 
421
- @disable_autocast
 
422
  def forward(
423
  self,
424
  q: torch.FloatTensor,
 
8
 
9
  import math
10
  from dataclasses import dataclass, field
11
+ from typing import Any, Dict, Optional, Tuple, Union
 
12
 
13
  import torch
14
  import torch.nn as nn
 
31
  FusedDense = None
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @dataclass
35
  class InferenceParams:
36
  """Inference parameters passed to model to efficiently calculate
 
346
  self.softmax_scale = softmax_scale
347
  self.drop = nn.Dropout(attention_dropout)
348
 
349
+ @torch.autocast("cpu", enabled=False)
350
+ @torch.autocast("cuda", enabled=False)
351
  def forward(
352
  self,
353
  qkv: torch.FloatTensor,
 
406
  self.softmax_scale = softmax_scale
407
  self.drop = nn.Dropout(attention_dropout)
408
 
409
+ @torch.autocast("cpu", enabled=False)
410
+ @torch.autocast("cuda", enabled=False)
411
  def forward(
412
  self,
413
  q: torch.FloatTensor,