|
import torch |
|
from modules import paths |
|
from modules.sd_hijack_utils import CondFunc |
|
from packaging import version |
|
|
|
|
|
|
|
|
|
def check_for_mps() -> bool: |
|
if not getattr(torch, 'has_mps', False): |
|
return False |
|
try: |
|
torch.zeros(1).to(torch.device("mps")) |
|
return True |
|
except Exception: |
|
return False |
|
has_mps = check_for_mps() |
|
|
|
|
|
|
|
def cumsum_fix(input, cumsum_func, *args, **kwargs): |
|
if input.device.type == 'mps': |
|
output_dtype = kwargs.get('dtype', input.dtype) |
|
if output_dtype == torch.int64: |
|
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) |
|
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): |
|
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) |
|
return cumsum_func(input, *args, **kwargs) |
|
|
|
|
|
if has_mps: |
|
|
|
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps') |
|
|
|
if version.parse(torch.__version__) < version.parse("1.13"): |
|
|
|
|
|
|
|
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), |
|
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) |
|
|
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), |
|
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') |
|
|
|
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) |
|
elif version.parse(torch.__version__) > version.parse("1.13.1"): |
|
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) |
|
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) |
|
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) |
|
CondFunc('torch.cumsum', cumsum_fix_func, None) |
|
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) |
|
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) |
|
|
|
|