"""
This file implements the training process and all the summaries
"""
import os
import numpy as np
import cv2
import torch
from torch.nn.functional import pixel_shuffle, softmax
from torch.utils.data import DataLoader
import torch.utils.data.dataloader as torch_loader
from tensorboardX import SummaryWriter

from .dataset.dataset_util import get_dataset
from .model.model_util import get_model
from .model.loss import TotalLoss, get_loss_and_weights
from .model.metrics import AverageMeter, Metrics, super_nms
from .model.lr_scheduler import get_lr_scheduler
from .misc.train_utils import (
    convert_image,
    get_latest_checkpoint,
    remove_old_checkpoints,
)


def customized_collate_fn(batch):
    """Customized collate_fn."""
    batch_keys = ["image", "junction_map", "heatmap", "valid_mask"]
    list_keys = ["junctions", "line_map"]

    outputs = {}
    for key in batch_keys:
        outputs[key] = torch_loader.default_collate([b[key] for b in batch])
    for key in list_keys:
        outputs[key] = [b[key] for b in batch]

    return outputs


def restore_weights(model, state_dict, strict=True):
    """Restore weights in compatible mode."""
    # Try to directly load state dict
    try:
        model.load_state_dict(state_dict, strict=strict)
    # Deal with some version compatibility issue (catch version incompatible)
    except:
        err = model.load_state_dict(state_dict, strict=False)

        # missing keys are those in model but not in state_dict
        missing_keys = err.missing_keys
        # Unexpected keys are those in state_dict but not in model
        unexpected_keys = err.unexpected_keys

        # Load mismatched keys manually
        model_dict = model.state_dict()
        for idx, key in enumerate(missing_keys):
            dict_keys = [_ for _ in unexpected_keys if not "tracked" in _]
            model_dict[key] = state_dict[dict_keys[idx]]
        model.load_state_dict(model_dict)

    return model


