Spaces:
Sleeping
Sleeping
File size: 7,382 Bytes
749745d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import time
import torch
import torch.distributed as dist
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
def reduce_loss_dict(all_loss_dict):
"""
Reduce the loss dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
loss_dict, after reduction.
"""
world_size = get_world_size()
with torch.no_grad():
loss_names = []
all_losses = []
for loss_dict in all_loss_dict:
for k in sorted(loss_dict.keys()):
loss_names.append(k)
all_losses.append(loss_dict[k])
all_losses = torch.stack(all_losses, dim=0)
if world_size > 1:
dist.reduce(all_losses, dst=0)
if dist.get_rank() == 0:
# only main process gets accumulated, so only divide by
# world_size in this case
all_losses /= world_size
reduced_losses = {}
for k, v in zip(loss_names, all_losses):
if k not in reduced_losses:
reduced_losses[k] = v / len(all_loss_dict)
reduced_losses[k] += v / len(all_loss_dict)
return reduced_losses
def do_train(
model,
data_loader,
optimizer,
scheduler,
checkpointer,
device,
checkpoint_period,
arguments,
):
logger = logging.getLogger("maskrcnn_benchmark.trainer")
logger.info("Start training")
meters = MetricLogger(delimiter=" ")
epoch_per_stage = arguments["epoch_per_stage"]
max_iter = sum(len(stage_loader) * epoch_per_stage[si] for si, stage_loader in enumerate(data_loader))
max_iter += epoch_per_stage[-1] * min(len(stage_loader) for stage_loader in data_loader)
model.train()
start_training_time = time.time()
end = time.time()
for stage_i, stage_loader in enumerate(data_loader):
for ep in range(epoch_per_stage[stage_i]):
start_iter = arguments["iteration"]
for iteration, (images, targets, _) in enumerate(stage_loader, start_iter):
data_time = time.time() - end
iteration = iteration + 1
arguments["iteration"] = iteration
scheduler[stage_i].step()
all_stage_loss_dict = []
images = images.to(device)
targets = [target.to(device) for target in targets]
loss_dict = model(images, targets, stage_i)
all_stage_loss_dict.append(loss_dict)
losses = sum(loss for loss_dict in all_stage_loss_dict for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = reduce_loss_dict(all_stage_loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
meters.update(loss=losses_reduced, **loss_dict_reduced)
optimizer.zero_grad()
losses.backward()
optimizer.step()
batch_time = time.time() - end
end = time.time()
meters.update(time=batch_time, data=data_time)
eta_seconds = meters.time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if iteration % 20 == 0 or iteration == max_iter:
logger.info(
meters.delimiter.join(
[
"eta: {eta}",
"iter: {iter}",
"{meters}",
"lr: {lr:.6f}",
"max mem: {memory:.0f}",
]
).format(
eta=eta_string,
iter=iteration,
meters=str(meters),
lr=optimizer.param_groups[0]["lr"],
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
)
)
if iteration % checkpoint_period == 0:
checkpointer.save("model_{:07d}".format(iteration), **arguments)
if iteration == max_iter:
checkpointer.save("model_final", **arguments)
for ep in range(epoch_per_stage[-1]):
start_iter = arguments["iteration"]
for iteration, stage_loader in enumerate(zip(*data_loader), start_iter):
data_time = time.time() - end
iteration = iteration + 1
arguments["iteration"] = iteration
scheduler[-1].step()
all_task_loss_dict = []
for stage_i, (images, targets, _) in enumerate(stage_loader):
images = images.to(device)
targets = [target.to(device) for target in targets]
loss_dict = model(images, targets, stage_i)
all_task_loss_dict.append(loss_dict)
losses = sum(loss for loss_dict in all_task_loss_dict for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = reduce_loss_dict(all_task_loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
meters.update(loss=losses_reduced, **loss_dict_reduced)
optimizer.zero_grad()
losses.backward()
optimizer.step()
batch_time = time.time() - end
end = time.time()
meters.update(time=batch_time, data=data_time)
eta_seconds = meters.time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if iteration % 20 == 0 or iteration == max_iter:
logger.info(
meters.delimiter.join(
[
"eta: {eta}",
"iter: {iter}",
"{meters}",
"lr: {lr:.6f}",
"max mem: {memory:.0f}",
]
).format(
eta=eta_string,
iter=iteration,
meters=str(meters),
lr=optimizer.param_groups[0]["lr"],
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
)
)
if iteration % checkpoint_period == 0:
checkpointer.save("model_{:07d}".format(iteration), **arguments)
if iteration == max_iter:
checkpointer.save("model_final", **arguments)
total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info("Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / (max_iter)))
|