| import torch | |
| _is_low_precision_mode_stack = [] | |
| class LowPrecisionMode: | |
| def __init__(self, enabled: bool = True): | |
| self.enabled = enabled | |
| def __enter__(self): | |
| global _is_low_precision_mode_stack | |
| _is_low_precision_mode_stack.append(self.enabled) | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| global _is_low_precision_mode_stack | |
| _is_low_precision_mode_stack.pop() | |
| def is_low_precision_mode() -> bool: | |
| global _is_low_precision_mode_stack | |
| if len(_is_low_precision_mode_stack) == 0: | |
| return False | |
| return _is_low_precision_mode_stack[-1] | |
| def optional_cast_to_bf16_and_cast_back(tensor: torch.Tensor) -> torch.Tensor: | |
| assert tensor.dtype == torch.float32, "Input tensor must be of dtype torch.float32 for optional casting." | |
| if is_low_precision_mode(): | |
| tensor_bf16 = tensor.to(torch.bfloat16) | |
| tensor_fp32 = tensor_bf16.to(torch.float32) | |
| return tensor_fp32 | |
| else: | |
| return tensor | |