# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch class TorchAutocast: """TorchAutocast utility class. Allows you to enable and disable autocast. This is specially useful when dealing with different architectures and clusters with different levels of support. Args: enabled (bool): Whether to enable torch.autocast or not. args: Additional args for torch.autocast. kwargs: Additional kwargs for torch.autocast """ def __init__(self, enabled: bool, *args, **kwargs): self.autocast = torch.autocast(*args, **kwargs) if enabled else None def __enter__(self): if self.autocast is None: return try: self.autocast.__enter__() except RuntimeError: device = self.autocast.device dtype = self.autocast.fast_dtype raise RuntimeError( f"There was an error autocasting with dtype={dtype} device={device}\n" "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" ) def __exit__(self, *args, **kwargs): if self.autocast is None: return self.autocast.__exit__(*args, **kwargs)