def train_net(args, dataset_cfg, model_cfg, output_path):
    """Main training function."""
    # Add some version compatibility check
    if model_cfg.get("weighting_policy") is None:
        # Default to static
        model_cfg["weighting_policy"] = "static"

    # Get the train, val, test config
    train_cfg = model_cfg["train"]
    test_cfg = model_cfg["test"]

    # Create train and test dataset
    print("\t Initializing dataset...")
    train_dataset, train_collate_fn = get_dataset("train", dataset_cfg)
    test_dataset, test_collate_fn = get_dataset("test", dataset_cfg)

    # Create the dataloader
    train_loader = DataLoader(
        train_dataset,
        batch_size=train_cfg["batch_size"],
        num_workers=8,
        shuffle=True,
        pin_memory=True,
        collate_fn=train_collate_fn,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=test_cfg.get("batch_size", 1),
        num_workers=test_cfg.get("num_workers", 1),
        shuffle=False,
        pin_memory=False,
        collate_fn=test_collate_fn,
    )
    print("\t Successfully intialized dataloaders.")

    # Get the loss function and weight first
    loss_funcs, loss_weights = get_loss_and_weights(model_cfg)

    # If resume.
    if args.resume:
        # Create model and load the state dict
        checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name)
        model = get_model(model_cfg, loss_weights)
        model = restore_weights(model, checkpoint["model_state_dict"])
        model = model.cuda()
        optimizer = torch.optim.Adam(
            [{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}],
            model_cfg["learning_rate"],
            amsgrad=True,
        )
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        # Optionally get the learning rate scheduler
        scheduler = get_lr_scheduler(
            lr_decay=model_cfg.get("lr_decay", False),
            lr_decay_cfg=model_cfg.get("lr_decay_cfg", None),
            optimizer=optimizer,
        )
        # If we start to use learning rate scheduler from the middle
        if (scheduler is not None) and (
            checkpoint.get("scheduler_state_dict", None) is not None
        ):
            scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
    # Initialize all the components.
    else:
        # Create model and optimizer
        model = get_model(model_cfg, loss_weights)
        # Optionally get the pretrained wieghts
        if args.pretrained:
            print("\t [Debug] Loading pretrained weights...")
            checkpoint = get_latest_checkpoint(
                args.pretrained_path, args.checkpoint_name
            )
            # If auto weighting restore from non-auto weighting
            model = restore_weights(model, checkpoint["model_state_dict"], strict=False)
            print("\t [Debug] Finished loading pretrained weights!")

        model = model.cuda()
        optimizer = torch.optim.Adam(
            [{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}],
            model_cfg["learning_rate"],
            amsgrad=True,
        )
        # Optionally get the learning rate scheduler
        scheduler = get_lr_scheduler(
            lr_decay=model_cfg.get("lr_decay", False),
            lr_decay_cfg=model_cfg.get("lr_decay_cfg", None),
            optimizer=optimizer,
        )
        start_epoch = 0

    print("\t Successfully initialized model")

    # Define the total loss
    policy = model_cfg.get("weighting_policy", "static")
    loss_func = TotalLoss(loss_funcs, loss_weights, policy).cuda()
    if "descriptor_decoder" in model_cfg:
        metric_func = Metrics(
            model_cfg["detection_thresh"],
            model_cfg["prob_thresh"],
            model_cfg["descriptor_loss_cfg"]["grid_size"],
            desc_metric_lst="all",
        )
    else:
        metric_func = Metrics(
            model_cfg["detection_thresh"],
            model_cfg["prob_thresh"],
            model_cfg["grid_size"],
        )

    # Define the summary writer
    logdir = os.path.join(output_path, "log")
    writer = SummaryWriter(logdir=logdir)

    # Start the training loop
    for epoch in range(start_epoch, model_cfg["epochs"]):
        # Record the learning rate
        current_lr = optimizer.state_dict()["param_groups"][0]["lr"]
        writer.add_scalar("LR/lr", current_lr, epoch)

        # Train for one epochs
        print("\n\n================== Training ====================")
        train_single_epoch(
            model=model,
            model_cfg=model_cfg,
            optimizer=optimizer,
            loss_func=loss_func,
            metric_func=metric_func,
            train_loader=train_loader,
            writer=writer,
            epoch=epoch,
        )

        # Do the validation
        print("\n\n================== Validation ==================")
        validate(
            model=model,
            model_cfg=model_cfg,
            loss_func=loss_func,
            metric_func=metric_func,
            val_loader=test_loader,
            writer=writer,
            epoch=epoch,
        )

        # Update the scheduler
        if scheduler is not None:
            scheduler.step()

        # Save checkpoints
        file_name = os.path.join(output_path, "checkpoint-epoch%03d-end.tar" % (epoch))
        print("[Info] Saving checkpoint %s ..." % file_name)
        save_dict = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "model_cfg": model_cfg,
        }
        if scheduler is not None:
            save_dict.update({"scheduler_state_dict": scheduler.state_dict()})
        torch.save(save_dict, file_name)

        # Remove the outdated checkpoints
        remove_old_checkpoints(output_path, model_cfg.get("max_ckpt", 15))


