genius_bgremover / carvekit /utils /models_utils.py
Towsif7's picture
firrst commit
59e40e1
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
import random
import warnings
from typing import Union, Tuple, Any
import torch
from torch import autocast
class EmptyAutocast(object):
"""
Empty class for disable any autocasting.
"""
def __enter__(self):
return None
def __exit__(self, exc_type, exc_val, exc_tb):
return
def __call__(self, func):
return
def get_precision_autocast(
device="cpu", fp16=True, override_dtype=None
) -> Union[
Tuple[EmptyAutocast, Union[torch.dtype, Any]],
Tuple[autocast, Union[torch.dtype, Any]],
]:
"""
Returns precision and autocast settings for given device and fp16 settings.
Args:
device: Device to get precision and autocast settings for.
fp16: Whether to use fp16 precision.
override_dtype: Override dtype for autocast.
Returns:
Autocast object, dtype
"""
dtype = torch.float32
cache_enabled = None
if device == "cpu" and fp16:
warnings.warn('FP16 is not supported on CPU. Using FP32 instead.')
dtype = torch.float32
# TODO: Implement BFP16 on CPU. There are unexpected slowdowns on cpu on a clean environment.
# warnings.warn(
# "Accuracy BFP16 has experimental support on the CPU. "
# "This may result in an unexpected reduction in quality."
# )
# dtype = (
# torch.bfloat16
# ) # Using bfloat16 for CPU, since autocast is not supported for float16
if "cuda" in device and fp16:
dtype = torch.float16
cache_enabled = True
if override_dtype is not None:
dtype = override_dtype
if dtype == torch.float32 and device == "cpu":
return EmptyAutocast(), dtype
return (
torch.autocast(
device_type=device, dtype=dtype, enabled=True, cache_enabled=cache_enabled
),
dtype,
)
def cast_network(network: torch.nn.Module, dtype: torch.dtype):
"""Cast network to given dtype
Args:
network: Network to be casted
dtype: Dtype to cast network to
"""
if dtype == torch.float16:
network.half()
elif dtype == torch.bfloat16:
network.bfloat16()
elif dtype == torch.float32:
network.float()
else:
raise ValueError(f"Unknown dtype {dtype}")
def fix_seed(seed=42):
"""Sets fixed random seed
Args:
seed: Random seed to be set
"""
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# noinspection PyUnresolvedReferences
torch.backends.cudnn.deterministic = True
# noinspection PyUnresolvedReferences
torch.backends.cudnn.benchmark = False
return True
def suppress_warnings():
# Suppress PyTorch 1.11.0 warning associated with changing order of args in nn.MaxPool2d layer,
# since source code is not affected by this issue and there aren't any other correct way to hide this message.
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="Note that order of the arguments: ceil_mode and "
"return_indices will changeto match the args list "
"in nn.MaxPool2d in a future release.",
module="torch",
)