|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
class TorchAutocast: |
|
"""TorchAutocast utility class. |
|
Allows you to enable and disable autocast. This is specially useful |
|
when dealing with different architectures and clusters with different |
|
levels of support. |
|
|
|
Args: |
|
enabled (bool): Whether to enable torch.autocast or not. |
|
args: Additional args for torch.autocast. |
|
kwargs: Additional kwargs for torch.autocast |
|
""" |
|
def __init__(self, enabled: bool, *args, **kwargs): |
|
self.autocast = torch.autocast(*args, **kwargs) if enabled else None |
|
|
|
def __enter__(self): |
|
if self.autocast is None: |
|
return |
|
try: |
|
self.autocast.__enter__() |
|
except RuntimeError: |
|
device = self.autocast.device |
|
dtype = self.autocast.fast_dtype |
|
raise RuntimeError( |
|
f"There was an error autocasting with dtype={dtype} device={device}\n" |
|
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" |
|
) |
|
|
|
def __exit__(self, *args, **kwargs): |
|
if self.autocast is None: |
|
return |
|
self.autocast.__exit__(*args, **kwargs) |
|
|