""" Source url: https://github.com/OPHoperHPO/image-background-remove-tool Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. License: Apache License 2.0 """ import random import warnings from typing import Union, Tuple, Any import torch from torch import autocast class EmptyAutocast(object): """ Empty class for disable any autocasting. """ def __enter__(self): return None def __exit__(self, exc_type, exc_val, exc_tb): return def __call__(self, func): return def get_precision_autocast( device="cpu", fp16=True, override_dtype=None ) -> Union[ Tuple[EmptyAutocast, Union[torch.dtype, Any]], Tuple[autocast, Union[torch.dtype, Any]], ]: """ Returns precision and autocast settings for given device and fp16 settings. Args: device: Device to get precision and autocast settings for. fp16: Whether to use fp16 precision. override_dtype: Override dtype for autocast. Returns: Autocast object, dtype """ dtype = torch.float32 cache_enabled = None if device == "cpu" and fp16: warnings.warn('FP16 is not supported on CPU. Using FP32 instead.') dtype = torch.float32 # TODO: Implement BFP16 on CPU. There are unexpected slowdowns on cpu on a clean environment. # warnings.warn( # "Accuracy BFP16 has experimental support on the CPU. " # "This may result in an unexpected reduction in quality." # ) # dtype = ( # torch.bfloat16 # ) # Using bfloat16 for CPU, since autocast is not supported for float16 if "cuda" in device and fp16: dtype = torch.float16 cache_enabled = True if override_dtype is not None: dtype = override_dtype if dtype == torch.float32 and device == "cpu": return EmptyAutocast(), dtype return ( torch.autocast( device_type=device, dtype=dtype, enabled=True, cache_enabled=cache_enabled ), dtype, ) def cast_network(network: torch.nn.Module, dtype: torch.dtype): """Cast network to given dtype Args: network: Network to be casted dtype: Dtype to cast network to """ if dtype == torch.float16: network.half() elif dtype == torch.bfloat16: network.bfloat16() elif dtype == torch.float32: network.float() else: raise ValueError(f"Unknown dtype {dtype}") def fix_seed(seed=42): """Sets fixed random seed Args: seed: Random seed to be set """ random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # noinspection PyUnresolvedReferences torch.backends.cudnn.deterministic = True # noinspection PyUnresolvedReferences torch.backends.cudnn.benchmark = False return True def suppress_warnings(): # Suppress PyTorch 1.11.0 warning associated with changing order of args in nn.MaxPool2d layer, # since source code is not affected by this issue and there aren't any other correct way to hide this message. warnings.filterwarnings( "ignore", category=UserWarning, message="Note that order of the arguments: ceil_mode and " "return_indices will changeto match the args list " "in nn.MaxPool2d in a future release.", module="torch", )