|
import torch |
|
import torch.nn as nn |
|
import os |
|
import time |
|
import datetime |
|
import argparse |
|
from pathlib import Path |
|
from torch.utils.data import DataLoader |
|
from arguments import get_args_parser |
|
from datasets.outdoor_buildings import OutdoorBuildingDataset |
|
from datasets.s3d_floorplans import S3DFloorplanDataset |
|
from datasets.data_utils import collate_fn, get_pixel_features |
|
from models.corner_models import HeatCorner |
|
from models.edge_models import HeatEdge |
|
from models.resnet import ResNetBackbone |
|
from models.loss import CornerCriterion, EdgeCriterion |
|
from models.corner_to_edge import prepare_edge_data |
|
import utils.misc as utils |
|
|
|
|
|
def train_one_epoch(image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, data_loader, |
|
optimizer, |
|
epoch, max_norm, args): |
|
backbone.train() |
|
corner_model.train() |
|
edge_model.train() |
|
corner_criterion.train() |
|
edge_criterion.train() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=100, fmt='{value:.6f}')) |
|
header = 'Epoch: [{}]'.format(epoch) |
|
print_freq = args.print_freq |
|
|
|
|
|
pixels, pixel_features = get_pixel_features(image_size) |
|
pixel_features = pixel_features.cuda() |
|
|
|
for data in metric_logger.log_every(data_loader, print_freq, header): |
|
corner_outputs, corner_loss, corner_recall, s1_logits, s2_logits_hb, s2_logits_rel, s1_losses, s2_losses_hb, \ |
|
s2_losses_rel, s1_acc, s2_acc_hb, s2_acc_rel = run_model( |
|
data, |
|
pixels, |
|
pixel_features, |
|
backbone, |
|
corner_model, |
|
edge_model, |
|
epoch, |
|
corner_criterion, |
|
edge_criterion, |
|
args) |
|
|
|
loss = s1_losses + s2_losses_hb + s2_losses_rel + corner_loss * args.lambda_corner |
|
|
|
loss_dict = {'loss_e_s1': s1_losses, 'loss_e_s2_hb': s2_losses_hb, 'loss_e_s2_rel': s2_losses_rel, |
|
'edge_acc_s1': s1_acc, 'edge_acc_s2_hb': s2_acc_hb, 'edge_acc_s2_rel': s2_acc_rel, |
|
'loss_c_s1': corner_loss, 'corner_recall': corner_recall} |
|
loss_value = loss.item() |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
|
|
if max_norm > 0: |
|
torch.nn.utils.clip_grad_norm_(backbone.parameters(), max_norm) |
|
torch.nn.utils.clip_grad_norm_(corner_model.parameters(), max_norm) |
|
torch.nn.utils.clip_grad_norm_(edge_model.parameters(), max_norm) |
|
|
|
optimizer.step() |
|
metric_logger.update(loss=loss_value, **loss_dict) |
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
|
|
|
print("Averaged stats:", metric_logger) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
def run_model(data, pixels, pixel_features, backbone, corner_model, edge_model, epoch, corner_criterion, edge_criterion, |
|
args): |
|
image = data['img'].cuda() |
|
annots = data['annot'] |
|
raw_images = data['raw_img'] |
|
pixel_labels = data['pixel_labels'].cuda() |
|
gauss_labels = data['gauss_labels'].cuda() |
|
|
|
pixel_features = pixel_features.unsqueeze(0).repeat(image.shape[0], 1, 1, 1) |
|
|
|
|
|
image_feats, feat_mask, all_image_feats = backbone(image) |
|
preds_s1 = corner_model(image_feats, feat_mask, pixel_features, pixels, all_image_feats) |
|
|
|
corner_loss_s1, corner_recall = corner_criterion(preds_s1, pixel_labels, gauss_labels, epoch) |
|
|
|
|
|
c_outputs = preds_s1 |
|
edge_data = prepare_edge_data(c_outputs, annots, raw_images, args.max_corner_num) |
|
|
|
edge_coords = edge_data['edge_coords'].cuda() |
|
edge_mask = edge_data['edge_coords_mask'].cuda() |
|
edge_lengths = edge_data['edge_coords_lengths'].cuda() |
|
edge_labels = edge_data['edge_labels'].cuda() |
|
corner_nums = edge_data['processed_corners_lengths'] |
|
|
|
|
|
max_candidates = torch.stack([corner_nums.max() * args.corner_to_edge_multiplier] * len(corner_nums), dim=0) |
|
logits_s1, logits_s2_hb, logits_s2_rel, s2_ids, s2_edge_mask, s2_gt_values = edge_model(image_feats, feat_mask, |
|
pixel_features, |
|
edge_coords, edge_mask, |
|
edge_labels, |
|
corner_nums, |
|
max_candidates) |
|
|
|
s1_losses, s1_acc, s2_losses_hb, s2_acc_hb, s2_losses_rel, s2_acc_rel = edge_criterion(logits_s1, logits_s2_hb, |
|
logits_s2_rel, s2_ids, |
|
s2_edge_mask, |
|
edge_labels, edge_lengths, |
|
edge_mask, s2_gt_values) |
|
|
|
return c_outputs, corner_loss_s1, corner_recall, logits_s1, logits_s2_hb, logits_s2_rel, s1_losses, s2_losses_hb, \ |
|
s2_losses_rel, s1_acc, s2_acc_hb, s2_acc_rel |
|
|
|
|
|
@torch.no_grad() |
|
def evaluate(image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, data_loader, epoch, |
|
args): |
|
backbone.eval() |
|
corner_model.eval() |
|
edge_model.eval() |
|
corner_criterion.eval() |
|
edge_criterion.eval() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = 'Test:' |
|
|
|
pixels, pixel_features = get_pixel_features(image_size) |
|
pixel_features = pixel_features.cuda() |
|
|
|
for data in metric_logger.log_every(data_loader, 10, header): |
|
c_outputs, corner_loss, corner_recall, s1_logits, \ |
|
s2_logits_hb, s2_logits_rel, s1_losses, s2_losses_hb, s2_losses_rel, s1_acc, s2_acc_hb, s2_acc_rel = run_model( |
|
data, |
|
pixels, |
|
pixel_features, |
|
backbone, |
|
corner_model, |
|
edge_model, |
|
epoch, |
|
corner_criterion, |
|
edge_criterion, |
|
args) |
|
|
|
loss_dict = {'loss_e_s1': s1_losses, |
|
'loss_e_s2_hb': s2_losses_hb, |
|
'loss_e_s2_rel': s2_losses_rel, |
|
'edge_acc_s1': s1_acc, |
|
'edge_acc_s2_hb': s2_acc_hb, |
|
'edge_acc_s2_rel': s2_acc_rel, |
|
'loss_c_s1': corner_loss, |
|
'corner_recall': corner_recall} |
|
|
|
loss = s1_losses + s2_losses_hb + s2_losses_rel + corner_loss * args.lambda_corner |
|
loss_value = loss.item() |
|
metric_logger.update(loss=loss_value, **loss_dict) |
|
|
|
print("Averaged stats:", metric_logger) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser('HEAT training', parents=[get_args_parser()]) |
|
args = parser.parse_args() |
|
image_size = args.image_size |
|
if args.exp_dataset == 'outdoor': |
|
data_path = './data/outdoor/cities_dataset' |
|
det_path = './data/outdoor/det_final' |
|
train_dataset = OutdoorBuildingDataset(data_path, det_path, phase='train', image_size=image_size, rand_aug=True, |
|
inference=False) |
|
test_dataset = OutdoorBuildingDataset(data_path, det_path, phase='valid', image_size=image_size, rand_aug=False, |
|
inference=False) |
|
elif args.exp_dataset == 's3d_floorplan': |
|
data_path = './data/s3d_floorplan' |
|
train_dataset = S3DFloorplanDataset(data_path, phase='train', rand_aug=True, inference=False) |
|
test_dataset = S3DFloorplanDataset(data_path, phase='valid', rand_aug=False, inference=False) |
|
else: |
|
raise ValueError('Unknown dataset: {}'.format(args.exp_dataset)) |
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, |
|
collate_fn=collate_fn, drop_last=True) |
|
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=args.num_workers, |
|
collate_fn=collate_fn) |
|
|
|
backbone = ResNetBackbone() |
|
strides = backbone.strides |
|
num_channels = backbone.num_channels |
|
|
|
corner_model = HeatCorner(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides, |
|
backbone_num_channels=num_channels) |
|
backbone = nn.DataParallel(backbone) |
|
backbone = backbone.cuda() |
|
corner_model = nn.DataParallel(corner_model) |
|
corner_model = corner_model.cuda() |
|
|
|
edge_model = HeatEdge(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides, |
|
backbone_num_channels=num_channels) |
|
edge_model = nn.DataParallel(edge_model) |
|
edge_model = edge_model.cuda() |
|
|
|
corner_criterion = CornerCriterion(image_size=image_size) |
|
edge_criterion = EdgeCriterion() |
|
|
|
backbone_params = [p for p in backbone.parameters()] |
|
corner_params = [p for p in corner_model.parameters()] |
|
edge_params = [p for p in edge_model.parameters()] |
|
|
|
all_params = corner_params + edge_params + backbone_params |
|
optimizer = torch.optim.AdamW(all_params, lr=args.lr, weight_decay=args.weight_decay) |
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) |
|
start_epoch = args.start_epoch |
|
|
|
if args.resume: |
|
ckpt = torch.load(args.resume) |
|
backbone.load_state_dict(ckpt['backbone']) |
|
corner_model.load_state_dict(ckpt['corner_model']) |
|
edge_model.load_state_dict(ckpt['edge_model']) |
|
optimizer.load_state_dict(ckpt['optimizer']) |
|
lr_scheduler.load_state_dict(ckpt['lr_scheduler']) |
|
lr_scheduler.step_size = args.lr_drop |
|
|
|
print('Resume from ckpt file {}, starting from epoch {}'.format(args.resume, ckpt['epoch'])) |
|
start_epoch = ckpt['epoch'] + 1 |
|
|
|
n_backbone_parameters = sum(p.numel() for p in backbone_params if p.requires_grad) |
|
n_corner_parameters = sum(p.numel() for p in corner_params if p.requires_grad) |
|
n_edge_parameters = sum(p.numel() for p in edge_params if p.requires_grad) |
|
n_all_parameters = sum(p.numel() for p in all_params if p.requires_grad) |
|
print('number of trainable backbone params:', n_backbone_parameters) |
|
print('number of trainable corner params:', n_corner_parameters) |
|
print('number of trainable edge params:', n_edge_parameters) |
|
print('number of all trainable params:', n_all_parameters) |
|
|
|
print("Start training") |
|
start_time = time.time() |
|
|
|
output_dir = Path(args.output_dir) |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
best_acc = 0 |
|
for epoch in range(start_epoch, args.epochs): |
|
train_stats = train_one_epoch( |
|
image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, train_dataloader, |
|
optimizer, |
|
epoch, args.clip_max_norm, args) |
|
lr_scheduler.step() |
|
|
|
if args.run_validation: |
|
val_stats = evaluate( |
|
image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, test_dataloader, |
|
epoch, args |
|
) |
|
|
|
val_acc = (val_stats['edge_acc_s1'] + val_stats['edge_acc_s2_hb']) / 2 |
|
if val_acc > best_acc: |
|
is_best = True |
|
best_acc = val_acc |
|
else: |
|
is_best = False |
|
else: |
|
val_acc = 0 |
|
is_best = False |
|
|
|
if args.output_dir: |
|
checkpoint_paths = [output_dir / 'checkpoint.pth'] |
|
if is_best: |
|
checkpoint_paths.append(output_dir / 'checkpoint_best.pth') |
|
|
|
for checkpoint_path in checkpoint_paths: |
|
torch.save({ |
|
'backbone': backbone.state_dict(), |
|
'corner_model': corner_model.state_dict(), |
|
'edge_model': edge_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'lr_scheduler': lr_scheduler.state_dict(), |
|
'epoch': epoch, |
|
'args': args, |
|
'val_acc': val_acc, |
|
}, checkpoint_path) |
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print('Training time {}'.format(total_time_str)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|