supermomo668's picture
Duplicate from facebook/MusicGen
1ccfe17
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
class TorchAutocast:
"""TorchAutocast utility class.
Allows you to enable and disable autocast. This is specially useful
when dealing with different architectures and clusters with different
levels of support.
Args:
enabled (bool): Whether to enable torch.autocast or not.
args: Additional args for torch.autocast.
kwargs: Additional kwargs for torch.autocast
"""
def __init__(self, enabled: bool, *args, **kwargs):
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
def __enter__(self):
if self.autocast is None:
return
try:
self.autocast.__enter__()
except RuntimeError:
device = self.autocast.device
dtype = self.autocast.fast_dtype
raise RuntimeError(
f"There was an error autocasting with dtype={dtype} device={device}\n"
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
)
def __exit__(self, *args, **kwargs):
if self.autocast is None:
return
self.autocast.__exit__(*args, **kwargs)