Pinwheel's picture
HF Demo
128757a
raw
history blame
16.1 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import sys
import os
import math
import time
import torch
import torch.distributed as dist
from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
from maskrcnn_benchmark.utils.ema import ModelEma
from maskrcnn_benchmark.utils.amp import autocast, GradScaler
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from .inference import inference
import pdb
def reduce_loss_dict(loss_dict):
"""
Reduce the loss dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
loss_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return loss_dict
with torch.no_grad():
loss_names = []
all_losses = []
for k in sorted(loss_dict.keys()):
loss_names.append(k)
all_losses.append(loss_dict[k])
all_losses = torch.stack(all_losses, dim=0)
dist.reduce(all_losses, dst=0)
if dist.get_rank() == 0:
# only main process gets accumulated, so only divide by
# world_size in this case
all_losses /= world_size
reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
return reduced_losses
def do_train(
cfg,
model,
data_loader,
optimizer,
scheduler,
checkpointer,
device,
checkpoint_period,
arguments,
val_data_loader=None,
meters=None,
zero_shot=False
):
logger = logging.getLogger("maskrcnn_benchmark.trainer")
logger.info("Start training")
# meters = MetricLogger(delimiter=" ")
max_iter = len(data_loader)
start_iter = arguments["iteration"]
model.train()
model_ema = None
if cfg.SOLVER.MODEL_EMA > 0:
model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA)
start_training_time = time.time()
end = time.time()
if cfg.SOLVER.USE_AMP:
scaler = GradScaler()
global_rank = get_rank()
if cfg.SOLVER.CHECKPOINT_PER_EPOCH != -1 and cfg.SOLVER.MAX_EPOCH >= 1:
checkpoint_period = len(data_loader) * cfg.SOLVER.CHECKPOINT_PER_EPOCH // cfg.SOLVER.MAX_EPOCH
if global_rank <= 0 and cfg.SOLVER.MAX_EPOCH >= 1:
print("Iter per epoch ", len(data_loader) // cfg.SOLVER.MAX_EPOCH )
if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1:
patience_counter = 0
previous_best = 0.0
# Adapt the weight decay
if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'):
milestone_target = 0
for i, milstone in enumerate(list(scheduler.milestones)):
if scheduler.last_epoch >= milstone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO:
milestone_target = i+1
for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):
nnegative = sum(len(target) < 1 for target in targets)
nsample = len(targets)
if nsample == nnegative or nnegative > nsample * cfg.SOLVER.MAX_NEG_PER_BATCH:
logger.info('[WARNING] Sampled {} negative in {} in a batch, greater the allowed ratio {}, skip'.
format(nnegative, nsample, cfg.SOLVER.MAX_NEG_PER_BATCH))
continue
data_time = time.time() - end
iteration = iteration + 1
arguments["iteration"] = iteration
images = images.to(device)
captions = None
try:
targets = [target.to(device) for target in targets]
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
except:
pass
# Freeze language backbone
if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
if hasattr(model, "module"):
model.module.language_backbone.eval()
else:
model.language_backbone.eval()
if cfg.SOLVER.USE_AMP:
with autocast():
if len(captions) > 0:
loss_dict = model(images, targets, captions, positive_map, greenlight_map = greenlight_map)
else:
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# save checkpoints for further debug if nan happens
# loss_value = losses.item()
# if not math.isfinite(loss_value):
# logging.error(f'=> loss is {loss_value}, stopping training')
# logging.error("Losses are : {}".format(loss_dict))
# time_str = time.strftime('%Y-%m-%d-%H-%M')
# fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth')
# logging.info(f'=> save error state to {fname}')
# dict_to_save = {
# 'x': images,
# 'y': targets,
# 'loss': losses,
# 'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
# }
# if len(captions) > 0:
# dict_to_save['captions'] = captions
# dict_to_save['positive_map'] = positive_map
# torch.save(
# dict_to_save,
# fname
# )
if torch.isnan(losses) or torch.isinf(losses):
logging.error("NaN encountered, ignoring")
losses[losses != losses] = 0
optimizer.zero_grad()
scaler.scale(losses).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
else:
if len(captions) > 0:
loss_dict = model(images, targets, captions, positive_map)
else:
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# loss_value = losses.item()
# if not math.isfinite(loss_value):
# logging.error(f'=> loss is {loss_value}, stopping training')
# time_str = time.strftime('%Y-%m-%d-%H-%M')
# fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth')
# logging.info(f'=> save error state to {fname}')
# dict_to_save = {
# 'x': images,
# 'y': targets,
# 'loss': losses,
# 'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
# }
# if len(captions) > 0:
# dict_to_save['captions'] = captions
# dict_to_save['positive_map'] = positive_map
# torch.save(
# dict_to_save,
# fname
# )
if torch.isnan(losses) or torch.isinf(losses):
losses[losses != losses] = 0
optimizer.zero_grad()
losses.backward()
optimizer.step()
scheduler.step()
# Adapt the weight decay: only support multiStepLR
if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'):
if milestone_target < len(scheduler.milestones):
next_milestone = list(scheduler.milestones)[milestone_target]
else:
next_milestone = float('inf')
if scheduler.last_epoch >= next_milestone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO:
gamma = scheduler.gamma
logger.info("Drop the weight decay by {}!".format(gamma))
for param in optimizer.param_groups:
if 'weight_decay' in param:
param['weight_decay'] *= gamma
# move the target forward
milestone_target += 1
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = reduce_loss_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
meters.update(loss=losses_reduced, **loss_dict_reduced)
if model_ema is not None:
model_ema.update(model)
arguments["model_ema"] = model_ema.state_dict()
batch_time = time.time() - end
end = time.time()
meters.update(time=batch_time, data=data_time)
eta_seconds = meters.time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if iteration % 20 == 0 or iteration == max_iter:
# if iteration % 1 == 0 or iteration == max_iter:
#logger.info(
if global_rank <= 0:
print(
meters.delimiter.join(
[
"eta: {eta}",
"iter: {iter}",
"{meters}",
"lr: {lr:.6f}",
"wd: {wd:.6f}",
"max mem: {memory:.0f}",
]
).format(
eta=eta_string,
iter=iteration,
meters=str(meters),
lr=optimizer.param_groups[0]["lr"],
wd=optimizer.param_groups[0]["weight_decay"],
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
)
)
if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter):
if is_main_process():
print("Evaluating")
eval_result = 0.0
model.eval()
if cfg.SOLVER.TEST_WITH_INFERENCE:
with torch.no_grad():
try:
_model = model.module
except:
_model = model
_result = inference(
model = _model,
data_loader = val_data_loader,
dataset_name="val",
device=device,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
output_folder=None,
cfg=cfg,
verbose=False
)
if is_main_process():
eval_result = _result[0].results['bbox']['AP']
else:
results_dict = {}
cpu_device = torch.device("cpu")
for i, batch in enumerate(val_data_loader):
images, targets, image_ids, positive_map, *_ = batch
with torch.no_grad():
images = images.to(device)
if positive_map is None:
output = model(images)
else:
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
output = model(images, captions, positive_map)
output = [o.to(cpu_device) for o in output]
results_dict.update(
{img_id: result for img_id, result in zip(image_ids, output)}
)
all_predictions = all_gather(results_dict)
if is_main_process():
predictions = {}
for p in all_predictions:
predictions.update(p)
predictions = [predictions[i] for i in list(sorted(predictions.keys()))]
eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None,
box_only=cfg.DATASETS.CLASS_AGNOSTIC)
if cfg.DATASETS.CLASS_AGNOSTIC:
eval_result = eval_result.results['box_proposal']['AR@100']
else:
eval_result = eval_result.results['bbox']['AP']
model.train()
if model_ema is not None and cfg.SOLVER.USE_EMA_FOR_MONITOR:
model_ema.ema.eval()
results_dict = {}
cpu_device = torch.device("cpu")
for i, batch in enumerate(val_data_loader):
images, targets, image_ids, positive_map, positive_map_eval = batch
with torch.no_grad():
images = images.to(device)
if positive_map is None:
output = model_ema.ema(images)
else:
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
output = model_ema.ema(images, captions, positive_map)
output = [o.to(cpu_device) for o in output]
results_dict.update(
{img_id: result for img_id, result in zip(image_ids, output)}
)
all_predictions = all_gather(results_dict)
if is_main_process():
predictions = {}
for p in all_predictions:
predictions.update(p)
predictions = [predictions[i] for i in list(sorted(predictions.keys()))]
eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None,
box_only=cfg.DATASETS.CLASS_AGNOSTIC)
if cfg.DATASETS.CLASS_AGNOSTIC:
eval_result = eval_result.results['box_proposal']['AR@100']
else:
eval_result = eval_result.results['bbox']['AP']
arguments.update(eval_result=eval_result)
if cfg.SOLVER.USE_AUTOSTEP:
eval_result = all_gather(eval_result)[0] #broadcast_data([eval_result])[0]
# print("Rank {} eval result gathered".format(cfg.local_rank), eval_result)
scheduler.step(eval_result)
if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1:
if eval_result < previous_best:
patience_counter += 1
else:
patience_counter = 0
previous_best = eval_result
checkpointer.save("model_best", **arguments)
print("Previous Best", previous_best, "Patience Counter", patience_counter, "Eval Result", eval_result)
if patience_counter >= cfg.SOLVER.AUTO_TERMINATE_PATIENCE:
if is_main_process():
print("\n\n\n\nAuto Termination at {}, current best {}\n\n\n".format(iteration, previous_best))
break
if iteration % checkpoint_period == 0:
checkpointer.save("model_{:07d}".format(iteration), **arguments)
if iteration == max_iter:
checkpointer.save("model_final", **arguments)
break
total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info(
"Total training time: {} ({:.4f} s / it)".format(
total_time_str, total_training_time / (max_iter)
)
)