from typing import Optional, Union import torch class _remote_device: """ Represents a device on a remote worker. Args: remote_device (str or torch.device): Represents a device on a remote worker. The string format should be one of the following: 1. "/", where the device field can be parsed as torch.device type. E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". In addition, the device field can be optional and the default value is "cpu". 2. "rank:/", where is the rank of the process and device can be parsed as torch.device type. E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0" 3. and are optional and formats like "cpu" and "cuda:1", just represent local devices. """ def __init__(self, remote_device: Union[str, torch.device]): PARSE_ERROR = ( f"Could not parse remote_device: {remote_device}. The valid format is " "'/' or 'rank:/' or ''" ) self._worker_name = None self._rank = None self._device: Optional[Union[str, int, torch.device]] = None if isinstance(remote_device, torch.device): self._device = remote_device elif isinstance(remote_device, str): fields = remote_device.split("/") if len(fields) == 2: self._worker_name, self._device = fields elif len(fields) == 1: # Check if this is a valid device. if _remote_device._is_valid_local_device(fields[0]): self._device = fields[0] else: self._worker_name = fields[0] self._device = "cpu" else: raise ValueError(PARSE_ERROR) else: raise TypeError(f'Invalid type for remote_device: {type(remote_device)}') # Do some basic sanity check (no empty string) if self._worker_name is not None and not self._worker_name: raise ValueError(PARSE_ERROR) # Validate the device. self._device = torch.device(self._device) # Check for rank based format. if self._worker_name is not None: fields = self._worker_name.split(":") if len(fields) == 2: # rank:/device format, extract rank if fields[0] == "rank" and fields[1].isdigit(): self._rank = int(fields[1]) # type: ignore[assignment] self._worker_name = None else: raise ValueError(PARSE_ERROR) elif len(fields) > 2: raise ValueError(PARSE_ERROR) @staticmethod def _is_valid_local_device(device): # Check for torch.device try: torch.device(device) return True except Exception: return False def worker_name(self) -> Optional[str]: """Return the name of remote worker representing the remote device and ``None`` if no worker name is available.""" return self._worker_name def rank(self) -> Optional[int]: """ Returns the rank of remote worker representing the remote device. Returns ``None`` if no rank is available. """ return self._rank def device(self) -> torch.device: """Return the local device on the remote worker.""" return self._device # type: ignore[return-value] def __repr__(self): if self._device is not None: if self._worker_name is not None: return f'{self._worker_name}/{self._device}' elif self._rank is not None: return f'rank:{self._rank}/{self._device}' else: return str(self._device) else: if self._worker_name is not None: return f'{self._worker_name}' elif self._rank is not None: return f'{self._rank}' else: raise RuntimeError('Invalid state!') def __eq__(self, other): if not isinstance(other, _remote_device): return False if ( self._worker_name == other._worker_name and self._device == other._device and self._rank == other._rank ): return True return False def __hash__(self): return hash(self._worker_name) ^ \ hash(self._device) ^ \ hash(self._rank)