def train_single_epoch(
    model, model_cfg, optimizer, loss_func, metric_func, train_loader, writer, epoch
):
    """Train for one epoch."""
    # Switch the model to training mode
    model.train()

    # Initialize the average meter
    compute_descriptors = loss_func.compute_descriptors
    if compute_descriptors:
        average_meter = AverageMeter(is_training=True, desc_metric_lst="all")
    else:
        average_meter = AverageMeter(is_training=True)

    # The training loop
    for idx, data in enumerate(train_loader):
        if compute_descriptors:
            junc_map = data["ref_junction_map"].cuda()
            junc_map2 = data["target_junction_map"].cuda()
            heatmap = data["ref_heatmap"].cuda()
            heatmap2 = data["target_heatmap"].cuda()
            line_points = data["ref_line_points"].cuda()
            line_points2 = data["target_line_points"].cuda()
            line_indices = data["ref_line_indices"].cuda()
            valid_mask = data["ref_valid_mask"].cuda()
            valid_mask2 = data["target_valid_mask"].cuda()
            input_images = data["ref_image"].cuda()
            input_images2 = data["target_image"].cuda()

            # Run the forward pass
            outputs = model(input_images)
            outputs2 = model(input_images2)

            # Compute losses
            losses = loss_func.forward_descriptors(
                outputs["junctions"],
                outputs2["junctions"],
                junc_map,
                junc_map2,
                outputs["heatmap"],
                outputs2["heatmap"],
                heatmap,
                heatmap2,
                line_points,
                line_points2,
                line_indices,
                outputs["descriptors"],
                outputs2["descriptors"],
                epoch,
                valid_mask,
                valid_mask2,
            )
        else:
            junc_map = data["junction_map"].cuda()
            heatmap = data["heatmap"].cuda()
            valid_mask = data["valid_mask"].cuda()
            input_images = data["image"].cuda()

            # Run the forward pass
            outputs = model(input_images)

            # Compute losses
            losses = loss_func(
                outputs["junctions"], junc_map, outputs["heatmap"], heatmap, valid_mask
            )

        total_loss = losses["total_loss"]

        # Update the model
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Compute the global step
        global_step = epoch * len(train_loader) + idx
        ############## Measure the metric error #########################
        # Only do this when needed
        if ((idx % model_cfg["disp_freq"]) == 0) or (
            (idx % model_cfg["summary_freq"]) == 0
        ):
            junc_np = convert_junc_predictions(
                outputs["junctions"],
                model_cfg["grid_size"],
                model_cfg["detection_thresh"],
                300,
            )
            junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1)

            # Always fetch only one channel (compatible with L1, L2, and CE)
            if outputs["heatmap"].shape[1] == 2:
                heatmap_np = softmax(outputs["heatmap"].detach(), dim=1).cpu().numpy()
                heatmap_np = heatmap_np.transpose(0, 2, 3, 1)[:, :, :, 1:]
            else:
                heatmap_np = torch.sigmoid(outputs["heatmap"].detach())
                heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1)

            heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1)
            valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1)

            # Evaluate metric results
            if compute_descriptors:
                metric_func.evaluate(
                    junc_np["junc_pred"],
                    junc_np["junc_pred_nms"],
                    junc_map_np,
                    heatmap_np,
                    heatmap_gt_np,
                    valid_mask_np,
                    line_points,
                    line_points2,
                    outputs["descriptors"],
                    outputs2["descriptors"],
                    line_indices,
                )
            else:
                metric_func.evaluate(
                    junc_np["junc_pred"],
                    junc_np["junc_pred_nms"],
                    junc_map_np,
                    heatmap_np,
                    heatmap_gt_np,
                    valid_mask_np,
                )
            # Update average meter
            junc_loss = losses["junc_loss"].item()
            heatmap_loss = losses["heatmap_loss"].item()
            loss_dict = {
                "junc_loss": junc_loss,
                "heatmap_loss": heatmap_loss,
                "total_loss": total_loss.item(),
            }
            if compute_descriptors:
                descriptor_loss = losses["descriptor_loss"].item()
                loss_dict["descriptor_loss"] = losses["descriptor_loss"].item()

            average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0])

        # Display the progress
        if (idx % model_cfg["disp_freq"]) == 0:
            results = metric_func.metric_results
            average = average_meter.average()
            # Get gpu memory usage in GB
            gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3)
            if compute_descriptors:
                print(
                    "Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f), gpu_mem=%.4fGB"
                    % (
                        epoch,
                        model_cfg["epochs"],
                        idx,
                        len(train_loader),
                        total_loss.item(),
                        average["total_loss"],
                        junc_loss,
                        average["junc_loss"],
                        heatmap_loss,
                        average["heatmap_loss"],
                        descriptor_loss,
                        average["descriptor_loss"],
                        gpu_mem_usage,
                    )
                )
            else:
                print(
                    "Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), gpu_mem=%.4fGB"
                    % (
                        epoch,
                        model_cfg["epochs"],
                        idx,
                        len(train_loader),
                        total_loss.item(),
                        average["total_loss"],
                        junc_loss,
                        average["junc_loss"],
                        heatmap_loss,
                        average["heatmap_loss"],
                        gpu_mem_usage,
                    )
                )
            print(
                "\t Junction     precision=%.4f (%.4f) / recall=%.4f (%.4f)"
                % (
                    results["junc_precision"],
                    average["junc_precision"],
                    results["junc_recall"],
                    average["junc_recall"],
                )
            )
            print(
                "\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)"
                % (
                    results["junc_precision_nms"],
                    average["junc_precision_nms"],
                    results["junc_recall_nms"],
                    average["junc_recall_nms"],
                )
            )
            print(
                "\t Heatmap      precision=%.4f (%.4f) / recall=%.4f (%.4f)"
                % (
                    results["heatmap_precision"],
                    average["heatmap_precision"],
                    results["heatmap_recall"],
                    average["heatmap_recall"],
                )
            )
            if compute_descriptors:
                print(
                    "\t Descriptors  matching score=%.4f (%.4f)"
                    % (results["matching_score"], average["matching_score"])
                )

        # Record summaries
        if (idx % model_cfg["summary_freq"]) == 0:
            results = metric_func.metric_results
            average = average_meter.average()
            # Add the shared losses
            scalar_summaries = {
                "junc_loss": junc_loss,
                "heatmap_loss": heatmap_loss,
                "total_loss": total_loss.detach().cpu().numpy(),
                "metrics": results,
                "average": average,
            }
            # Add descriptor terms
            if compute_descriptors:
                scalar_summaries["descriptor_loss"] = descriptor_loss
                scalar_summaries["w_desc"] = losses["w_desc"]

            # Add weighting terms (even for static terms)
            scalar_summaries["w_junc"] = losses["w_junc"]
            scalar_summaries["w_heatmap"] = losses["w_heatmap"]
            scalar_summaries["reg_loss"] = losses["reg_loss"].item()

            num_images = 3
            junc_pred_binary = (
                junc_np["junc_pred"][:num_images, ...] > model_cfg["detection_thresh"]
            )
            junc_pred_nms_binary = (
                junc_np["junc_pred_nms"][:num_images, ...]
                > model_cfg["detection_thresh"]
            )
            image_summaries = {
                "image": input_images.cpu().numpy()[:num_images, ...],
                "valid_mask": valid_mask_np[:num_images, ...],
                "junc_map_pred": junc_pred_binary,
                "junc_map_pred_nms": junc_pred_nms_binary,
                "junc_map_gt": junc_map_np[:num_images, ...],
                "junc_prob_map": junc_np["junc_prob"][:num_images, ...],
                "heatmap_pred": heatmap_np[:num_images, ...],
                "heatmap_gt": heatmap_gt_np[:num_images, ...],
            }
            # Record the training summary
            record_train_summaries(
                writer, global_step, scalars=scalar_summaries, images=image_summaries
            )


