|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import distributed |
|
from torch import autograd |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
def print_if_rank0(*args): |
|
if distributed.get_rank() == 0: |
|
print(*args) |
|
|
|
|
|
class awesome_allgather_function(autograd.Function): |
|
@staticmethod |
|
def forward(ctx, input): |
|
world_size = distributed.get_world_size() |
|
|
|
allgather_list = [torch.empty_like(input) for _ in range(world_size)] |
|
|
|
|
|
distributed.all_gather(allgather_list, input) |
|
return torch.cat(allgather_list, dim=0) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
|
|
|
|
grads_per_rank = grad_output.shape[0] // distributed.get_world_size() |
|
rank = distributed.get_rank() |
|
|
|
|
|
sl = slice(rank * grads_per_rank, (rank + 1) * grads_per_rank) |
|
|
|
return grad_output[sl] |
|
|
|
|
|
if __name__ == "__main__": |
|
import torch.distributed as dist |
|
import argparse |
|
from torch import nn |
|
from torch.optim import Adam |
|
|
|
argumentparser = argparse.ArgumentParser() |
|
argumentparser.add_argument("--local_rank", type=int) |
|
args = argumentparser.parse_args() |
|
|
|
torch.cuda.set_device(args.local_rank) |
|
dist.init_process_group(backend='nccl', init_method='env://') |
|
|
|
rnd = torch.rand((5, 2)).cuda() |
|
|
|
rnd_gathered = awesome_allgather_function.apply(rnd) |
|
print("gathering random tensors\nbefore\b", rnd, "\nafter\n", rnd_gathered) |
|
|
|
|
|
print("now running a DDP model") |
|
c = nn.Conv2d(2, 3, 3, 1, 1, 1, 1, True).cuda() |
|
c = DDP(c) |
|
opt = Adam(c.parameters()) |
|
|
|
bs = 5 |
|
if dist.get_rank() == 0: |
|
bs = 4 |
|
inp = torch.rand((bs, 2, 5, 5)).cuda() |
|
|
|
out = c(inp) |
|
print("output_shape", out.shape) |
|
|
|
out_gathered = awesome_allgather_function.apply(out) |
|
print("output_shape_after_gather", out_gathered.shape) |
|
|
|
|
|
loss = out_gathered.sum() |
|
loss.backward() |
|
opt.step() |
|
|