test / modules /dml /backend.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
raw
history blame contribute delete
No virus
2.75 kB
# pylint: disable=no-member,no-self-argument,no-method-argument
from typing import Optional, Callable
import torch
import torch_directml # pylint: disable=import-error
import modules.dml.amp as amp
from .utils import rDevice, get_device
from .device import Device
from .Generator import Generator
from .device_properties import DeviceProperties
def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
from .memory_amd import AMDMemoryProvider
return AMDMemoryProvider.mem_get_info(get_device(device).index)
def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
mem_info = DirectML.memory_provider.get_memory(get_device(device).index)
return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"])
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument
return (8589934592, 8589934592)
class DirectML:
amp = amp
device = Device
Generator = Generator
context_device: Optional[torch.device] = None
is_autocast_enabled = False
autocast_gpu_dtype = torch.float16
memory_provider = None
def is_available() -> bool:
return torch_directml.is_available()
def is_directml_device(device: torch.device) -> bool:
return device.type == "privateuseone"
def has_float64_support(device: Optional[rDevice]=None) -> bool:
return torch_directml.has_float64_support(get_device(device).index)
def device_count() -> int:
return torch_directml.device_count()
def current_device() -> torch.device:
return DirectML.context_device or DirectML.default_device()
def default_device() -> torch.device:
return torch_directml.device(torch_directml.default_device())
def get_device_string(device: Optional[rDevice]=None) -> str:
return f"privateuseone:{get_device(device).index}"
def get_device_name(device: Optional[rDevice]=None) -> str:
return torch_directml.device_name(get_device(device).index)
def get_device_properties(device: Optional[rDevice]=None) -> DeviceProperties:
return DeviceProperties(get_device(device))
def memory_stats(device: Optional[rDevice]=None):
return {
"num_ooms": 0,
"num_alloc_retries": 0,
}
mem_get_info: Callable = mem_get_info
def memory_allocated(device: Optional[rDevice]=None) -> int:
return sum(torch_directml.gpu_memory(get_device(device).index)) * (1 << 20)
def max_memory_allocated(device: Optional[rDevice]=None):
return DirectML.memory_allocated(device) # DirectML does not empty GPU memory
def reset_peak_memory_stats(device: Optional[rDevice]=None):
return