def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch):
    """Validation."""
    # Switch the model to eval mode
    model.eval()

    # Initialize the average meter
    compute_descriptors = loss_func.compute_descriptors
    if compute_descriptors:
        average_meter = AverageMeter(is_training=True, desc_metric_lst="all")
    else:
        average_meter = AverageMeter(is_training=True)

    # The validation loop
    for idx, data in enumerate(val_loader):
        if compute_descriptors:
            junc_map = data["ref_junction_map"].cuda()
            junc_map2 = data["target_junction_map"].cuda()
            heatmap = data["ref_heatmap"].cuda()
            heatmap2 = data["target_heatmap"].cuda()
            line_points = data["ref_line_points"].cuda()
            line_points2 = data["target_line_points"].cuda()
            line_indices = data["ref_line_indices"].cuda()
            valid_mask = data["ref_valid_mask"].cuda()
            valid_mask2 = data["target_valid_mask"].cuda()
            input_images = data["ref_image"].cuda()
            input_images2 = data["target_image"].cuda()

            # Run the forward pass
            with torch.no_grad():
                outputs = model(input_images)
                outputs2 = model(input_images2)

                # Compute losses
                losses = loss_func.forward_descriptors(
                    outputs["junctions"],
                    outputs2["junctions"],
                    junc_map,
                    junc_map2,
                    outputs["heatmap"],
                    outputs2["heatmap"],
                    heatmap,
                    heatmap2,
                    line_points,
                    line_points2,
                    line_indices,
                    outputs["descriptors"],
                    outputs2["descriptors"],
                    epoch,
                    valid_mask,
                    valid_mask2,
                )
        else:
            junc_map = data["junction_map"].cuda()
            heatmap = data["heatmap"].cuda()
            valid_mask = data["valid_mask"].cuda()
            input_images = data["image"].cuda()

            # Run the forward pass
            with torch.no_grad():
                outputs = model(input_images)

                # Compute losses
                losses = loss_func(
                    outputs["junctions"],
                    junc_map,
                    outputs["heatmap"],
                    heatmap,
                    valid_mask,
                )
        total_loss = losses["total_loss"]

        ############## Measure the metric error #########################
        junc_np = convert_junc_predictions(
            outputs["junctions"],
            model_cfg["grid_size"],
            model_cfg["detection_thresh"],
            300,
        )
        junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1)
        # Always fetch only one channel (compatible with L1, L2, and CE)
        if outputs["heatmap"].shape[1] == 2:
            heatmap_np = (
                softmax(outputs["heatmap"].detach(), dim=1)
                .cpu()
                .numpy()
                .transpose(0, 2, 3, 1)
            )
            heatmap_np = heatmap_np[:, :, :, 1:]
        else:
            heatmap_np = torch.sigmoid(outputs["heatmap"].detach())
            heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1)

        heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1)
        valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1)

        # Evaluate metric results
        if compute_descriptors:
            metric_func.evaluate(
                junc_np["junc_pred"],
                junc_np["junc_pred_nms"],
                junc_map_np,
                heatmap_np,
                heatmap_gt_np,
                valid_mask_np,
                line_points,
                line_points2,
                outputs["descriptors"],
                outputs2["descriptors"],
                line_indices,
            )
        else:
            metric_func.evaluate(
                junc_np["junc_pred"],
                junc_np["junc_pred_nms"],
                junc_map_np,
                heatmap_np,
                heatmap_gt_np,
                valid_mask_np,
            )
        # Update average meter
        junc_loss = losses["junc_loss"].item()
        heatmap_loss = losses["heatmap_loss"].item()
        loss_dict = {
            "junc_loss": junc_loss,
            "heatmap_loss": heatmap_loss,
            "total_loss": total_loss.item(),
        }
        if compute_descriptors:
            descriptor_loss = losses["descriptor_loss"].item()
            loss_dict["descriptor_loss"] = losses["descriptor_loss"].item()
        average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0])

        # Display the progress
        if (idx % model_cfg["disp_freq"]) == 0:
            results = metric_func.metric_results
            average = average_meter.average()
            if compute_descriptors:
                print(
                    "Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f)"
                    % (
                        idx,
                        len(val_loader),
                        total_loss.item(),
                        average["total_loss"],
                        junc_loss,
                        average["junc_loss"],
                        heatmap_loss,
                        average["heatmap_loss"],
                        descriptor_loss,
                        average["descriptor_loss"],
                    )
                )
            else:
                print(
                    "Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f)"
                    % (
                        idx,
                        len(val_loader),
                        total_loss.item(),
                        average["total_loss"],
                        junc_loss,
                        average["junc_loss"],
                        heatmap_loss,
                        average["heatmap_loss"],
                    )
                )
            print(
                "\t Junction     precision=%.4f (%.4f) / recall=%.4f (%.4f)"
                % (
                    results["junc_precision"],
                    average["junc_precision"],
                    results["junc_recall"],
                    average["junc_recall"],
                )
            )
            print(
                "\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)"
                % (
                    results["junc_precision_nms"],
                    average["junc_precision_nms"],
                    results["junc_recall_nms"],
                    average["junc_recall_nms"],
                )
            )
            print(
                "\t Heatmap      precision=%.4f (%.4f) / recall=%.4f (%.4f)"
                % (
                    results["heatmap_precision"],
                    average["heatmap_precision"],
                    results["heatmap_recall"],
                    average["heatmap_recall"],
                )
            )
            if compute_descriptors:
                print(
                    "\t Descriptors  matching score=%.4f (%.4f)"
                    % (results["matching_score"], average["matching_score"])
                )

    # Record summaries
    average = average_meter.average()
    scalar_summaries = {"average": average}
    # Record the training summary
    record_test_summaries(writer, epoch, scalar_summaries)


