Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import datetime | |
import logging | |
import sys | |
import os | |
import math | |
import time | |
import torch | |
import torch.distributed as dist | |
from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank | |
from maskrcnn_benchmark.utils.metric_logger import MetricLogger | |
from maskrcnn_benchmark.utils.ema import ModelEma | |
from maskrcnn_benchmark.utils.amp import autocast, GradScaler | |
from maskrcnn_benchmark.data.datasets.evaluation import evaluate | |
from .inference import inference | |
from .tsv_saver import TSVResultWriter | |
import wandb | |
import pdb, random | |
def reduce_loss_dict(loss_dict): | |
""" | |
Reduce the loss dictionary from all processes so that process with rank | |
0 has the averaged results. Returns a dict with the same fields as | |
loss_dict, after reduction. | |
""" | |
world_size = get_world_size() | |
if world_size < 2: | |
return loss_dict | |
with torch.no_grad(): | |
loss_names = [] | |
all_losses = [] | |
for k in sorted(loss_dict.keys()): | |
loss_names.append(k) | |
all_losses.append(loss_dict[k]) | |
all_losses = torch.stack(all_losses, dim=0) | |
dist.reduce(all_losses, dst=0) | |
if dist.get_rank() == 0: | |
# only main process gets accumulated, so only divide by | |
# world_size in this case | |
all_losses /= world_size | |
reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} | |
return reduced_losses | |
def do_train( | |
cfg, | |
model, | |
data_loader, | |
optimizer, | |
scheduler, | |
checkpointer, | |
device, | |
checkpoint_period, | |
arguments, | |
val_data_loader=None, | |
meters=None, | |
use_wandb=False | |
): | |
logger = logging.getLogger("maskrcnn_benchmark.trainer") | |
logger.info("Start training") | |
# meters = MetricLogger(delimiter=" ") | |
max_iter = len(data_loader) | |
start_iter = arguments["iteration"] | |
model.train() | |
model_ema = None | |
if cfg.SOLVER.MODEL_EMA > 0: | |
model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA) | |
start_training_time = time.time() | |
end = time.time() | |
if cfg.SOLVER.USE_AMP: | |
scaler = GradScaler() | |
global_rank = get_rank() | |
if cfg.SOLVER.CHECKPOINT_PER_EPOCH != -1 and cfg.SOLVER.MAX_EPOCH >= 1: | |
checkpoint_period = len(data_loader) * cfg.SOLVER.CHECKPOINT_PER_EPOCH // cfg.SOLVER.MAX_EPOCH | |
if global_rank <= 0 and cfg.SOLVER.MAX_EPOCH >= 1: | |
print("Iter per epoch ", len(data_loader) // cfg.SOLVER.MAX_EPOCH) | |
if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1: | |
patience_counter = 0 | |
previous_best = 0.0 | |
# Adapt the weight decay | |
if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, "milestones"): | |
milestone_target = 0 | |
for i, milstone in enumerate(list(scheduler.milestones)): | |
if scheduler.last_epoch >= milstone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO: | |
milestone_target = i + 1 | |
# try to visualize the training data | |
### get the tokenizer | |
if hasattr(model, "module"): | |
tokenizer = model.module.tokenizer | |
else: | |
tokenizer = model.tokenizer | |
tsv_visualizer = TSVResultWriter( | |
tokenizer=tokenizer, | |
max_visualize_num=1000, | |
file_name=cfg.OUTPUT_DIR + "/train_visualize/train.tsv", write_freq=100) | |
for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate( | |
data_loader, start_iter | |
): | |
nnegative = sum(len(target) < 1 for target in targets) | |
nsample = len(targets) | |
if nnegative > nsample * cfg.SOLVER.MAX_NEG_PER_BATCH: | |
logger.info( | |
"[WARNING] Sampled {} negative in {} in a batch, greater the allowed ratio {}, skip".format( | |
nnegative, nsample, cfg.SOLVER.MAX_NEG_PER_BATCH | |
) | |
) | |
continue | |
data_time = time.time() - end | |
iteration = iteration + 1 | |
arguments["iteration"] = iteration | |
images = images.to(device) | |
captions = None | |
try: | |
targets = [target.to(device) for target in targets] | |
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] | |
except: | |
pass | |
# Freeze language backbone | |
if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: | |
if hasattr(model, "module"): | |
if hasattr(model.module, "fusion_backbone"): | |
model.module.fusion_backbone.language_backbone.eval() | |
else: | |
model.module.language_backbone.eval() | |
else: | |
if hasattr(model, "fusion_backbone"): | |
model.fusion_backbone.language_backbone.eval() | |
else: | |
model.language_backbone.eval() | |
if is_main_process(): # only visualize for the main process | |
tsv_visualizer.update_train_data(images, targets) | |
if cfg.SOLVER.USE_AMP: | |
with autocast(): | |
if len(captions) > 0: | |
loss_dict = model(images, targets, captions, positive_map, greenlight_map=greenlight_map) | |
else: | |
loss_dict = model(images, targets) | |
losses = sum(loss for loss in loss_dict.values()) | |
# save checkpoints for further debug if nan happens | |
loss_value = losses.item() | |
if torch.isnan(losses) or torch.isinf(losses): | |
logging.error("NaN encountered, ignoring") | |
losses[losses != losses] = 0 | |
# if loss is too large, ignore it | |
if loss_value > 10 and iteration > 10000: | |
losses[losses == losses] = 0 # this is a bad example | |
print("Loss is too large, ignore it, loss: ", loss_value) | |
optimizer.zero_grad() | |
scaler.scale(losses).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
scheduler.step() | |
else: | |
if len(captions) > 0: | |
loss_dict = model(images, targets, captions, positive_map) | |
else: | |
loss_dict = model(images, targets) | |
losses = sum(loss for loss in loss_dict.values()) | |
# save checkpoints for further debug if nan happens | |
loss_value = losses.item() | |
if not math.isfinite(loss_value): | |
logging.error(f"=> loss is {loss_value}, stopping training") | |
time_str = time.strftime("%Y-%m-%d-%H-%M") | |
fname = os.path.join(checkpointer.save_dir, f"{time_str}_states.pth") | |
logging.info(f"=> save error state to {fname}") | |
dict_to_save = { | |
"x": images, | |
"y": targets, | |
"loss": losses, | |
"states": model.module.state_dict() if hasattr(model, "module") else model.state_dict(), | |
} | |
if len(captions) > 0: | |
dict_to_save["captions"] = captions | |
dict_to_save["positive_map"] = positive_map | |
torch.save(dict_to_save, fname) | |
sys.exit(-1) | |
if torch.isnan(losses) or torch.isinf(losses): | |
losses[losses != losses] = 0 | |
# if loss is too large, ignore it | |
# if loss_value > 10 and iteration > 10000: | |
# losses[losses == losses] = 0 # this is a bad example | |
# print("Loss is too large, ignore it, loss: ", loss_value) | |
optimizer.zero_grad() | |
losses.backward() | |
optimizer.step() | |
scheduler.step() | |
# Adapt the weight decay: only support multiStepLR | |
if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, "milestones"): | |
if milestone_target < len(scheduler.milestones): | |
next_milestone = list(scheduler.milestones)[milestone_target] | |
else: | |
next_milestone = float("inf") | |
if scheduler.last_epoch >= next_milestone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO: | |
gamma = scheduler.gamma | |
logger.info("Drop the weight decay by {}!".format(gamma)) | |
for param in optimizer.param_groups: | |
if "weight_decay" in param: | |
param["weight_decay"] *= gamma | |
# move the target forward | |
milestone_target += 1 | |
# reduce losses over all GPUs for logging purposes | |
loss_dict_reduced = reduce_loss_dict(loss_dict) | |
losses_reduced = sum(loss for loss in loss_dict_reduced.values()) | |
meters.update(loss=losses_reduced, **loss_dict_reduced) | |
if model_ema is not None: | |
model_ema.update(model) | |
arguments["model_ema"] = model_ema.state_dict() | |
batch_time = time.time() - end | |
end = time.time() | |
meters.update(time=batch_time, data=data_time) | |
eta_seconds = meters.time.global_avg * (max_iter - iteration) | |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
if iteration % 20 == 0 or iteration == max_iter: | |
# if iteration % 1 == 0 or iteration == max_iter: | |
# logger.info( | |
if global_rank <= 0: | |
print( | |
meters.delimiter.join( | |
[ | |
"eta: {eta}", | |
"iter: {iter}", | |
"{meters}", | |
"lr: {lr:.6f}", | |
"wd: {wd:.6f}", | |
"max mem: {memory:.0f}", | |
] | |
).format( | |
eta=eta_string, | |
iter=iteration, | |
meters=str(meters), | |
lr=optimizer.param_groups[0]["lr"], | |
wd=optimizer.param_groups[0]["weight_decay"], | |
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, | |
) | |
) | |
if use_wandb and is_main_process(): | |
wandb.log({"train_loss": losses_reduced, "lr": optimizer.param_groups[0]["lr"], "wd": optimizer.param_groups[0]["weight_decay"], **loss_dict_reduced}) | |
if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter): | |
if is_main_process(): | |
print("Evaluating") | |
eval_result = 0.0 | |
model.eval() | |
if cfg.SOLVER.TEST_WITH_INFERENCE: | |
with torch.no_grad(): | |
try: | |
_model = model.module | |
except: | |
_model = model | |
_result = inference( | |
model=_model, | |
data_loader=val_data_loader, | |
dataset_name="val", | |
device=device, | |
expected_results=cfg.TEST.EXPECTED_RESULTS, | |
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, | |
output_folder=None, | |
cfg=cfg, | |
verbose=False, | |
) | |
if is_main_process(): | |
try: | |
eval_result = _result[0].results["bbox"]["AP"] | |
except: | |
pass | |
else: | |
results_dict = {} | |
cpu_device = torch.device("cpu") | |
for i, batch in enumerate(val_data_loader): | |
images, targets, image_ids, positive_map, *_ = batch | |
with torch.no_grad(): | |
images = images.to(device) | |
if positive_map is None: | |
output = model(images) | |
else: | |
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] | |
output = model(images, captions, positive_map) | |
output = [o.to(cpu_device) for o in output] | |
results_dict.update({img_id: result for img_id, result in zip(image_ids, output)}) | |
all_predictions = all_gather(results_dict) | |
if is_main_process(): | |
predictions = {} | |
for p in all_predictions: | |
predictions.update(p) | |
predictions = [predictions[i] for i in list(sorted(predictions.keys()))] | |
eval_result, _ = evaluate( | |
val_data_loader.dataset, predictions, output_folder=None, box_only=cfg.DATASETS.CLASS_AGNOSTIC | |
) | |
if cfg.DATASETS.CLASS_AGNOSTIC: | |
eval_result = eval_result.results["box_proposal"]["AR@100"] | |
else: | |
eval_result = eval_result.results["bbox"]["AP"] | |
model.train() | |
if model_ema is not None and cfg.SOLVER.USE_EMA_FOR_MONITOR: | |
model_ema.ema.eval() | |
results_dict = {} | |
cpu_device = torch.device("cpu") | |
for i, batch in enumerate(val_data_loader): | |
images, targets, image_ids, positive_map, positive_map_eval = batch | |
with torch.no_grad(): | |
images = images.to(device) | |
if positive_map is None: | |
output = model_ema.ema(images) | |
else: | |
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] | |
output = model_ema.ema(images, captions, positive_map) | |
output = [o.to(cpu_device) for o in output] | |
results_dict.update({img_id: result for img_id, result in zip(image_ids, output)}) | |
all_predictions = all_gather(results_dict) | |
if is_main_process(): | |
predictions = {} | |
for p in all_predictions: | |
predictions.update(p) | |
predictions = [predictions[i] for i in list(sorted(predictions.keys()))] | |
eval_result, _ = evaluate( | |
val_data_loader.dataset, predictions, output_folder=None, box_only=cfg.DATASETS.CLASS_AGNOSTIC | |
) | |
if cfg.DATASETS.CLASS_AGNOSTIC: | |
eval_result = eval_result.results["box_proposal"]["AR@100"] | |
else: | |
eval_result = eval_result.results["bbox"]["AP"] | |
if eval_result is not None: | |
arguments.update(eval_result=eval_result) | |
if cfg.SOLVER.USE_AUTOSTEP: | |
assert eval_result is not None | |
eval_result = all_gather(eval_result)[0] # broadcast_data([eval_result])[0] | |
# print("Rank {} eval result gathered".format(cfg.local_rank), eval_result) | |
scheduler.step(eval_result) | |
if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1: | |
if eval_result < previous_best: | |
patience_counter += 1 | |
else: | |
patience_counter = 0 | |
previous_best = eval_result | |
checkpointer.save("model_best", **arguments) | |
print("Previous Best", previous_best, "Patience Counter", patience_counter, "Eval Result", eval_result) | |
if patience_counter >= cfg.SOLVER.AUTO_TERMINATE_PATIENCE: | |
if is_main_process(): | |
print("\n\n\n\nAuto Termination at {}, current best {}\n\n\n".format(iteration, previous_best)) | |
break | |
if iteration % checkpoint_period == 0 and checkpoint_period > 100: | |
checkpointer.save("model_{:07d}".format(iteration), **arguments) | |
if iteration == max_iter: | |
checkpointer.save("model_final", **arguments) | |
break | |
total_training_time = time.time() - start_training_time | |
total_time_str = str(datetime.timedelta(seconds=total_training_time)) | |
logger.info("Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / (max_iter))) | |