""" This file implements the training process and all the summaries """ import os import numpy as np import cv2 import torch from torch.nn.functional import pixel_shuffle, softmax from torch.utils.data import DataLoader import torch.utils.data.dataloader as torch_loader from tensorboardX import SummaryWriter from .dataset.dataset_util import get_dataset from .model.model_util import get_model from .model.loss import TotalLoss, get_loss_and_weights from .model.metrics import AverageMeter, Metrics, super_nms from .model.lr_scheduler import get_lr_scheduler from .misc.train_utils import ( convert_image, get_latest_checkpoint, remove_old_checkpoints, ) def customized_collate_fn(batch): """Customized collate_fn.""" batch_keys = ["image", "junction_map", "heatmap", "valid_mask"] list_keys = ["junctions", "line_map"] outputs = {} for key in batch_keys: outputs[key] = torch_loader.default_collate([b[key] for b in batch]) for key in list_keys: outputs[key] = [b[key] for b in batch] return outputs def restore_weights(model, state_dict, strict=True): """Restore weights in compatible mode.""" # Try to directly load state dict try: model.load_state_dict(state_dict, strict=strict) # Deal with some version compatibility issue (catch version incompatible) except: err = model.load_state_dict(state_dict, strict=False) # missing keys are those in model but not in state_dict missing_keys = err.missing_keys # Unexpected keys are those in state_dict but not in model unexpected_keys = err.unexpected_keys # Load mismatched keys manually model_dict = model.state_dict() for idx, key in enumerate(missing_keys): dict_keys = [_ for _ in unexpected_keys if not "tracked" in _] model_dict[key] = state_dict[dict_keys[idx]] model.load_state_dict(model_dict) return model def train_net(args, dataset_cfg, model_cfg, output_path): """Main training function.""" # Add some version compatibility check if model_cfg.get("weighting_policy") is None: # Default to static model_cfg["weighting_policy"] = "static" # Get the train, val, test config train_cfg = model_cfg["train"] test_cfg = model_cfg["test"] # Create train and test dataset print("\t Initializing dataset...") train_dataset, train_collate_fn = get_dataset("train", dataset_cfg) test_dataset, test_collate_fn = get_dataset("test", dataset_cfg) # Create the dataloader train_loader = DataLoader( train_dataset, batch_size=train_cfg["batch_size"], num_workers=8, shuffle=True, pin_memory=True, collate_fn=train_collate_fn, ) test_loader = DataLoader( test_dataset, batch_size=test_cfg.get("batch_size", 1), num_workers=test_cfg.get("num_workers", 1), shuffle=False, pin_memory=False, collate_fn=test_collate_fn, ) print("\t Successfully intialized dataloaders.") # Get the loss function and weight first loss_funcs, loss_weights = get_loss_and_weights(model_cfg) # If resume. if args.resume: # Create model and load the state dict checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name) model = get_model(model_cfg, loss_weights) model = restore_weights(model, checkpoint["model_state_dict"]) model = model.cuda() optimizer = torch.optim.Adam( [{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}], model_cfg["learning_rate"], amsgrad=True, ) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) # Optionally get the learning rate scheduler scheduler = get_lr_scheduler( lr_decay=model_cfg.get("lr_decay", False), lr_decay_cfg=model_cfg.get("lr_decay_cfg", None), optimizer=optimizer, ) # If we start to use learning rate scheduler from the middle if (scheduler is not None) and ( checkpoint.get("scheduler_state_dict", None) is not None ): scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) start_epoch = checkpoint["epoch"] + 1 # Initialize all the components. else: # Create model and optimizer model = get_model(model_cfg, loss_weights) # Optionally get the pretrained wieghts if args.pretrained: print("\t [Debug] Loading pretrained weights...") checkpoint = get_latest_checkpoint( args.pretrained_path, args.checkpoint_name ) # If auto weighting restore from non-auto weighting model = restore_weights(model, checkpoint["model_state_dict"], strict=False) print("\t [Debug] Finished loading pretrained weights!") model = model.cuda() optimizer = torch.optim.Adam( [{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}], model_cfg["learning_rate"], amsgrad=True, ) # Optionally get the learning rate scheduler scheduler = get_lr_scheduler( lr_decay=model_cfg.get("lr_decay", False), lr_decay_cfg=model_cfg.get("lr_decay_cfg", None), optimizer=optimizer, ) start_epoch = 0 print("\t Successfully initialized model") # Define the total loss policy = model_cfg.get("weighting_policy", "static") loss_func = TotalLoss(loss_funcs, loss_weights, policy).cuda() if "descriptor_decoder" in model_cfg: metric_func = Metrics( model_cfg["detection_thresh"], model_cfg["prob_thresh"], model_cfg["descriptor_loss_cfg"]["grid_size"], desc_metric_lst="all", ) else: metric_func = Metrics( model_cfg["detection_thresh"], model_cfg["prob_thresh"], model_cfg["grid_size"], ) # Define the summary writer logdir = os.path.join(output_path, "log") writer = SummaryWriter(logdir=logdir) # Start the training loop for epoch in range(start_epoch, model_cfg["epochs"]): # Record the learning rate current_lr = optimizer.state_dict()["param_groups"][0]["lr"] writer.add_scalar("LR/lr", current_lr, epoch) # Train for one epochs print("\n\n================== Training ====================") train_single_epoch( model=model, model_cfg=model_cfg, optimizer=optimizer, loss_func=loss_func, metric_func=metric_func, train_loader=train_loader, writer=writer, epoch=epoch, ) # Do the validation print("\n\n================== Validation ==================") validate( model=model, model_cfg=model_cfg, loss_func=loss_func, metric_func=metric_func, val_loader=test_loader, writer=writer, epoch=epoch, ) # Update the scheduler if scheduler is not None: scheduler.step() # Save checkpoints file_name = os.path.join(output_path, "checkpoint-epoch%03d-end.tar" % (epoch)) print("[Info] Saving checkpoint %s ..." % file_name) save_dict = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "model_cfg": model_cfg, } if scheduler is not None: save_dict.update({"scheduler_state_dict": scheduler.state_dict()}) torch.save(save_dict, file_name) # Remove the outdated checkpoints remove_old_checkpoints(output_path, model_cfg.get("max_ckpt", 15)) def train_single_epoch( model, model_cfg, optimizer, loss_func, metric_func, train_loader, writer, epoch ): """Train for one epoch.""" # Switch the model to training mode model.train() # Initialize the average meter compute_descriptors = loss_func.compute_descriptors if compute_descriptors: average_meter = AverageMeter(is_training=True, desc_metric_lst="all") else: average_meter = AverageMeter(is_training=True) # The training loop for idx, data in enumerate(train_loader): if compute_descriptors: junc_map = data["ref_junction_map"].cuda() junc_map2 = data["target_junction_map"].cuda() heatmap = data["ref_heatmap"].cuda() heatmap2 = data["target_heatmap"].cuda() line_points = data["ref_line_points"].cuda() line_points2 = data["target_line_points"].cuda() line_indices = data["ref_line_indices"].cuda() valid_mask = data["ref_valid_mask"].cuda() valid_mask2 = data["target_valid_mask"].cuda() input_images = data["ref_image"].cuda() input_images2 = data["target_image"].cuda() # Run the forward pass outputs = model(input_images) outputs2 = model(input_images2) # Compute losses losses = loss_func.forward_descriptors( outputs["junctions"], outputs2["junctions"], junc_map, junc_map2, outputs["heatmap"], outputs2["heatmap"], heatmap, heatmap2, line_points, line_points2, line_indices, outputs["descriptors"], outputs2["descriptors"], epoch, valid_mask, valid_mask2, ) else: junc_map = data["junction_map"].cuda() heatmap = data["heatmap"].cuda() valid_mask = data["valid_mask"].cuda() input_images = data["image"].cuda() # Run the forward pass outputs = model(input_images) # Compute losses losses = loss_func( outputs["junctions"], junc_map, outputs["heatmap"], heatmap, valid_mask ) total_loss = losses["total_loss"] # Update the model optimizer.zero_grad() total_loss.backward() optimizer.step() # Compute the global step global_step = epoch * len(train_loader) + idx ############## Measure the metric error ######################### # Only do this when needed if ((idx % model_cfg["disp_freq"]) == 0) or ( (idx % model_cfg["summary_freq"]) == 0 ): junc_np = convert_junc_predictions( outputs["junctions"], model_cfg["grid_size"], model_cfg["detection_thresh"], 300, ) junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1) # Always fetch only one channel (compatible with L1, L2, and CE) if outputs["heatmap"].shape[1] == 2: heatmap_np = softmax(outputs["heatmap"].detach(), dim=1).cpu().numpy() heatmap_np = heatmap_np.transpose(0, 2, 3, 1)[:, :, :, 1:] else: heatmap_np = torch.sigmoid(outputs["heatmap"].detach()) heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1) heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1) valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1) # Evaluate metric results if compute_descriptors: metric_func.evaluate( junc_np["junc_pred"], junc_np["junc_pred_nms"], junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np, line_points, line_points2, outputs["descriptors"], outputs2["descriptors"], line_indices, ) else: metric_func.evaluate( junc_np["junc_pred"], junc_np["junc_pred_nms"], junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np, ) # Update average meter junc_loss = losses["junc_loss"].item() heatmap_loss = losses["heatmap_loss"].item() loss_dict = { "junc_loss": junc_loss, "heatmap_loss": heatmap_loss, "total_loss": total_loss.item(), } if compute_descriptors: descriptor_loss = losses["descriptor_loss"].item() loss_dict["descriptor_loss"] = losses["descriptor_loss"].item() average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0]) # Display the progress if (idx % model_cfg["disp_freq"]) == 0: results = metric_func.metric_results average = average_meter.average() # Get gpu memory usage in GB gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3) if compute_descriptors: print( "Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f), gpu_mem=%.4fGB" % ( epoch, model_cfg["epochs"], idx, len(train_loader), total_loss.item(), average["total_loss"], junc_loss, average["junc_loss"], heatmap_loss, average["heatmap_loss"], descriptor_loss, average["descriptor_loss"], gpu_mem_usage, ) ) else: print( "Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), gpu_mem=%.4fGB" % ( epoch, model_cfg["epochs"], idx, len(train_loader), total_loss.item(), average["total_loss"], junc_loss, average["junc_loss"], heatmap_loss, average["heatmap_loss"], gpu_mem_usage, ) ) print( "\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" % ( results["junc_precision"], average["junc_precision"], results["junc_recall"], average["junc_recall"], ) ) print( "\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" % ( results["junc_precision_nms"], average["junc_precision_nms"], results["junc_recall_nms"], average["junc_recall_nms"], ) ) print( "\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" % ( results["heatmap_precision"], average["heatmap_precision"], results["heatmap_recall"], average["heatmap_recall"], ) ) if compute_descriptors: print( "\t Descriptors matching score=%.4f (%.4f)" % (results["matching_score"], average["matching_score"]) ) # Record summaries if (idx % model_cfg["summary_freq"]) == 0: results = metric_func.metric_results average = average_meter.average() # Add the shared losses scalar_summaries = { "junc_loss": junc_loss, "heatmap_loss": heatmap_loss, "total_loss": total_loss.detach().cpu().numpy(), "metrics": results, "average": average, } # Add descriptor terms if compute_descriptors: scalar_summaries["descriptor_loss"] = descriptor_loss scalar_summaries["w_desc"] = losses["w_desc"] # Add weighting terms (even for static terms) scalar_summaries["w_junc"] = losses["w_junc"] scalar_summaries["w_heatmap"] = losses["w_heatmap"] scalar_summaries["reg_loss"] = losses["reg_loss"].item() num_images = 3 junc_pred_binary = ( junc_np["junc_pred"][:num_images, ...] > model_cfg["detection_thresh"] ) junc_pred_nms_binary = ( junc_np["junc_pred_nms"][:num_images, ...] > model_cfg["detection_thresh"] ) image_summaries = { "image": input_images.cpu().numpy()[:num_images, ...], "valid_mask": valid_mask_np[:num_images, ...], "junc_map_pred": junc_pred_binary, "junc_map_pred_nms": junc_pred_nms_binary, "junc_map_gt": junc_map_np[:num_images, ...], "junc_prob_map": junc_np["junc_prob"][:num_images, ...], "heatmap_pred": heatmap_np[:num_images, ...], "heatmap_gt": heatmap_gt_np[:num_images, ...], } # Record the training summary record_train_summaries( writer, global_step, scalars=scalar_summaries, images=image_summaries ) def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch): """Validation.""" # Switch the model to eval mode model.eval() # Initialize the average meter compute_descriptors = loss_func.compute_descriptors if compute_descriptors: average_meter = AverageMeter(is_training=True, desc_metric_lst="all") else: average_meter = AverageMeter(is_training=True) # The validation loop for idx, data in enumerate(val_loader): if compute_descriptors: junc_map = data["ref_junction_map"].cuda() junc_map2 = data["target_junction_map"].cuda() heatmap = data["ref_heatmap"].cuda() heatmap2 = data["target_heatmap"].cuda() line_points = data["ref_line_points"].cuda() line_points2 = data["target_line_points"].cuda() line_indices = data["ref_line_indices"].cuda() valid_mask = data["ref_valid_mask"].cuda() valid_mask2 = data["target_valid_mask"].cuda() input_images = data["ref_image"].cuda() input_images2 = data["target_image"].cuda() # Run the forward pass with torch.no_grad(): outputs = model(input_images) outputs2 = model(input_images2) # Compute losses losses = loss_func.forward_descriptors( outputs["junctions"], outputs2["junctions"], junc_map, junc_map2, outputs["heatmap"], outputs2["heatmap"], heatmap, heatmap2, line_points, line_points2, line_indices, outputs["descriptors"], outputs2["descriptors"], epoch, valid_mask, valid_mask2, ) else: junc_map = data["junction_map"].cuda() heatmap = data["heatmap"].cuda() valid_mask = data["valid_mask"].cuda() input_images = data["image"].cuda() # Run the forward pass with torch.no_grad(): outputs = model(input_images) # Compute losses losses = loss_func( outputs["junctions"], junc_map, outputs["heatmap"], heatmap, valid_mask, ) total_loss = losses["total_loss"] ############## Measure the metric error ######################### junc_np = convert_junc_predictions( outputs["junctions"], model_cfg["grid_size"], model_cfg["detection_thresh"], 300, ) junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1) # Always fetch only one channel (compatible with L1, L2, and CE) if outputs["heatmap"].shape[1] == 2: heatmap_np = ( softmax(outputs["heatmap"].detach(), dim=1) .cpu() .numpy() .transpose(0, 2, 3, 1) ) heatmap_np = heatmap_np[:, :, :, 1:] else: heatmap_np = torch.sigmoid(outputs["heatmap"].detach()) heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1) heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1) valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1) # Evaluate metric results if compute_descriptors: metric_func.evaluate( junc_np["junc_pred"], junc_np["junc_pred_nms"], junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np, line_points, line_points2, outputs["descriptors"], outputs2["descriptors"], line_indices, ) else: metric_func.evaluate( junc_np["junc_pred"], junc_np["junc_pred_nms"], junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np, ) # Update average meter junc_loss = losses["junc_loss"].item() heatmap_loss = losses["heatmap_loss"].item() loss_dict = { "junc_loss": junc_loss, "heatmap_loss": heatmap_loss, "total_loss": total_loss.item(), } if compute_descriptors: descriptor_loss = losses["descriptor_loss"].item() loss_dict["descriptor_loss"] = losses["descriptor_loss"].item() average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0]) # Display the progress if (idx % model_cfg["disp_freq"]) == 0: results = metric_func.metric_results average = average_meter.average() if compute_descriptors: print( "Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f)" % ( idx, len(val_loader), total_loss.item(), average["total_loss"], junc_loss, average["junc_loss"], heatmap_loss, average["heatmap_loss"], descriptor_loss, average["descriptor_loss"], ) ) else: print( "Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f)" % ( idx, len(val_loader), total_loss.item(), average["total_loss"], junc_loss, average["junc_loss"], heatmap_loss, average["heatmap_loss"], ) ) print( "\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" % ( results["junc_precision"], average["junc_precision"], results["junc_recall"], average["junc_recall"], ) ) print( "\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" % ( results["junc_precision_nms"], average["junc_precision_nms"], results["junc_recall_nms"], average["junc_recall_nms"], ) ) print( "\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" % ( results["heatmap_precision"], average["heatmap_precision"], results["heatmap_recall"], average["heatmap_recall"], ) ) if compute_descriptors: print( "\t Descriptors matching score=%.4f (%.4f)" % (results["matching_score"], average["matching_score"]) ) # Record summaries average = average_meter.average() scalar_summaries = {"average": average} # Record the training summary record_test_summaries(writer, epoch, scalar_summaries) def convert_junc_predictions(predictions, grid_size, detect_thresh=1 / 65, topk=300): """Convert torch predictions to numpy arrays for evaluation.""" # Convert to probability outputs first junc_prob = softmax(predictions.detach(), dim=1).cpu() junc_pred = junc_prob[:, :-1, :, :] junc_prob_np = junc_prob.numpy().transpose(0, 2, 3, 1)[:, :, :, :-1] junc_prob_np = np.sum(junc_prob_np, axis=-1) junc_pred_np = ( pixel_shuffle(junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1) ) junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk) junc_pred_np = junc_pred_np.squeeze(-1) return { "junc_pred": junc_pred_np, "junc_pred_nms": junc_pred_np_nms, "junc_prob": junc_prob_np, } def record_train_summaries(writer, global_step, scalars, images): """Record training summaries.""" # Record the scalar summaries results = scalars["metrics"] average = scalars["average"] # GPU memory part # Get gpu memory usage in GB gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3) writer.add_scalar("GPU/GPU_memory_usage", gpu_mem_usage, global_step) # Loss part writer.add_scalar("Train_loss/junc_loss", scalars["junc_loss"], global_step) writer.add_scalar("Train_loss/heatmap_loss", scalars["heatmap_loss"], global_step) writer.add_scalar("Train_loss/total_loss", scalars["total_loss"], global_step) # Add regularization loss if "reg_loss" in scalars.keys(): writer.add_scalar("Train_loss/reg_loss", scalars["reg_loss"], global_step) # Add descriptor loss if "descriptor_loss" in scalars.keys(): key = "descriptor_loss" writer.add_scalar("Train_loss/%s" % (key), scalars[key], global_step) writer.add_scalar("Train_loss_average/%s" % (key), average[key], global_step) # Record weighting for key in scalars.keys(): if "w_" in key: writer.add_scalar("Train_weight/%s" % (key), scalars[key], global_step) # Smoothed loss writer.add_scalar("Train_loss_average/junc_loss", average["junc_loss"], global_step) writer.add_scalar( "Train_loss_average/heatmap_loss", average["heatmap_loss"], global_step ) writer.add_scalar( "Train_loss_average/total_loss", average["total_loss"], global_step ) # Add smoothed descriptor loss if "descriptor_loss" in average.keys(): writer.add_scalar( "Train_loss_average/descriptor_loss", average["descriptor_loss"], global_step, ) # Metrics part writer.add_scalar( "Train_metrics/junc_precision", results["junc_precision"], global_step ) writer.add_scalar( "Train_metrics/junc_precision_nms", results["junc_precision_nms"], global_step ) writer.add_scalar("Train_metrics/junc_recall", results["junc_recall"], global_step) writer.add_scalar( "Train_metrics/junc_recall_nms", results["junc_recall_nms"], global_step ) writer.add_scalar( "Train_metrics/heatmap_precision", results["heatmap_precision"], global_step ) writer.add_scalar( "Train_metrics/heatmap_recall", results["heatmap_recall"], global_step ) # Add descriptor metric if "matching_score" in results.keys(): writer.add_scalar( "Train_metrics/matching_score", results["matching_score"], global_step ) # Average part writer.add_scalar( "Train_metrics_average/junc_precision", average["junc_precision"], global_step ) writer.add_scalar( "Train_metrics_average/junc_precision_nms", average["junc_precision_nms"], global_step, ) writer.add_scalar( "Train_metrics_average/junc_recall", average["junc_recall"], global_step ) writer.add_scalar( "Train_metrics_average/junc_recall_nms", average["junc_recall_nms"], global_step ) writer.add_scalar( "Train_metrics_average/heatmap_precision", average["heatmap_precision"], global_step, ) writer.add_scalar( "Train_metrics_average/heatmap_recall", average["heatmap_recall"], global_step ) # Add smoothed descriptor metric if "matching_score" in average.keys(): writer.add_scalar( "Train_metrics_average/matching_score", average["matching_score"], global_step, ) # Record the image summary # Image part image_tensor = convert_image(images["image"], 1) valid_masks = convert_image(images["valid_mask"], -1) writer.add_images("Train/images", image_tensor, global_step, dataformats="NCHW") writer.add_images("Train/valid_map", valid_masks, global_step, dataformats="NHWC") # Heatmap part writer.add_images( "Train/heatmap_gt", convert_image(images["heatmap_gt"], -1), global_step, dataformats="NHWC", ) writer.add_images( "Train/heatmap_pred", convert_image(images["heatmap_pred"], -1), global_step, dataformats="NHWC", ) # Junction prediction part junc_plots = plot_junction_detection( image_tensor, images["junc_map_pred"], images["junc_map_pred_nms"], images["junc_map_gt"], ) writer.add_images( "Train/junc_gt", junc_plots["junc_gt_plot"] / 255.0, global_step, dataformats="NHWC", ) writer.add_images( "Train/junc_pred", junc_plots["junc_pred_plot"] / 255.0, global_step, dataformats="NHWC", ) writer.add_images( "Train/junc_pred_nms", junc_plots["junc_pred_nms_plot"] / 255.0, global_step, dataformats="NHWC", ) writer.add_images( "Train/junc_prob_map", convert_image(images["junc_prob_map"][..., None], axis=-1), global_step, dataformats="NHWC", ) def record_test_summaries(writer, epoch, scalars): """Record testing summaries.""" average = scalars["average"] # Average loss writer.add_scalar("Val_loss/junc_loss", average["junc_loss"], epoch) writer.add_scalar("Val_loss/heatmap_loss", average["heatmap_loss"], epoch) writer.add_scalar("Val_loss/total_loss", average["total_loss"], epoch) # Add descriptor loss if "descriptor_loss" in average.keys(): key = "descriptor_loss" writer.add_scalar("Val_loss/%s" % (key), average[key], epoch) # Average metrics writer.add_scalar("Val_metrics/junc_precision", average["junc_precision"], epoch) writer.add_scalar( "Val_metrics/junc_precision_nms", average["junc_precision_nms"], epoch ) writer.add_scalar("Val_metrics/junc_recall", average["junc_recall"], epoch) writer.add_scalar("Val_metrics/junc_recall_nms", average["junc_recall_nms"], epoch) writer.add_scalar( "Val_metrics/heatmap_precision", average["heatmap_precision"], epoch ) writer.add_scalar("Val_metrics/heatmap_recall", average["heatmap_recall"], epoch) # Add descriptor metric if "matching_score" in average.keys(): writer.add_scalar( "Val_metrics/matching_score", average["matching_score"], epoch ) def plot_junction_detection( image_tensor, junc_pred_tensor, junc_pred_nms_tensor, junc_gt_tensor ): """Plot the junction points on images.""" # Get the batch_size batch_size = image_tensor.shape[0] # Process through batch dimension junc_pred_lst = [] junc_pred_nms_lst = [] junc_gt_lst = [] for i in range(batch_size): # Convert image to 255 uint8 image = (image_tensor[i, :, :, :] * 255.0).astype(np.uint8).transpose(1, 2, 0) # Plot groundtruth onto image junc_gt = junc_gt_tensor[i, ...] coord_gt = np.where(junc_gt.squeeze() > 0) points_gt = np.concatenate( (coord_gt[0][..., None], coord_gt[1][..., None]), axis=1 ) plot_gt = image.copy() for id in range(points_gt.shape[0]): cv2.circle( plot_gt, tuple(np.flip(points_gt[id, :])), 3, color=(255, 0, 0), thickness=2, ) junc_gt_lst.append(plot_gt[None, ...]) # Plot junc_pred junc_pred = junc_pred_tensor[i, ...] coord_pred = np.where(junc_pred > 0) points_pred = np.concatenate( (coord_pred[0][..., None], coord_pred[1][..., None]), axis=1 ) plot_pred = image.copy() for id in range(points_pred.shape[0]): cv2.circle( plot_pred, tuple(np.flip(points_pred[id, :])), 3, color=(0, 255, 0), thickness=2, ) junc_pred_lst.append(plot_pred[None, ...]) # Plot junc_pred_nms junc_pred_nms = junc_pred_nms_tensor[i, ...] coord_pred_nms = np.where(junc_pred_nms > 0) points_pred_nms = np.concatenate( (coord_pred_nms[0][..., None], coord_pred_nms[1][..., None]), axis=1 ) plot_pred_nms = image.copy() for id in range(points_pred_nms.shape[0]): cv2.circle( plot_pred_nms, tuple(np.flip(points_pred_nms[id, :])), 3, color=(0, 255, 0), thickness=2, ) junc_pred_nms_lst.append(plot_pred_nms[None, ...]) return { "junc_gt_plot": np.concatenate(junc_gt_lst, axis=0), "junc_pred_plot": np.concatenate(junc_pred_lst, axis=0), "junc_pred_nms_plot": np.concatenate(junc_pred_nms_lst, axis=0), }