def convert_junc_predictions(predictions, grid_size, detect_thresh=1 / 65, topk=300):
    """Convert torch predictions to numpy arrays for evaluation."""
    # Convert to probability outputs first
    junc_prob = softmax(predictions.detach(), dim=1).cpu()
    junc_pred = junc_prob[:, :-1, :, :]

    junc_prob_np = junc_prob.numpy().transpose(0, 2, 3, 1)[:, :, :, :-1]
    junc_prob_np = np.sum(junc_prob_np, axis=-1)
    junc_pred_np = (
        pixel_shuffle(junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1)
    )
    junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk)
    junc_pred_np = junc_pred_np.squeeze(-1)

    return {
        "junc_pred": junc_pred_np,
        "junc_pred_nms": junc_pred_np_nms,
        "junc_prob": junc_prob_np,
    }


def record_train_summaries(writer, global_step, scalars, images):
    """Record training summaries."""
    # Record the scalar summaries
    results = scalars["metrics"]
    average = scalars["average"]

    # GPU memory part
    # Get gpu memory usage in GB
    gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3)
    writer.add_scalar("GPU/GPU_memory_usage", gpu_mem_usage, global_step)

    # Loss part
    writer.add_scalar("Train_loss/junc_loss", scalars["junc_loss"], global_step)
    writer.add_scalar("Train_loss/heatmap_loss", scalars["heatmap_loss"], global_step)
    writer.add_scalar("Train_loss/total_loss", scalars["total_loss"], global_step)
    # Add regularization loss
    if "reg_loss" in scalars.keys():
        writer.add_scalar("Train_loss/reg_loss", scalars["reg_loss"], global_step)
    # Add descriptor loss
    if "descriptor_loss" in scalars.keys():
        key = "descriptor_loss"
        writer.add_scalar("Train_loss/%s" % (key), scalars[key], global_step)
        writer.add_scalar("Train_loss_average/%s" % (key), average[key], global_step)

    # Record weighting
    for key in scalars.keys():
        if "w_" in key:
            writer.add_scalar("Train_weight/%s" % (key), scalars[key], global_step)

    # Smoothed loss
    writer.add_scalar("Train_loss_average/junc_loss", average["junc_loss"], global_step)
    writer.add_scalar(
        "Train_loss_average/heatmap_loss", average["heatmap_loss"], global_step
    )
    writer.add_scalar(
        "Train_loss_average/total_loss", average["total_loss"], global_step
    )
    # Add smoothed descriptor loss
    if "descriptor_loss" in average.keys():
        writer.add_scalar(
            "Train_loss_average/descriptor_loss",
            average["descriptor_loss"],
            global_step,
        )

    # Metrics part
    writer.add_scalar(
        "Train_metrics/junc_precision", results["junc_precision"], global_step
    )
    writer.add_scalar(
        "Train_metrics/junc_precision_nms", results["junc_precision_nms"], global_step
    )
    writer.add_scalar("Train_metrics/junc_recall", results["junc_recall"], global_step)
    writer.add_scalar(
        "Train_metrics/junc_recall_nms", results["junc_recall_nms"], global_step
    )
    writer.add_scalar(
        "Train_metrics/heatmap_precision", results["heatmap_precision"], global_step
    )
    writer.add_scalar(
        "Train_metrics/heatmap_recall", results["heatmap_recall"], global_step
    )
    # Add descriptor metric
    if "matching_score" in results.keys():
        writer.add_scalar(
            "Train_metrics/matching_score", results["matching_score"], global_step
        )

    # Average part
    writer.add_scalar(
        "Train_metrics_average/junc_precision", average["junc_precision"], global_step
    )
    writer.add_scalar(
        "Train_metrics_average/junc_precision_nms",
        average["junc_precision_nms"],
        global_step,
    )
    writer.add_scalar(
        "Train_metrics_average/junc_recall", average["junc_recall"], global_step
    )
    writer.add_scalar(
        "Train_metrics_average/junc_recall_nms", average["junc_recall_nms"], global_step
    )
    writer.add_scalar(
        "Train_metrics_average/heatmap_precision",
        average["heatmap_precision"],
        global_step,
    )
    writer.add_scalar(
        "Train_metrics_average/heatmap_recall", average["heatmap_recall"], global_step
    )
    # Add smoothed descriptor metric
    if "matching_score" in average.keys():
        writer.add_scalar(
            "Train_metrics_average/matching_score",
            average["matching_score"],
            global_step,
        )

    # Record the image summary
    # Image part
    image_tensor = convert_image(images["image"], 1)
    valid_masks = convert_image(images["valid_mask"], -1)
    writer.add_images("Train/images", image_tensor, global_step, dataformats="NCHW")
    writer.add_images("Train/valid_map", valid_masks, global_step, dataformats="NHWC")

    # Heatmap part
    writer.add_images(
        "Train/heatmap_gt",
        convert_image(images["heatmap_gt"], -1),
        global_step,
        dataformats="NHWC",
    )
    writer.add_images(
        "Train/heatmap_pred",
        convert_image(images["heatmap_pred"], -1),
        global_step,
        dataformats="NHWC",
    )

    # Junction prediction part
    junc_plots = plot_junction_detection(
        image_tensor,
        images["junc_map_pred"],
        images["junc_map_pred_nms"],
        images["junc_map_gt"],
    )
    writer.add_images(
        "Train/junc_gt",
        junc_plots["junc_gt_plot"] / 255.0,
        global_step,
        dataformats="NHWC",
    )
    writer.add_images(
        "Train/junc_pred",
        junc_plots["junc_pred_plot"] / 255.0,
        global_step,
        dataformats="NHWC",
    )
    writer.add_images(
        "Train/junc_pred_nms",
        junc_plots["junc_pred_nms_plot"] / 255.0,
        global_step,
        dataformats="NHWC",
    )
    writer.add_images(
        "Train/junc_prob_map",
        convert_image(images["junc_prob_map"][..., None], axis=-1),
        global_step,
        dataformats="NHWC",
    )


