HEAT / train.py
Egrt's picture
init
424188c
import torch
import torch.nn as nn
import os
import time
import datetime
import argparse
from pathlib import Path
from torch.utils.data import DataLoader
from arguments import get_args_parser
from datasets.outdoor_buildings import OutdoorBuildingDataset
from datasets.s3d_floorplans import S3DFloorplanDataset
from datasets.data_utils import collate_fn, get_pixel_features
from models.corner_models import HeatCorner
from models.edge_models import HeatEdge
from models.resnet import ResNetBackbone
from models.loss import CornerCriterion, EdgeCriterion
from models.corner_to_edge import prepare_edge_data
import utils.misc as utils
def train_one_epoch(image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, data_loader,
optimizer,
epoch, max_norm, args):
backbone.train()
corner_model.train()
edge_model.train()
corner_criterion.train()
edge_criterion.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=100, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = args.print_freq
# get the positional encodings for all pixels
pixels, pixel_features = get_pixel_features(image_size)
pixel_features = pixel_features.cuda()
for data in metric_logger.log_every(data_loader, print_freq, header):
corner_outputs, corner_loss, corner_recall, s1_logits, s2_logits_hb, s2_logits_rel, s1_losses, s2_losses_hb, \
s2_losses_rel, s1_acc, s2_acc_hb, s2_acc_rel = run_model(
data,
pixels,
pixel_features,
backbone,
corner_model,
edge_model,
epoch,
corner_criterion,
edge_criterion,
args)
loss = s1_losses + s2_losses_hb + s2_losses_rel + corner_loss * args.lambda_corner
loss_dict = {'loss_e_s1': s1_losses, 'loss_e_s2_hb': s2_losses_hb, 'loss_e_s2_rel': s2_losses_rel,
'edge_acc_s1': s1_acc, 'edge_acc_s2_hb': s2_acc_hb, 'edge_acc_s2_rel': s2_acc_rel,
'loss_c_s1': corner_loss, 'corner_recall': corner_recall}
loss_value = loss.item()
optimizer.zero_grad()
loss.backward()
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(backbone.parameters(), max_norm)
torch.nn.utils.clip_grad_norm_(corner_model.parameters(), max_norm)
torch.nn.utils.clip_grad_norm_(edge_model.parameters(), max_norm)
optimizer.step()
metric_logger.update(loss=loss_value, **loss_dict)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def run_model(data, pixels, pixel_features, backbone, corner_model, edge_model, epoch, corner_criterion, edge_criterion,
args):
image = data['img'].cuda()
annots = data['annot']
raw_images = data['raw_img']
pixel_labels = data['pixel_labels'].cuda()
gauss_labels = data['gauss_labels'].cuda()
pixel_features = pixel_features.unsqueeze(0).repeat(image.shape[0], 1, 1, 1)
# get corner preds from corner model
image_feats, feat_mask, all_image_feats = backbone(image)
preds_s1 = corner_model(image_feats, feat_mask, pixel_features, pixels, all_image_feats)
corner_loss_s1, corner_recall = corner_criterion(preds_s1, pixel_labels, gauss_labels, epoch)
# get edge candidates and corresponding G.T.
c_outputs = preds_s1
edge_data = prepare_edge_data(c_outputs, annots, raw_images, args.max_corner_num)
edge_coords = edge_data['edge_coords'].cuda()
edge_mask = edge_data['edge_coords_mask'].cuda()
edge_lengths = edge_data['edge_coords_lengths'].cuda()
edge_labels = edge_data['edge_labels'].cuda()
corner_nums = edge_data['processed_corners_lengths']
# run the edge model
max_candidates = torch.stack([corner_nums.max() * args.corner_to_edge_multiplier] * len(corner_nums), dim=0)
logits_s1, logits_s2_hb, logits_s2_rel, s2_ids, s2_edge_mask, s2_gt_values = edge_model(image_feats, feat_mask,
pixel_features,
edge_coords, edge_mask,
edge_labels,
corner_nums,
max_candidates)
s1_losses, s1_acc, s2_losses_hb, s2_acc_hb, s2_losses_rel, s2_acc_rel = edge_criterion(logits_s1, logits_s2_hb,
logits_s2_rel, s2_ids,
s2_edge_mask,
edge_labels, edge_lengths,
edge_mask, s2_gt_values)
return c_outputs, corner_loss_s1, corner_recall, logits_s1, logits_s2_hb, logits_s2_rel, s1_losses, s2_losses_hb, \
s2_losses_rel, s1_acc, s2_acc_hb, s2_acc_rel
@torch.no_grad()
def evaluate(image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, data_loader, epoch,
args):
backbone.eval()
corner_model.eval()
edge_model.eval()
corner_criterion.eval()
edge_criterion.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
pixels, pixel_features = get_pixel_features(image_size)
pixel_features = pixel_features.cuda()
for data in metric_logger.log_every(data_loader, 10, header):
c_outputs, corner_loss, corner_recall, s1_logits, \
s2_logits_hb, s2_logits_rel, s1_losses, s2_losses_hb, s2_losses_rel, s1_acc, s2_acc_hb, s2_acc_rel = run_model(
data,
pixels,
pixel_features,
backbone,
corner_model,
edge_model,
epoch,
corner_criterion,
edge_criterion,
args)
loss_dict = {'loss_e_s1': s1_losses,
'loss_e_s2_hb': s2_losses_hb,
'loss_e_s2_rel': s2_losses_rel,
'edge_acc_s1': s1_acc,
'edge_acc_s2_hb': s2_acc_hb,
'edge_acc_s2_rel': s2_acc_rel,
'loss_c_s1': corner_loss,
'corner_recall': corner_recall}
loss = s1_losses + s2_losses_hb + s2_losses_rel + corner_loss * args.lambda_corner
loss_value = loss.item()
metric_logger.update(loss=loss_value, **loss_dict)
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def main():
parser = argparse.ArgumentParser('HEAT training', parents=[get_args_parser()])
args = parser.parse_args()
image_size = args.image_size
if args.exp_dataset == 'outdoor':
data_path = './data/outdoor/cities_dataset'
det_path = './data/outdoor/det_final'
train_dataset = OutdoorBuildingDataset(data_path, det_path, phase='train', image_size=image_size, rand_aug=True,
inference=False)
test_dataset = OutdoorBuildingDataset(data_path, det_path, phase='valid', image_size=image_size, rand_aug=False,
inference=False)
elif args.exp_dataset == 's3d_floorplan':
data_path = './data/s3d_floorplan'
train_dataset = S3DFloorplanDataset(data_path, phase='train', rand_aug=True, inference=False)
test_dataset = S3DFloorplanDataset(data_path, phase='valid', rand_aug=False, inference=False)
else:
raise ValueError('Unknown dataset: {}'.format(args.exp_dataset))
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
collate_fn=collate_fn, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=args.num_workers,
collate_fn=collate_fn)
backbone = ResNetBackbone()
strides = backbone.strides
num_channels = backbone.num_channels
corner_model = HeatCorner(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
backbone_num_channels=num_channels)
backbone = nn.DataParallel(backbone)
backbone = backbone.cuda()
corner_model = nn.DataParallel(corner_model)
corner_model = corner_model.cuda()
edge_model = HeatEdge(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
backbone_num_channels=num_channels)
edge_model = nn.DataParallel(edge_model)
edge_model = edge_model.cuda()
corner_criterion = CornerCriterion(image_size=image_size)
edge_criterion = EdgeCriterion()
backbone_params = [p for p in backbone.parameters()]
corner_params = [p for p in corner_model.parameters()]
edge_params = [p for p in edge_model.parameters()]
all_params = corner_params + edge_params + backbone_params
optimizer = torch.optim.AdamW(all_params, lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
start_epoch = args.start_epoch
if args.resume:
ckpt = torch.load(args.resume)
backbone.load_state_dict(ckpt['backbone'])
corner_model.load_state_dict(ckpt['corner_model'])
edge_model.load_state_dict(ckpt['edge_model'])
optimizer.load_state_dict(ckpt['optimizer'])
lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
lr_scheduler.step_size = args.lr_drop
print('Resume from ckpt file {}, starting from epoch {}'.format(args.resume, ckpt['epoch']))
start_epoch = ckpt['epoch'] + 1
n_backbone_parameters = sum(p.numel() for p in backbone_params if p.requires_grad)
n_corner_parameters = sum(p.numel() for p in corner_params if p.requires_grad)
n_edge_parameters = sum(p.numel() for p in edge_params if p.requires_grad)
n_all_parameters = sum(p.numel() for p in all_params if p.requires_grad)
print('number of trainable backbone params:', n_backbone_parameters)
print('number of trainable corner params:', n_corner_parameters)
print('number of trainable edge params:', n_edge_parameters)
print('number of all trainable params:', n_all_parameters)
print("Start training")
start_time = time.time()
output_dir = Path(args.output_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
best_acc = 0
for epoch in range(start_epoch, args.epochs):
train_stats = train_one_epoch(
image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, train_dataloader,
optimizer,
epoch, args.clip_max_norm, args)
lr_scheduler.step()
if args.run_validation:
val_stats = evaluate(
image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, test_dataloader,
epoch, args
)
val_acc = (val_stats['edge_acc_s1'] + val_stats['edge_acc_s2_hb']) / 2
if val_acc > best_acc:
is_best = True
best_acc = val_acc
else:
is_best = False
else:
val_acc = 0
is_best = False
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
if is_best:
checkpoint_paths.append(output_dir / 'checkpoint_best.pth')
for checkpoint_path in checkpoint_paths:
torch.save({
'backbone': backbone.state_dict(),
'corner_model': corner_model.state_dict(),
'edge_model': edge_model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
'val_acc': val_acc,
}, checkpoint_path)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
main()