gugarosa commited on
Commit
67ecc75
1 Parent(s): b5c5161

Adds disable_autocast support for different device types.

Browse files
Files changed (1) hide show
  1. modeling_phi.py +10 -7
modeling_phi.py CHANGED
@@ -32,13 +32,16 @@ except:
32
  FusedDense = None
33
 
34
 
35
- def disable_autocast(func: Callable) -> Callable:
36
- @wraps(func)
37
- def wrapper(*args, **kwargs):
38
- with torch.cuda.amp.autocast(enabled=False):
39
- return func(*args, **kwargs)
40
-
41
- return wrapper
 
 
 
42
 
43
 
44
  @dataclass
 
32
  FusedDense = None
33
 
34
 
35
+ def disable_autocast(device_type: str = "cuda") -> None:
36
+ def _disable_autocast(f: Callable) -> Callable:
37
+ @wraps(f)
38
+ def __disable_autocast(*args, **kwargs) -> Callable:
39
+ with torch.autocast(device_type, enabled=False):
40
+ return f(*args, **kwargs)
41
+
42
+ return __disable_autocast
43
+
44
+ return _disable_autocast
45
 
46
 
47
  @dataclass