def record_test_summaries(writer, epoch, scalars):
    """Record testing summaries."""
    average = scalars["average"]

    # Average loss
    writer.add_scalar("Val_loss/junc_loss", average["junc_loss"], epoch)
    writer.add_scalar("Val_loss/heatmap_loss", average["heatmap_loss"], epoch)
    writer.add_scalar("Val_loss/total_loss", average["total_loss"], epoch)
    # Add descriptor loss
    if "descriptor_loss" in average.keys():
        key = "descriptor_loss"
        writer.add_scalar("Val_loss/%s" % (key), average[key], epoch)

    # Average metrics
    writer.add_scalar("Val_metrics/junc_precision", average["junc_precision"], epoch)
    writer.add_scalar(
        "Val_metrics/junc_precision_nms", average["junc_precision_nms"], epoch
    )
    writer.add_scalar("Val_metrics/junc_recall", average["junc_recall"], epoch)
    writer.add_scalar("Val_metrics/junc_recall_nms", average["junc_recall_nms"], epoch)
    writer.add_scalar(
        "Val_metrics/heatmap_precision", average["heatmap_precision"], epoch
    )
    writer.add_scalar("Val_metrics/heatmap_recall", average["heatmap_recall"], epoch)
    # Add descriptor metric
    if "matching_score" in average.keys():
        writer.add_scalar(
            "Val_metrics/matching_score", average["matching_score"], epoch
        )


