|
from typing import * |
|
|
|
import torch |
|
import torch.distributed.rpc as rpc |
|
from torch import Tensor |
|
from torch._jit_internal import Future |
|
from torch.distributed.rpc import RRef |
|
from typing import Tuple |
|
|
|
|
|
module_interface_cls = None |
|
|
|
|
|
def forward_async(self, *args, **kwargs): |
|
args = (self.module_rref, self.device, self.is_device_map_set, *args) |
|
kwargs = {**kwargs} |
|
return rpc.rpc_async( |
|
self.module_rref.owner(), |
|
_remote_forward, |
|
args, |
|
kwargs, |
|
) |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
args = (self.module_rref, self.device, self.is_device_map_set, *args) |
|
kwargs = {**kwargs} |
|
ret_fut = rpc.rpc_async( |
|
self.module_rref.owner(), |
|
_remote_forward, |
|
args, |
|
kwargs, |
|
) |
|
return ret_fut.wait() |
|
|
|
|
|
_generated_methods = [ |
|
forward_async, |
|
forward, |
|
] |
|
|
|
|
|
|
|
|
|
def _remote_forward( |
|
module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, *args, **kwargs): |
|
module = module_rref.local_value() |
|
device = torch.device(device) |
|
|
|
if device.type != "cuda": |
|
return module.forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
args = (*args,) |
|
out_args: Tuple[()] = () |
|
for arg in args: |
|
arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,) |
|
out_args = out_args + arg |
|
|
|
kwargs = {**kwargs} |
|
for k, v in kwargs.items(): |
|
if isinstance(v, Tensor): |
|
kwargs[k] = kwargs[k].to(device) |
|
|
|
if is_device_map_set: |
|
return module.forward(*out_args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
ret: Tuple[()] = () |
|
for i in module.forward(*out_args, **kwargs): |
|
i = (i.cpu(),) if isinstance(i, Tensor) else (i,) |
|
ret = ret + i |
|
return ret |
|
|