zdou0830's picture
desco
749745d
raw
history blame
16.7 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import sys
import os
import math
import time
import torch
import torch.distributed as dist
from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
from maskrcnn_benchmark.utils.ema import ModelEma
from maskrcnn_benchmark.utils.amp import autocast, GradScaler
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from .inference import inference
from .tsv_saver import TSVResultWriter
import wandb
import pdb, random
def reduce_loss_dict(loss_dict):
"""
Reduce the loss dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
loss_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return loss_dict
with torch.no_grad():
loss_names = []
all_losses = []
for k in sorted(loss_dict.keys()):
loss_names.append(k)
all_losses.append(loss_dict[k])
all_losses = torch.stack(all_losses, dim=0)
dist.reduce(all_losses, dst=0)
if dist.get_rank() == 0:
# only main process gets accumulated, so only divide by
# world_size in this case
all_losses /= world_size
reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
return reduced_losses
def do_train(
cfg,
model,
data_loader,
optimizer,
scheduler,
checkpointer,
device,
checkpoint_period,
arguments,
val_data_loader=None,
meters=None,
use_wandb=False
):
logger = logging.getLogger("maskrcnn_benchmark.trainer")
logger.info("Start training")
# meters = MetricLogger(delimiter=" ")
max_iter = len(data_loader)
start_iter = arguments["iteration"]
model.train()
model_ema = None
if cfg.SOLVER.MODEL_EMA > 0:
model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA)
start_training_time = time.time()
end = time.time()
if cfg.SOLVER.USE_AMP:
scaler = GradScaler()
global_rank = get_rank()
if cfg.SOLVER.CHECKPOINT_PER_EPOCH != -1 and cfg.SOLVER.MAX_EPOCH >= 1:
checkpoint_period = len(data_loader) * cfg.SOLVER.CHECKPOINT_PER_EPOCH // cfg.SOLVER.MAX_EPOCH
if global_rank <= 0 and cfg.SOLVER.MAX_EPOCH >= 1:
print("Iter per epoch ", len(data_loader) // cfg.SOLVER.MAX_EPOCH)
if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1:
patience_counter = 0
previous_best = 0.0
# Adapt the weight decay
if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, "milestones"):
milestone_target = 0
for i, milstone in enumerate(list(scheduler.milestones)):
if scheduler.last_epoch >= milstone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO:
milestone_target = i + 1
# try to visualize the training data
### get the tokenizer
if hasattr(model, "module"):
tokenizer = model.module.tokenizer
else:
tokenizer = model.tokenizer
tsv_visualizer = TSVResultWriter(
tokenizer=tokenizer,
max_visualize_num=1000,
file_name=cfg.OUTPUT_DIR + "/train_visualize/train.tsv", write_freq=100)
for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(
data_loader, start_iter
):
nnegative = sum(len(target) < 1 for target in targets)
nsample = len(targets)
if nnegative > nsample * cfg.SOLVER.MAX_NEG_PER_BATCH:
logger.info(
"[WARNING] Sampled {} negative in {} in a batch, greater the allowed ratio {}, skip".format(
nnegative, nsample, cfg.SOLVER.MAX_NEG_PER_BATCH
)
)
continue
data_time = time.time() - end
iteration = iteration + 1
arguments["iteration"] = iteration
images = images.to(device)
captions = None
try:
targets = [target.to(device) for target in targets]
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
except:
pass
# Freeze language backbone
if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
if hasattr(model, "module"):
if hasattr(model.module, "fusion_backbone"):
model.module.fusion_backbone.language_backbone.eval()
else:
model.module.language_backbone.eval()
else:
if hasattr(model, "fusion_backbone"):
model.fusion_backbone.language_backbone.eval()
else:
model.language_backbone.eval()
if is_main_process(): # only visualize for the main process
tsv_visualizer.update_train_data(images, targets)
if cfg.SOLVER.USE_AMP:
with autocast():
if len(captions) > 0:
loss_dict = model(images, targets, captions, positive_map, greenlight_map=greenlight_map)
else:
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# save checkpoints for further debug if nan happens
loss_value = losses.item()
if torch.isnan(losses) or torch.isinf(losses):
logging.error("NaN encountered, ignoring")
losses[losses != losses] = 0
# if loss is too large, ignore it
if loss_value > 10 and iteration > 10000:
losses[losses == losses] = 0 # this is a bad example
print("Loss is too large, ignore it, loss: ", loss_value)
optimizer.zero_grad()
scaler.scale(losses).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
else:
if len(captions) > 0:
loss_dict = model(images, targets, captions, positive_map)
else:
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# save checkpoints for further debug if nan happens
loss_value = losses.item()
if not math.isfinite(loss_value):
logging.error(f"=> loss is {loss_value}, stopping training")
time_str = time.strftime("%Y-%m-%d-%H-%M")
fname = os.path.join(checkpointer.save_dir, f"{time_str}_states.pth")
logging.info(f"=> save error state to {fname}")
dict_to_save = {
"x": images,
"y": targets,
"loss": losses,
"states": model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
}
if len(captions) > 0:
dict_to_save["captions"] = captions
dict_to_save["positive_map"] = positive_map
torch.save(dict_to_save, fname)
sys.exit(-1)
if torch.isnan(losses) or torch.isinf(losses):
losses[losses != losses] = 0
# if loss is too large, ignore it
# if loss_value > 10 and iteration > 10000:
# losses[losses == losses] = 0 # this is a bad example
# print("Loss is too large, ignore it, loss: ", loss_value)
optimizer.zero_grad()
losses.backward()
optimizer.step()
scheduler.step()
# Adapt the weight decay: only support multiStepLR
if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, "milestones"):
if milestone_target < len(scheduler.milestones):
next_milestone = list(scheduler.milestones)[milestone_target]
else:
next_milestone = float("inf")
if scheduler.last_epoch >= next_milestone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO:
gamma = scheduler.gamma
logger.info("Drop the weight decay by {}!".format(gamma))
for param in optimizer.param_groups:
if "weight_decay" in param:
param["weight_decay"] *= gamma
# move the target forward
milestone_target += 1
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = reduce_loss_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
meters.update(loss=losses_reduced, **loss_dict_reduced)
if model_ema is not None:
model_ema.update(model)
arguments["model_ema"] = model_ema.state_dict()
batch_time = time.time() - end
end = time.time()
meters.update(time=batch_time, data=data_time)
eta_seconds = meters.time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if iteration % 20 == 0 or iteration == max_iter:
# if iteration % 1 == 0 or iteration == max_iter:
# logger.info(
if global_rank <= 0:
print(
meters.delimiter.join(
[
"eta: {eta}",
"iter: {iter}",
"{meters}",
"lr: {lr:.6f}",
"wd: {wd:.6f}",
"max mem: {memory:.0f}",
]
).format(
eta=eta_string,
iter=iteration,
meters=str(meters),
lr=optimizer.param_groups[0]["lr"],
wd=optimizer.param_groups[0]["weight_decay"],
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
)
)
if use_wandb and is_main_process():
wandb.log({"train_loss": losses_reduced, "lr": optimizer.param_groups[0]["lr"], "wd": optimizer.param_groups[0]["weight_decay"], **loss_dict_reduced})
if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter):
if is_main_process():
print("Evaluating")
eval_result = 0.0
model.eval()
if cfg.SOLVER.TEST_WITH_INFERENCE:
with torch.no_grad():
try:
_model = model.module
except:
_model = model
_result = inference(
model=_model,
data_loader=val_data_loader,
dataset_name="val",
device=device,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
output_folder=None,
cfg=cfg,
verbose=False,
)
if is_main_process():
try:
eval_result = _result[0].results["bbox"]["AP"]
except:
pass
else:
results_dict = {}
cpu_device = torch.device("cpu")
for i, batch in enumerate(val_data_loader):
images, targets, image_ids, positive_map, *_ = batch
with torch.no_grad():
images = images.to(device)
if positive_map is None:
output = model(images)
else:
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
output = model(images, captions, positive_map)
output = [o.to(cpu_device) for o in output]
results_dict.update({img_id: result for img_id, result in zip(image_ids, output)})
all_predictions = all_gather(results_dict)
if is_main_process():
predictions = {}
for p in all_predictions:
predictions.update(p)
predictions = [predictions[i] for i in list(sorted(predictions.keys()))]
eval_result, _ = evaluate(
val_data_loader.dataset, predictions, output_folder=None, box_only=cfg.DATASETS.CLASS_AGNOSTIC
)
if cfg.DATASETS.CLASS_AGNOSTIC:
eval_result = eval_result.results["box_proposal"]["AR@100"]
else:
eval_result = eval_result.results["bbox"]["AP"]
model.train()
if model_ema is not None and cfg.SOLVER.USE_EMA_FOR_MONITOR:
model_ema.ema.eval()
results_dict = {}
cpu_device = torch.device("cpu")
for i, batch in enumerate(val_data_loader):
images, targets, image_ids, positive_map, positive_map_eval = batch
with torch.no_grad():
images = images.to(device)
if positive_map is None:
output = model_ema.ema(images)
else:
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
output = model_ema.ema(images, captions, positive_map)
output = [o.to(cpu_device) for o in output]
results_dict.update({img_id: result for img_id, result in zip(image_ids, output)})
all_predictions = all_gather(results_dict)
if is_main_process():
predictions = {}
for p in all_predictions:
predictions.update(p)
predictions = [predictions[i] for i in list(sorted(predictions.keys()))]
eval_result, _ = evaluate(
val_data_loader.dataset, predictions, output_folder=None, box_only=cfg.DATASETS.CLASS_AGNOSTIC
)
if cfg.DATASETS.CLASS_AGNOSTIC:
eval_result = eval_result.results["box_proposal"]["AR@100"]
else:
eval_result = eval_result.results["bbox"]["AP"]
if eval_result is not None:
arguments.update(eval_result=eval_result)
if cfg.SOLVER.USE_AUTOSTEP:
assert eval_result is not None
eval_result = all_gather(eval_result)[0] # broadcast_data([eval_result])[0]
# print("Rank {} eval result gathered".format(cfg.local_rank), eval_result)
scheduler.step(eval_result)
if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1:
if eval_result < previous_best:
patience_counter += 1
else:
patience_counter = 0
previous_best = eval_result
checkpointer.save("model_best", **arguments)
print("Previous Best", previous_best, "Patience Counter", patience_counter, "Eval Result", eval_result)
if patience_counter >= cfg.SOLVER.AUTO_TERMINATE_PATIENCE:
if is_main_process():
print("\n\n\n\nAuto Termination at {}, current best {}\n\n\n".format(iteration, previous_best))
break
if iteration % checkpoint_period == 0 and checkpoint_period > 100:
checkpointer.save("model_{:07d}".format(iteration), **arguments)
if iteration == max_iter:
checkpointer.save("model_final", **arguments)
break
total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info("Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / (max_iter)))