Spaces:
Runtime error
Runtime error
# pylint: disable=no-member,no-self-argument,no-method-argument | |
from typing import Optional, Callable | |
import torch | |
import torch_directml # pylint: disable=import-error | |
import modules.dml.amp as amp | |
from .utils import rDevice, get_device | |
from .device import Device | |
from .Generator import Generator | |
from .device_properties import DeviceProperties | |
def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: | |
from .memory_amd import AMDMemoryProvider | |
return AMDMemoryProvider.mem_get_info(get_device(device).index) | |
def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: | |
mem_info = DirectML.memory_provider.get_memory(get_device(device).index) | |
return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"]) | |
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument | |
return (8589934592, 8589934592) | |
class DirectML: | |
amp = amp | |
device = Device | |
Generator = Generator | |
context_device: Optional[torch.device] = None | |
is_autocast_enabled = False | |
autocast_gpu_dtype = torch.float16 | |
memory_provider = None | |
def is_available() -> bool: | |
return torch_directml.is_available() | |
def is_directml_device(device: torch.device) -> bool: | |
return device.type == "privateuseone" | |
def has_float64_support(device: Optional[rDevice]=None) -> bool: | |
return torch_directml.has_float64_support(get_device(device).index) | |
def device_count() -> int: | |
return torch_directml.device_count() | |
def current_device() -> torch.device: | |
return DirectML.context_device or DirectML.default_device() | |
def default_device() -> torch.device: | |
return torch_directml.device(torch_directml.default_device()) | |
def get_device_string(device: Optional[rDevice]=None) -> str: | |
return f"privateuseone:{get_device(device).index}" | |
def get_device_name(device: Optional[rDevice]=None) -> str: | |
return torch_directml.device_name(get_device(device).index) | |
def get_device_properties(device: Optional[rDevice]=None) -> DeviceProperties: | |
return DeviceProperties(get_device(device)) | |
def memory_stats(device: Optional[rDevice]=None): | |
return { | |
"num_ooms": 0, | |
"num_alloc_retries": 0, | |
} | |
mem_get_info: Callable = mem_get_info | |
def memory_allocated(device: Optional[rDevice]=None) -> int: | |
return sum(torch_directml.gpu_memory(get_device(device).index)) * (1 << 20) | |
def max_memory_allocated(device: Optional[rDevice]=None): | |
return DirectML.memory_allocated(device) # DirectML does not empty GPU memory | |
def reset_peak_memory_stats(device: Optional[rDevice]=None): | |
return | |