|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import functools |
|
import random |
|
import unittest |
|
from multiprocessing import Manager |
|
|
|
import torch |
|
import torch.nn as nn |
|
from fairseq import optim |
|
from fairseq.distributed import utils as distributed_utils |
|
from omegaconf import OmegaConf |
|
|
|
|
|
class Model(nn.Module): |
|
def __init__(self, input_size, output_size): |
|
super(Model, self).__init__() |
|
self.fc = nn.Linear(input_size, output_size) |
|
|
|
def forward(self, input): |
|
output = self.fc(input) |
|
return output |
|
|
|
|
|
def setup_model_loss_criterion(cfg, args, rank, is_cuda): |
|
""" |
|
setup model, criterion and optimizer based on input args |
|
""" |
|
args.distributed_rank = rank |
|
cfg.distributed_training.distributed_rank = args.distributed_rank |
|
if cfg.distributed_training.distributed_world_size > 1: |
|
distributed_utils.distributed_init(cfg) |
|
torch.manual_seed(1) |
|
model = Model(args.input_size, args.nb_classes) |
|
loss_fn = nn.CrossEntropyLoss() |
|
if is_cuda: |
|
model = model.cuda() |
|
loss_fn = loss_fn.cuda() |
|
|
|
optimizer = optim.sgd.SGD(args, model.parameters()) |
|
optimizer = optim.FairseqBMUF( |
|
cfg=cfg.bmuf, |
|
optimizer=optimizer |
|
) |
|
|
|
return model, loss_fn, optimizer |
|
|
|
|
|
def train_step(input, target, model, loss_fn, optimizer, **unused): |
|
"""Do forward, backward and parameter update.""" |
|
model.train() |
|
output = model(input) |
|
loss = loss_fn(output, target) |
|
optimizer.backward(loss) |
|
optimizer.step() |
|
|
|
|
|
def single_gpu_training(cfg, args, rank, iterations, shared_results): |
|
|
|
is_cuda = torch.cuda.is_available() |
|
if is_cuda: |
|
torch.cuda.set_device(rank) |
|
|
|
model, loss_fn, optimizer = setup_model_loss_criterion(cfg, args, rank, is_cuda) |
|
|
|
for _ in range(iterations): |
|
input = torch.randn(1, args.input_size) |
|
target = torch.empty(args.batch_size, dtype=torch.long).random_(args.nb_classes) |
|
|
|
if is_cuda: |
|
input = input.cuda() |
|
target = target.cuda() |
|
train_step(input, target, model, loss_fn, optimizer) |
|
|
|
results = [] |
|
for param in model.parameters(): |
|
if len(results) == 0: |
|
results = param.flatten().cpu().data |
|
else: |
|
results = torch.cat((results, param.flatten().cpu().data), 0) |
|
|
|
shared_results[rank] = results |
|
|
|
|
|
def setup_args(): |
|
args = argparse.Namespace() |
|
args.global_sync_iter = 20 |
|
args.block_momentum = 0.875 |
|
args.block_lr = 0.5 |
|
args.input_size = 5 |
|
args.nb_classes = 2 |
|
args.batch_size = 1 |
|
args.lr = [1e-3] |
|
args.momentum = 0 |
|
args.weight_decay = 0 |
|
args.warmup_iterations = 0 |
|
args.use_nbm = True |
|
args.average_sync = True |
|
args.global_sync_iter = 1 |
|
args.model_parallel_size = 1 |
|
args.distributed_backend = "gloo" |
|
|
|
args.distributed_world_size = 2 |
|
port = random.randint(10000, 20000) |
|
args.distributed_init_method = "tcp://localhost:{port}".format(port=port) |
|
args.distributed_init_host = "localhost" |
|
args.distributed_port = port + 1 |
|
args.local_world_size = args.distributed_world_size |
|
|
|
cfg = OmegaConf.create() |
|
cfg.optimization = OmegaConf.create() |
|
cfg.common = OmegaConf.create() |
|
cfg.distributed_training = OmegaConf.create() |
|
cfg.dataset = OmegaConf.create() |
|
cfg.bmuf = OmegaConf.create() |
|
cfg.optimizer = OmegaConf.create() |
|
|
|
cfg.bmuf.global_sync_iter = args.global_sync_iter |
|
cfg.bmuf.block_momentum = args.block_momentum |
|
cfg.bmuf.block_lr = args.block_lr |
|
cfg.dataset.batch_size = args.batch_size |
|
cfg.optimization.lr = args.lr |
|
cfg.optimizer.momentum = args.momentum |
|
cfg.optimizer.weight_decay = args.weight_decay |
|
cfg.bmuf.warmup_iterations = args.warmup_iterations |
|
cfg.bmuf.use_nbm = args.use_nbm |
|
cfg.bmuf.average_sync = args.average_sync |
|
cfg.common.model_parallel_size = args.model_parallel_size |
|
cfg.distributed_training.distributed_backend = args.distributed_backend |
|
cfg.distributed_training.distributed_world_size = args.distributed_world_size |
|
cfg.bmuf.distributed_world_size = args.distributed_world_size |
|
cfg.distributed_training.distributed_init_method = args.distributed_init_method |
|
cfg.distributed_training.distributed_port = args.distributed_port |
|
|
|
return cfg, args |
|
|
|
|
|
@unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2 GPUs") |
|
class TestBMUF(unittest.TestCase): |
|
def bmuf_process(self, cfg, args, iterations): |
|
processes = [] |
|
results = Manager().dict() |
|
torch.multiprocessing.spawn( |
|
fn=functools.partial(single_gpu_training, cfg, args), |
|
args=(iterations, results), |
|
nprocs=args.distributed_world_size, |
|
join=True, |
|
) |
|
return results |
|
|
|
def test_bmuf_sync(self): |
|
|
|
cfg, args = setup_args() |
|
iterations = 1 |
|
results = self.bmuf_process(cfg, args, iterations) |
|
|
|
assert len(results) == 2 |
|
self.assertAlmostEqual(results[0], results[1]) |
|
|
|
def test_warmup_sync(self): |
|
|
|
cfg, args = setup_args() |
|
args.warmup_iterations = 20 |
|
cfg.bmuf.warmup_iterations = args.warmup_iterations |
|
iterations = 20 |
|
results = self.bmuf_process(cfg, args, iterations) |
|
|
|
assert len(results) == 2 |
|
self.assertAlmostEqual(results[0], results[1]) |
|
|
|
def test_warmup_sync_bmuf_sync(self): |
|
|
|
|
|
cfg, args = setup_args() |
|
args.warmup_iterations = 20 |
|
args.global_sync_iter = 5 |
|
cfg.bmuf.warmup_iterations = args.warmup_iterations |
|
cfg.bmuf.global_sync_iter = args.global_sync_iter |
|
iterations = 25 |
|
results = self.bmuf_process(cfg, args, iterations) |
|
|
|
assert len(results) == 2 |
|
self.assertAlmostEqual(results[0], results[1]) |
|
|
|
def test_single_gpu_bmuf(self): |
|
|
|
cfg, args = setup_args() |
|
args.distributed_world_size = 1 |
|
args.warmup_iterations = 5 |
|
cfg.distributed_training.distributed_world_size = args.distributed_world_size |
|
cfg.bmuf.distributed_world_size = args.distributed_world_size |
|
cfg.bmuf.warmup_iterations = args.warmup_iterations |
|
iterations = 20 |
|
results = self.bmuf_process(cfg, args, iterations) |
|
assert len(results) == 1 |
|
|
|
def assertAlmostEqual(self, t1, t2): |
|
self.assertEqual(t1.size(), t2.size(), "size mismatch") |
|
self.assertLess((t1 - t2).abs().max(), 1e-4) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|