def plot_junction_detection(
    image_tensor, junc_pred_tensor, junc_pred_nms_tensor, junc_gt_tensor
):
    """Plot the junction points on images."""
    # Get the batch_size
    batch_size = image_tensor.shape[0]

    # Process through batch dimension
    junc_pred_lst = []
    junc_pred_nms_lst = []
    junc_gt_lst = []
    for i in range(batch_size):
        # Convert image to 255 uint8
        image = (image_tensor[i, :, :, :] * 255.0).astype(np.uint8).transpose(1, 2, 0)

        # Plot groundtruth onto image
        junc_gt = junc_gt_tensor[i, ...]
        coord_gt = np.where(junc_gt.squeeze() > 0)
        points_gt = np.concatenate(
            (coord_gt[0][..., None], coord_gt[1][..., None]), axis=1
        )
        plot_gt = image.copy()
        for id in range(points_gt.shape[0]):
            cv2.circle(
                plot_gt,
                tuple(np.flip(points_gt[id, :])),
                3,
                color=(255, 0, 0),
                thickness=2,
            )
        junc_gt_lst.append(plot_gt[None, ...])

        # Plot junc_pred
        junc_pred = junc_pred_tensor[i, ...]
        coord_pred = np.where(junc_pred > 0)
        points_pred = np.concatenate(
            (coord_pred[0][..., None], coord_pred[1][..., None]), axis=1
        )
        plot_pred = image.copy()
        for id in range(points_pred.shape[0]):
            cv2.circle(
                plot_pred,
                tuple(np.flip(points_pred[id, :])),
                3,
                color=(0, 255, 0),
                thickness=2,
            )
        junc_pred_lst.append(plot_pred[None, ...])

        # Plot junc_pred_nms
        junc_pred_nms = junc_pred_nms_tensor[i, ...]
        coord_pred_nms = np.where(junc_pred_nms > 0)
        points_pred_nms = np.concatenate(
            (coord_pred_nms[0][..., None], coord_pred_nms[1][..., None]), axis=1
        )
        plot_pred_nms = image.copy()
        for id in range(points_pred_nms.shape[0]):
            cv2.circle(
                plot_pred_nms,
                tuple(np.flip(points_pred_nms[id, :])),
                3,
                color=(0, 255, 0),
                thickness=2,
            )
        junc_pred_nms_lst.append(plot_pred_nms[None, ...])

    return {
        "junc_gt_plot": np.concatenate(junc_gt_lst, axis=0),
        "junc_pred_plot": np.concatenate(junc_pred_lst, axis=0),
        "junc_pred_nms_plot": np.concatenate(junc_pred_nms_lst, axis=0),
    }