booxel-cpu / BOOXEL /utils /devices.py
yanranxiaoxi's picture
First commit
00fc29f verified
import sys
import contextlib
from functools import lru_cache
import torch
#from modules import errors
if sys.platform == "darwin":
from modules import mac_specific
def has_mps() -> bool:
if sys.platform != "darwin":
return False
else:
return mac_specific.has_mps
def get_cuda_device_string():
return "cuda"
def get_optimal_device_name():
if torch.cuda.is_available():
return get_cuda_device_string()
if has_mps():
return "mps"
return "cpu"
def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task):
return get_optimal_device()
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if has_mps():
mac_specific.torch_mps_gc()
def enable_tf32():
if torch.cuda.is_available():
# 启用基准选项似乎能让一系列显卡在无法使用 fp16 时使用 fp16
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
enable_tf32()
#errors.run(enable_tf32, "Enabling TF32")
cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
dtype = torch.float16
dtype_vae = torch.float16
dtype_unet = torch.float16
unet_needs_upcast = False
def cond_cast_unet(input):
return input.to(dtype_unet) if unet_needs_upcast else input
def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
def randn(seed, shape):
torch.manual_seed(seed)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
return torch.randn(shape, device=device)
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
return torch.autocast("cuda")
def without_autocast(disable=False):
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
class NansException(Exception):
pass
def test_for_nans(x, where):
if not torch.all(torch.isnan(x)).item():
return
if where == "unet":
message = "在 Unet 中生成了一个包含所有 NaNs 的张量。"
elif where == "vae":
message = "在 VAE 中生成了一个包含所有 NaN 的张量。"
else:
message = "产生了一个包含所有 NaN 的张量。"
message += " 使用 --disable-nan-check 命令行参数禁用此检查。"
raise NansException(message)
@lru_cache
def first_time_calculation():
"""
只要用 pytorch 层进行任何计算,第一次计算就会分配约 700MB 内存,耗时约 2.7 秒,至少在 NVIDIA 上是这样。
"""
x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)