Spaces:
Sleeping
Sleeping
File size: 4,986 Bytes
749745d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import time
import random
import torch
import torch.distributed as dist
from maskrcnn_benchmark.utils.comm import get_world_size, synchronize, broadcast_data
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
from maskrcnn_benchmark.utils.ema import ModelEma
def reduce_loss_dict(loss_dict):
"""
Reduce the loss dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
loss_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return loss_dict
with torch.no_grad():
loss_names = []
all_losses = []
for k in sorted(loss_dict.keys()):
loss_names.append(k)
all_losses.append(loss_dict[k])
all_losses = torch.stack(all_losses, dim=0)
dist.reduce(all_losses, dst=0)
if dist.get_rank() == 0:
# only main process gets accumulated, so only divide by
# world_size in this case
all_losses /= world_size
reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
return reduced_losses
def do_train(
cfg, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, rngs=None
):
logger = logging.getLogger("maskrcnn_benchmark.trainer")
logger.info("Start training")
meters = MetricLogger(delimiter=" ")
max_iter = len(data_loader)
start_iter = arguments["iteration"]
model.train()
model_ema = None
if cfg.SOLVER.MODEL_EMA > 0:
model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA)
start_training_time = time.time()
end = time.time()
for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
if any(len(target) < 1 for target in targets):
logger.error(
"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}"
)
continue
data_time = time.time() - end
iteration = iteration + 1
arguments["iteration"] = iteration
images = images.to(device)
targets = [target.to(device) for target in targets]
# synchronize rngs
if rngs is None:
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
mix_nums = model.module.mix_nums
else:
mix_nums = model.mix_nums
rngs = [random.randint(0, mix - 1) for mix in mix_nums]
rngs = broadcast_data(rngs)
for param in model.parameters():
param.requires_grad = False
loss_dict = model(images, targets, rngs)
losses = sum(loss for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = reduce_loss_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
meters.update(loss=losses_reduced, **loss_dict_reduced)
optimizer.zero_grad()
losses.backward()
optimizer.step()
scheduler.step()
if model_ema is not None:
model_ema.update(model)
arguments["model_ema"] = model_ema.state_dict()
batch_time = time.time() - end
end = time.time()
meters.update(time=batch_time, data=data_time)
eta_seconds = meters.time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if iteration % 20 == 0 or iteration == max_iter:
logger.info(
meters.delimiter.join(
[
"eta: {eta}",
"iter: {iter}",
"{meters}",
"lr: {lr:.6f}",
"max mem: {memory:.0f}",
]
).format(
eta=eta_string,
iter=iteration,
meters=str(meters),
lr=optimizer.param_groups[0]["lr"],
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
)
)
if iteration % checkpoint_period == 0:
checkpointer.save("model_{:07d}".format(iteration), **arguments)
if iteration == max_iter:
if model_ema is not None:
model.load_state_dict(model_ema.state_dict())
checkpointer.save("model_final", **arguments)
total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info("Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / (max_iter)))
|