Spaces:
Build error
Build error
| """ | |
| This file implements the training process and all the summaries | |
| """ | |
| import os | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from torch.nn.functional import pixel_shuffle, softmax | |
| from torch.utils.data import DataLoader | |
| import torch.utils.data.dataloader as torch_loader | |
| from tensorboardX import SummaryWriter | |
| from .dataset.dataset_util import get_dataset | |
| from .model.model_util import get_model | |
| from .model.loss import TotalLoss, get_loss_and_weights | |
| from .model.metrics import AverageMeter, Metrics, super_nms | |
| from .model.lr_scheduler import get_lr_scheduler | |
| from .misc.train_utils import ( | |
| convert_image, | |
| get_latest_checkpoint, | |
| remove_old_checkpoints, | |
| ) | |
| def customized_collate_fn(batch): | |
| """Customized collate_fn.""" | |
| batch_keys = ["image", "junction_map", "heatmap", "valid_mask"] | |
| list_keys = ["junctions", "line_map"] | |
| outputs = {} | |
| for key in batch_keys: | |
| outputs[key] = torch_loader.default_collate([b[key] for b in batch]) | |
| for key in list_keys: | |
| outputs[key] = [b[key] for b in batch] | |
| return outputs | |
| def restore_weights(model, state_dict, strict=True): | |
| """Restore weights in compatible mode.""" | |
| # Try to directly load state dict | |
| try: | |
| model.load_state_dict(state_dict, strict=strict) | |
| # Deal with some version compatibility issue (catch version incompatible) | |
| except: | |
| err = model.load_state_dict(state_dict, strict=False) | |
| # missing keys are those in model but not in state_dict | |
| missing_keys = err.missing_keys | |
| # Unexpected keys are those in state_dict but not in model | |
| unexpected_keys = err.unexpected_keys | |
| # Load mismatched keys manually | |
| model_dict = model.state_dict() | |
| for idx, key in enumerate(missing_keys): | |
| dict_keys = [_ for _ in unexpected_keys if not "tracked" in _] | |
| model_dict[key] = state_dict[dict_keys[idx]] | |
| model.load_state_dict(model_dict) | |
| return model | |
| def train_net(args, dataset_cfg, model_cfg, output_path): | |
| """Main training function.""" | |
| # Add some version compatibility check | |
| if model_cfg.get("weighting_policy") is None: | |
| # Default to static | |
| model_cfg["weighting_policy"] = "static" | |
| # Get the train, val, test config | |
| train_cfg = model_cfg["train"] | |
| test_cfg = model_cfg["test"] | |
| # Create train and test dataset | |
| print("\t Initializing dataset...") | |
| train_dataset, train_collate_fn = get_dataset("train", dataset_cfg) | |
| test_dataset, test_collate_fn = get_dataset("test", dataset_cfg) | |
| # Create the dataloader | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=train_cfg["batch_size"], | |
| num_workers=8, | |
| shuffle=True, | |
| pin_memory=True, | |
| collate_fn=train_collate_fn, | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=test_cfg.get("batch_size", 1), | |
| num_workers=test_cfg.get("num_workers", 1), | |
| shuffle=False, | |
| pin_memory=False, | |
| collate_fn=test_collate_fn, | |
| ) | |
| print("\t Successfully intialized dataloaders.") | |
| # Get the loss function and weight first | |
| loss_funcs, loss_weights = get_loss_and_weights(model_cfg) | |
| # If resume. | |
| if args.resume: | |
| # Create model and load the state dict | |
| checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name) | |
| model = get_model(model_cfg, loss_weights) | |
| model = restore_weights(model, checkpoint["model_state_dict"]) | |
| model = model.cuda() | |
| optimizer = torch.optim.Adam( | |
| [{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}], | |
| model_cfg["learning_rate"], | |
| amsgrad=True, | |
| ) | |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| # Optionally get the learning rate scheduler | |
| scheduler = get_lr_scheduler( | |
| lr_decay=model_cfg.get("lr_decay", False), | |
| lr_decay_cfg=model_cfg.get("lr_decay_cfg", None), | |
| optimizer=optimizer, | |
| ) | |
| # If we start to use learning rate scheduler from the middle | |
| if (scheduler is not None) and ( | |
| checkpoint.get("scheduler_state_dict", None) is not None | |
| ): | |
| scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) | |
| start_epoch = checkpoint["epoch"] + 1 | |
| # Initialize all the components. | |
| else: | |
| # Create model and optimizer | |
| model = get_model(model_cfg, loss_weights) | |
| # Optionally get the pretrained wieghts | |
| if args.pretrained: | |
| print("\t [Debug] Loading pretrained weights...") | |
| checkpoint = get_latest_checkpoint( | |
| args.pretrained_path, args.checkpoint_name | |
| ) | |
| # If auto weighting restore from non-auto weighting | |
| model = restore_weights(model, checkpoint["model_state_dict"], strict=False) | |
| print("\t [Debug] Finished loading pretrained weights!") | |
| model = model.cuda() | |
| optimizer = torch.optim.Adam( | |
| [{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}], | |
| model_cfg["learning_rate"], | |
| amsgrad=True, | |
| ) | |
| # Optionally get the learning rate scheduler | |
| scheduler = get_lr_scheduler( | |
| lr_decay=model_cfg.get("lr_decay", False), | |
| lr_decay_cfg=model_cfg.get("lr_decay_cfg", None), | |
| optimizer=optimizer, | |
| ) | |
| start_epoch = 0 | |
| print("\t Successfully initialized model") | |
| # Define the total loss | |
| policy = model_cfg.get("weighting_policy", "static") | |
| loss_func = TotalLoss(loss_funcs, loss_weights, policy).cuda() | |
| if "descriptor_decoder" in model_cfg: | |
| metric_func = Metrics( | |
| model_cfg["detection_thresh"], | |
| model_cfg["prob_thresh"], | |
| model_cfg["descriptor_loss_cfg"]["grid_size"], | |
| desc_metric_lst="all", | |
| ) | |
| else: | |
| metric_func = Metrics( | |
| model_cfg["detection_thresh"], | |
| model_cfg["prob_thresh"], | |
| model_cfg["grid_size"], | |
| ) | |
| # Define the summary writer | |
| logdir = os.path.join(output_path, "log") | |
| writer = SummaryWriter(logdir=logdir) | |
| # Start the training loop | |
| for epoch in range(start_epoch, model_cfg["epochs"]): | |
| # Record the learning rate | |
| current_lr = optimizer.state_dict()["param_groups"][0]["lr"] | |
| writer.add_scalar("LR/lr", current_lr, epoch) | |
| # Train for one epochs | |
| print("\n\n================== Training ====================") | |
| train_single_epoch( | |
| model=model, | |
| model_cfg=model_cfg, | |
| optimizer=optimizer, | |
| loss_func=loss_func, | |
| metric_func=metric_func, | |
| train_loader=train_loader, | |
| writer=writer, | |
| epoch=epoch, | |
| ) | |
| # Do the validation | |
| print("\n\n================== Validation ==================") | |
| validate( | |
| model=model, | |
| model_cfg=model_cfg, | |
| loss_func=loss_func, | |
| metric_func=metric_func, | |
| val_loader=test_loader, | |
| writer=writer, | |
| epoch=epoch, | |
| ) | |
| # Update the scheduler | |
| if scheduler is not None: | |
| scheduler.step() | |
| # Save checkpoints | |
| file_name = os.path.join(output_path, "checkpoint-epoch%03d-end.tar" % (epoch)) | |
| print("[Info] Saving checkpoint %s ..." % file_name) | |
| save_dict = { | |
| "epoch": epoch, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "model_cfg": model_cfg, | |
| } | |
| if scheduler is not None: | |
| save_dict.update({"scheduler_state_dict": scheduler.state_dict()}) | |
| torch.save(save_dict, file_name) | |
| # Remove the outdated checkpoints | |
| remove_old_checkpoints(output_path, model_cfg.get("max_ckpt", 15)) | |
| def train_single_epoch( | |
| model, model_cfg, optimizer, loss_func, metric_func, train_loader, writer, epoch | |
| ): | |
| """Train for one epoch.""" | |
| # Switch the model to training mode | |
| model.train() | |
| # Initialize the average meter | |
| compute_descriptors = loss_func.compute_descriptors | |
| if compute_descriptors: | |
| average_meter = AverageMeter(is_training=True, desc_metric_lst="all") | |
| else: | |
| average_meter = AverageMeter(is_training=True) | |
| # The training loop | |
| for idx, data in enumerate(train_loader): | |
| if compute_descriptors: | |
| junc_map = data["ref_junction_map"].cuda() | |
| junc_map2 = data["target_junction_map"].cuda() | |
| heatmap = data["ref_heatmap"].cuda() | |
| heatmap2 = data["target_heatmap"].cuda() | |
| line_points = data["ref_line_points"].cuda() | |
| line_points2 = data["target_line_points"].cuda() | |
| line_indices = data["ref_line_indices"].cuda() | |
| valid_mask = data["ref_valid_mask"].cuda() | |
| valid_mask2 = data["target_valid_mask"].cuda() | |
| input_images = data["ref_image"].cuda() | |
| input_images2 = data["target_image"].cuda() | |
| # Run the forward pass | |
| outputs = model(input_images) | |
| outputs2 = model(input_images2) | |
| # Compute losses | |
| losses = loss_func.forward_descriptors( | |
| outputs["junctions"], | |
| outputs2["junctions"], | |
| junc_map, | |
| junc_map2, | |
| outputs["heatmap"], | |
| outputs2["heatmap"], | |
| heatmap, | |
| heatmap2, | |
| line_points, | |
| line_points2, | |
| line_indices, | |
| outputs["descriptors"], | |
| outputs2["descriptors"], | |
| epoch, | |
| valid_mask, | |
| valid_mask2, | |
| ) | |
| else: | |
| junc_map = data["junction_map"].cuda() | |
| heatmap = data["heatmap"].cuda() | |
| valid_mask = data["valid_mask"].cuda() | |
| input_images = data["image"].cuda() | |
| # Run the forward pass | |
| outputs = model(input_images) | |
| # Compute losses | |
| losses = loss_func( | |
| outputs["junctions"], junc_map, outputs["heatmap"], heatmap, valid_mask | |
| ) | |
| total_loss = losses["total_loss"] | |
| # Update the model | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| optimizer.step() | |
| # Compute the global step | |
| global_step = epoch * len(train_loader) + idx | |
| ############## Measure the metric error ######################### | |
| # Only do this when needed | |
| if ((idx % model_cfg["disp_freq"]) == 0) or ( | |
| (idx % model_cfg["summary_freq"]) == 0 | |
| ): | |
| junc_np = convert_junc_predictions( | |
| outputs["junctions"], | |
| model_cfg["grid_size"], | |
| model_cfg["detection_thresh"], | |
| 300, | |
| ) | |
| junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1) | |
| # Always fetch only one channel (compatible with L1, L2, and CE) | |
| if outputs["heatmap"].shape[1] == 2: | |
| heatmap_np = softmax(outputs["heatmap"].detach(), dim=1).cpu().numpy() | |
| heatmap_np = heatmap_np.transpose(0, 2, 3, 1)[:, :, :, 1:] | |
| else: | |
| heatmap_np = torch.sigmoid(outputs["heatmap"].detach()) | |
| heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1) | |
| heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1) | |
| valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1) | |
| # Evaluate metric results | |
| if compute_descriptors: | |
| metric_func.evaluate( | |
| junc_np["junc_pred"], | |
| junc_np["junc_pred_nms"], | |
| junc_map_np, | |
| heatmap_np, | |
| heatmap_gt_np, | |
| valid_mask_np, | |
| line_points, | |
| line_points2, | |
| outputs["descriptors"], | |
| outputs2["descriptors"], | |
| line_indices, | |
| ) | |
| else: | |
| metric_func.evaluate( | |
| junc_np["junc_pred"], | |
| junc_np["junc_pred_nms"], | |
| junc_map_np, | |
| heatmap_np, | |
| heatmap_gt_np, | |
| valid_mask_np, | |
| ) | |
| # Update average meter | |
| junc_loss = losses["junc_loss"].item() | |
| heatmap_loss = losses["heatmap_loss"].item() | |
| loss_dict = { | |
| "junc_loss": junc_loss, | |
| "heatmap_loss": heatmap_loss, | |
| "total_loss": total_loss.item(), | |
| } | |
| if compute_descriptors: | |
| descriptor_loss = losses["descriptor_loss"].item() | |
| loss_dict["descriptor_loss"] = losses["descriptor_loss"].item() | |
| average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0]) | |
| # Display the progress | |
| if (idx % model_cfg["disp_freq"]) == 0: | |
| results = metric_func.metric_results | |
| average = average_meter.average() | |
| # Get gpu memory usage in GB | |
| gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3) | |
| if compute_descriptors: | |
| print( | |
| "Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f), gpu_mem=%.4fGB" | |
| % ( | |
| epoch, | |
| model_cfg["epochs"], | |
| idx, | |
| len(train_loader), | |
| total_loss.item(), | |
| average["total_loss"], | |
| junc_loss, | |
| average["junc_loss"], | |
| heatmap_loss, | |
| average["heatmap_loss"], | |
| descriptor_loss, | |
| average["descriptor_loss"], | |
| gpu_mem_usage, | |
| ) | |
| ) | |
| else: | |
| print( | |
| "Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), gpu_mem=%.4fGB" | |
| % ( | |
| epoch, | |
| model_cfg["epochs"], | |
| idx, | |
| len(train_loader), | |
| total_loss.item(), | |
| average["total_loss"], | |
| junc_loss, | |
| average["junc_loss"], | |
| heatmap_loss, | |
| average["heatmap_loss"], | |
| gpu_mem_usage, | |
| ) | |
| ) | |
| print( | |
| "\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" | |
| % ( | |
| results["junc_precision"], | |
| average["junc_precision"], | |
| results["junc_recall"], | |
| average["junc_recall"], | |
| ) | |
| ) | |
| print( | |
| "\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" | |
| % ( | |
| results["junc_precision_nms"], | |
| average["junc_precision_nms"], | |
| results["junc_recall_nms"], | |
| average["junc_recall_nms"], | |
| ) | |
| ) | |
| print( | |
| "\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" | |
| % ( | |
| results["heatmap_precision"], | |
| average["heatmap_precision"], | |
| results["heatmap_recall"], | |
| average["heatmap_recall"], | |
| ) | |
| ) | |
| if compute_descriptors: | |
| print( | |
| "\t Descriptors matching score=%.4f (%.4f)" | |
| % (results["matching_score"], average["matching_score"]) | |
| ) | |
| # Record summaries | |
| if (idx % model_cfg["summary_freq"]) == 0: | |
| results = metric_func.metric_results | |
| average = average_meter.average() | |
| # Add the shared losses | |
| scalar_summaries = { | |
| "junc_loss": junc_loss, | |
| "heatmap_loss": heatmap_loss, | |
| "total_loss": total_loss.detach().cpu().numpy(), | |
| "metrics": results, | |
| "average": average, | |
| } | |
| # Add descriptor terms | |
| if compute_descriptors: | |
| scalar_summaries["descriptor_loss"] = descriptor_loss | |
| scalar_summaries["w_desc"] = losses["w_desc"] | |
| # Add weighting terms (even for static terms) | |
| scalar_summaries["w_junc"] = losses["w_junc"] | |
| scalar_summaries["w_heatmap"] = losses["w_heatmap"] | |
| scalar_summaries["reg_loss"] = losses["reg_loss"].item() | |
| num_images = 3 | |
| junc_pred_binary = ( | |
| junc_np["junc_pred"][:num_images, ...] > model_cfg["detection_thresh"] | |
| ) | |
| junc_pred_nms_binary = ( | |
| junc_np["junc_pred_nms"][:num_images, ...] | |
| > model_cfg["detection_thresh"] | |
| ) | |
| image_summaries = { | |
| "image": input_images.cpu().numpy()[:num_images, ...], | |
| "valid_mask": valid_mask_np[:num_images, ...], | |
| "junc_map_pred": junc_pred_binary, | |
| "junc_map_pred_nms": junc_pred_nms_binary, | |
| "junc_map_gt": junc_map_np[:num_images, ...], | |
| "junc_prob_map": junc_np["junc_prob"][:num_images, ...], | |
| "heatmap_pred": heatmap_np[:num_images, ...], | |
| "heatmap_gt": heatmap_gt_np[:num_images, ...], | |
| } | |
| # Record the training summary | |
| record_train_summaries( | |
| writer, global_step, scalars=scalar_summaries, images=image_summaries | |
| ) | |
| def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch): | |
| """Validation.""" | |
| # Switch the model to eval mode | |
| model.eval() | |
| # Initialize the average meter | |
| compute_descriptors = loss_func.compute_descriptors | |
| if compute_descriptors: | |
| average_meter = AverageMeter(is_training=True, desc_metric_lst="all") | |
| else: | |
| average_meter = AverageMeter(is_training=True) | |
| # The validation loop | |
| for idx, data in enumerate(val_loader): | |
| if compute_descriptors: | |
| junc_map = data["ref_junction_map"].cuda() | |
| junc_map2 = data["target_junction_map"].cuda() | |
| heatmap = data["ref_heatmap"].cuda() | |
| heatmap2 = data["target_heatmap"].cuda() | |
| line_points = data["ref_line_points"].cuda() | |
| line_points2 = data["target_line_points"].cuda() | |
| line_indices = data["ref_line_indices"].cuda() | |
| valid_mask = data["ref_valid_mask"].cuda() | |
| valid_mask2 = data["target_valid_mask"].cuda() | |
| input_images = data["ref_image"].cuda() | |
| input_images2 = data["target_image"].cuda() | |
| # Run the forward pass | |
| with torch.no_grad(): | |
| outputs = model(input_images) | |
| outputs2 = model(input_images2) | |
| # Compute losses | |
| losses = loss_func.forward_descriptors( | |
| outputs["junctions"], | |
| outputs2["junctions"], | |
| junc_map, | |
| junc_map2, | |
| outputs["heatmap"], | |
| outputs2["heatmap"], | |
| heatmap, | |
| heatmap2, | |
| line_points, | |
| line_points2, | |
| line_indices, | |
| outputs["descriptors"], | |
| outputs2["descriptors"], | |
| epoch, | |
| valid_mask, | |
| valid_mask2, | |
| ) | |
| else: | |
| junc_map = data["junction_map"].cuda() | |
| heatmap = data["heatmap"].cuda() | |
| valid_mask = data["valid_mask"].cuda() | |
| input_images = data["image"].cuda() | |
| # Run the forward pass | |
| with torch.no_grad(): | |
| outputs = model(input_images) | |
| # Compute losses | |
| losses = loss_func( | |
| outputs["junctions"], | |
| junc_map, | |
| outputs["heatmap"], | |
| heatmap, | |
| valid_mask, | |
| ) | |
| total_loss = losses["total_loss"] | |
| ############## Measure the metric error ######################### | |
| junc_np = convert_junc_predictions( | |
| outputs["junctions"], | |
| model_cfg["grid_size"], | |
| model_cfg["detection_thresh"], | |
| 300, | |
| ) | |
| junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1) | |
| # Always fetch only one channel (compatible with L1, L2, and CE) | |
| if outputs["heatmap"].shape[1] == 2: | |
| heatmap_np = ( | |
| softmax(outputs["heatmap"].detach(), dim=1) | |
| .cpu() | |
| .numpy() | |
| .transpose(0, 2, 3, 1) | |
| ) | |
| heatmap_np = heatmap_np[:, :, :, 1:] | |
| else: | |
| heatmap_np = torch.sigmoid(outputs["heatmap"].detach()) | |
| heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1) | |
| heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1) | |
| valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1) | |
| # Evaluate metric results | |
| if compute_descriptors: | |
| metric_func.evaluate( | |
| junc_np["junc_pred"], | |
| junc_np["junc_pred_nms"], | |
| junc_map_np, | |
| heatmap_np, | |
| heatmap_gt_np, | |
| valid_mask_np, | |
| line_points, | |
| line_points2, | |
| outputs["descriptors"], | |
| outputs2["descriptors"], | |
| line_indices, | |
| ) | |
| else: | |
| metric_func.evaluate( | |
| junc_np["junc_pred"], | |
| junc_np["junc_pred_nms"], | |
| junc_map_np, | |
| heatmap_np, | |
| heatmap_gt_np, | |
| valid_mask_np, | |
| ) | |
| # Update average meter | |
| junc_loss = losses["junc_loss"].item() | |
| heatmap_loss = losses["heatmap_loss"].item() | |
| loss_dict = { | |
| "junc_loss": junc_loss, | |
| "heatmap_loss": heatmap_loss, | |
| "total_loss": total_loss.item(), | |
| } | |
| if compute_descriptors: | |
| descriptor_loss = losses["descriptor_loss"].item() | |
| loss_dict["descriptor_loss"] = losses["descriptor_loss"].item() | |
| average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0]) | |
| # Display the progress | |
| if (idx % model_cfg["disp_freq"]) == 0: | |
| results = metric_func.metric_results | |
| average = average_meter.average() | |
| if compute_descriptors: | |
| print( | |
| "Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f)" | |
| % ( | |
| idx, | |
| len(val_loader), | |
| total_loss.item(), | |
| average["total_loss"], | |
| junc_loss, | |
| average["junc_loss"], | |
| heatmap_loss, | |
| average["heatmap_loss"], | |
| descriptor_loss, | |
| average["descriptor_loss"], | |
| ) | |
| ) | |
| else: | |
| print( | |
| "Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f)" | |
| % ( | |
| idx, | |
| len(val_loader), | |
| total_loss.item(), | |
| average["total_loss"], | |
| junc_loss, | |
| average["junc_loss"], | |
| heatmap_loss, | |
| average["heatmap_loss"], | |
| ) | |
| ) | |
| print( | |
| "\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" | |
| % ( | |
| results["junc_precision"], | |
| average["junc_precision"], | |
| results["junc_recall"], | |
| average["junc_recall"], | |
| ) | |
| ) | |
| print( | |
| "\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" | |
| % ( | |
| results["junc_precision_nms"], | |
| average["junc_precision_nms"], | |
| results["junc_recall_nms"], | |
| average["junc_recall_nms"], | |
| ) | |
| ) | |
| print( | |
| "\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" | |
| % ( | |
| results["heatmap_precision"], | |
| average["heatmap_precision"], | |
| results["heatmap_recall"], | |
| average["heatmap_recall"], | |
| ) | |
| ) | |
| if compute_descriptors: | |
| print( | |
| "\t Descriptors matching score=%.4f (%.4f)" | |
| % (results["matching_score"], average["matching_score"]) | |
| ) | |
| # Record summaries | |
| average = average_meter.average() | |
| scalar_summaries = {"average": average} | |
| # Record the training summary | |
| record_test_summaries(writer, epoch, scalar_summaries) | |
| def convert_junc_predictions(predictions, grid_size, detect_thresh=1 / 65, topk=300): | |
| """Convert torch predictions to numpy arrays for evaluation.""" | |
| # Convert to probability outputs first | |
| junc_prob = softmax(predictions.detach(), dim=1).cpu() | |
| junc_pred = junc_prob[:, :-1, :, :] | |
| junc_prob_np = junc_prob.numpy().transpose(0, 2, 3, 1)[:, :, :, :-1] | |
| junc_prob_np = np.sum(junc_prob_np, axis=-1) | |
| junc_pred_np = ( | |
| pixel_shuffle(junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1) | |
| ) | |
| junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk) | |
| junc_pred_np = junc_pred_np.squeeze(-1) | |
| return { | |
| "junc_pred": junc_pred_np, | |
| "junc_pred_nms": junc_pred_np_nms, | |
| "junc_prob": junc_prob_np, | |
| } | |
| def record_train_summaries(writer, global_step, scalars, images): | |
| """Record training summaries.""" | |
| # Record the scalar summaries | |
| results = scalars["metrics"] | |
| average = scalars["average"] | |
| # GPU memory part | |
| # Get gpu memory usage in GB | |
| gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3) | |
| writer.add_scalar("GPU/GPU_memory_usage", gpu_mem_usage, global_step) | |
| # Loss part | |
| writer.add_scalar("Train_loss/junc_loss", scalars["junc_loss"], global_step) | |
| writer.add_scalar("Train_loss/heatmap_loss", scalars["heatmap_loss"], global_step) | |
| writer.add_scalar("Train_loss/total_loss", scalars["total_loss"], global_step) | |
| # Add regularization loss | |
| if "reg_loss" in scalars.keys(): | |
| writer.add_scalar("Train_loss/reg_loss", scalars["reg_loss"], global_step) | |
| # Add descriptor loss | |
| if "descriptor_loss" in scalars.keys(): | |
| key = "descriptor_loss" | |
| writer.add_scalar("Train_loss/%s" % (key), scalars[key], global_step) | |
| writer.add_scalar("Train_loss_average/%s" % (key), average[key], global_step) | |
| # Record weighting | |
| for key in scalars.keys(): | |
| if "w_" in key: | |
| writer.add_scalar("Train_weight/%s" % (key), scalars[key], global_step) | |
| # Smoothed loss | |
| writer.add_scalar("Train_loss_average/junc_loss", average["junc_loss"], global_step) | |
| writer.add_scalar( | |
| "Train_loss_average/heatmap_loss", average["heatmap_loss"], global_step | |
| ) | |
| writer.add_scalar( | |
| "Train_loss_average/total_loss", average["total_loss"], global_step | |
| ) | |
| # Add smoothed descriptor loss | |
| if "descriptor_loss" in average.keys(): | |
| writer.add_scalar( | |
| "Train_loss_average/descriptor_loss", | |
| average["descriptor_loss"], | |
| global_step, | |
| ) | |
| # Metrics part | |
| writer.add_scalar( | |
| "Train_metrics/junc_precision", results["junc_precision"], global_step | |
| ) | |
| writer.add_scalar( | |
| "Train_metrics/junc_precision_nms", results["junc_precision_nms"], global_step | |
| ) | |
| writer.add_scalar("Train_metrics/junc_recall", results["junc_recall"], global_step) | |
| writer.add_scalar( | |
| "Train_metrics/junc_recall_nms", results["junc_recall_nms"], global_step | |
| ) | |
| writer.add_scalar( | |
| "Train_metrics/heatmap_precision", results["heatmap_precision"], global_step | |
| ) | |
| writer.add_scalar( | |
| "Train_metrics/heatmap_recall", results["heatmap_recall"], global_step | |
| ) | |
| # Add descriptor metric | |
| if "matching_score" in results.keys(): | |
| writer.add_scalar( | |
| "Train_metrics/matching_score", results["matching_score"], global_step | |
| ) | |
| # Average part | |
| writer.add_scalar( | |
| "Train_metrics_average/junc_precision", average["junc_precision"], global_step | |
| ) | |
| writer.add_scalar( | |
| "Train_metrics_average/junc_precision_nms", | |
| average["junc_precision_nms"], | |
| global_step, | |
| ) | |
| writer.add_scalar( | |
| "Train_metrics_average/junc_recall", average["junc_recall"], global_step | |
| ) | |
| writer.add_scalar( | |
| "Train_metrics_average/junc_recall_nms", average["junc_recall_nms"], global_step | |
| ) | |
| writer.add_scalar( | |
| "Train_metrics_average/heatmap_precision", | |
| average["heatmap_precision"], | |
| global_step, | |
| ) | |
| writer.add_scalar( | |
| "Train_metrics_average/heatmap_recall", average["heatmap_recall"], global_step | |
| ) | |
| # Add smoothed descriptor metric | |
| if "matching_score" in average.keys(): | |
| writer.add_scalar( | |
| "Train_metrics_average/matching_score", | |
| average["matching_score"], | |
| global_step, | |
| ) | |
| # Record the image summary | |
| # Image part | |
| image_tensor = convert_image(images["image"], 1) | |
| valid_masks = convert_image(images["valid_mask"], -1) | |
| writer.add_images("Train/images", image_tensor, global_step, dataformats="NCHW") | |
| writer.add_images("Train/valid_map", valid_masks, global_step, dataformats="NHWC") | |
| # Heatmap part | |
| writer.add_images( | |
| "Train/heatmap_gt", | |
| convert_image(images["heatmap_gt"], -1), | |
| global_step, | |
| dataformats="NHWC", | |
| ) | |
| writer.add_images( | |
| "Train/heatmap_pred", | |
| convert_image(images["heatmap_pred"], -1), | |
| global_step, | |
| dataformats="NHWC", | |
| ) | |
| # Junction prediction part | |
| junc_plots = plot_junction_detection( | |
| image_tensor, | |
| images["junc_map_pred"], | |
| images["junc_map_pred_nms"], | |
| images["junc_map_gt"], | |
| ) | |
| writer.add_images( | |
| "Train/junc_gt", | |
| junc_plots["junc_gt_plot"] / 255.0, | |
| global_step, | |
| dataformats="NHWC", | |
| ) | |
| writer.add_images( | |
| "Train/junc_pred", | |
| junc_plots["junc_pred_plot"] / 255.0, | |
| global_step, | |
| dataformats="NHWC", | |
| ) | |
| writer.add_images( | |
| "Train/junc_pred_nms", | |
| junc_plots["junc_pred_nms_plot"] / 255.0, | |
| global_step, | |
| dataformats="NHWC", | |
| ) | |
| writer.add_images( | |
| "Train/junc_prob_map", | |
| convert_image(images["junc_prob_map"][..., None], axis=-1), | |
| global_step, | |
| dataformats="NHWC", | |
| ) | |
| def record_test_summaries(writer, epoch, scalars): | |
| """Record testing summaries.""" | |
| average = scalars["average"] | |
| # Average loss | |
| writer.add_scalar("Val_loss/junc_loss", average["junc_loss"], epoch) | |
| writer.add_scalar("Val_loss/heatmap_loss", average["heatmap_loss"], epoch) | |
| writer.add_scalar("Val_loss/total_loss", average["total_loss"], epoch) | |
| # Add descriptor loss | |
| if "descriptor_loss" in average.keys(): | |
| key = "descriptor_loss" | |
| writer.add_scalar("Val_loss/%s" % (key), average[key], epoch) | |
| # Average metrics | |
| writer.add_scalar("Val_metrics/junc_precision", average["junc_precision"], epoch) | |
| writer.add_scalar( | |
| "Val_metrics/junc_precision_nms", average["junc_precision_nms"], epoch | |
| ) | |
| writer.add_scalar("Val_metrics/junc_recall", average["junc_recall"], epoch) | |
| writer.add_scalar("Val_metrics/junc_recall_nms", average["junc_recall_nms"], epoch) | |
| writer.add_scalar( | |
| "Val_metrics/heatmap_precision", average["heatmap_precision"], epoch | |
| ) | |
| writer.add_scalar("Val_metrics/heatmap_recall", average["heatmap_recall"], epoch) | |
| # Add descriptor metric | |
| if "matching_score" in average.keys(): | |
| writer.add_scalar( | |
| "Val_metrics/matching_score", average["matching_score"], epoch | |
| ) | |
| def plot_junction_detection( | |
| image_tensor, junc_pred_tensor, junc_pred_nms_tensor, junc_gt_tensor | |
| ): | |
| """Plot the junction points on images.""" | |
| # Get the batch_size | |
| batch_size = image_tensor.shape[0] | |
| # Process through batch dimension | |
| junc_pred_lst = [] | |
| junc_pred_nms_lst = [] | |
| junc_gt_lst = [] | |
| for i in range(batch_size): | |
| # Convert image to 255 uint8 | |
| image = (image_tensor[i, :, :, :] * 255.0).astype(np.uint8).transpose(1, 2, 0) | |
| # Plot groundtruth onto image | |
| junc_gt = junc_gt_tensor[i, ...] | |
| coord_gt = np.where(junc_gt.squeeze() > 0) | |
| points_gt = np.concatenate( | |
| (coord_gt[0][..., None], coord_gt[1][..., None]), axis=1 | |
| ) | |
| plot_gt = image.copy() | |
| for id in range(points_gt.shape[0]): | |
| cv2.circle( | |
| plot_gt, | |
| tuple(np.flip(points_gt[id, :])), | |
| 3, | |
| color=(255, 0, 0), | |
| thickness=2, | |
| ) | |
| junc_gt_lst.append(plot_gt[None, ...]) | |
| # Plot junc_pred | |
| junc_pred = junc_pred_tensor[i, ...] | |
| coord_pred = np.where(junc_pred > 0) | |
| points_pred = np.concatenate( | |
| (coord_pred[0][..., None], coord_pred[1][..., None]), axis=1 | |
| ) | |
| plot_pred = image.copy() | |
| for id in range(points_pred.shape[0]): | |
| cv2.circle( | |
| plot_pred, | |
| tuple(np.flip(points_pred[id, :])), | |
| 3, | |
| color=(0, 255, 0), | |
| thickness=2, | |
| ) | |
| junc_pred_lst.append(plot_pred[None, ...]) | |
| # Plot junc_pred_nms | |
| junc_pred_nms = junc_pred_nms_tensor[i, ...] | |
| coord_pred_nms = np.where(junc_pred_nms > 0) | |
| points_pred_nms = np.concatenate( | |
| (coord_pred_nms[0][..., None], coord_pred_nms[1][..., None]), axis=1 | |
| ) | |
| plot_pred_nms = image.copy() | |
| for id in range(points_pred_nms.shape[0]): | |
| cv2.circle( | |
| plot_pred_nms, | |
| tuple(np.flip(points_pred_nms[id, :])), | |
| 3, | |
| color=(0, 255, 0), | |
| thickness=2, | |
| ) | |
| junc_pred_nms_lst.append(plot_pred_nms[None, ...]) | |
| return { | |
| "junc_gt_plot": np.concatenate(junc_gt_lst, axis=0), | |
| "junc_pred_plot": np.concatenate(junc_pred_lst, axis=0), | |
| "junc_pred_nms_plot": np.concatenate(junc_pred_nms_lst, axis=0), | |
| } | |