import numpy as np import copy import cv2 import h5py import math from tqdm import tqdm import torch from torch.nn.functional import pixel_shuffle, softmax from torch.utils.data import DataLoader from kornia.geometry import warp_perspective from .dataset.dataset_util import get_dataset from .model.model_util import get_model from .misc.train_utils import get_latest_checkpoint from .train import convert_junc_predictions from .dataset.transforms.homographic_transforms import sample_homography def restore_weights(model, state_dict): """Restore weights in compatible mode.""" # Try to directly load state dict try: model.load_state_dict(state_dict) except: err = model.load_state_dict(state_dict, strict=False) # missing keys are those in model but not in state_dict missing_keys = err.missing_keys # Unexpected keys are those in state_dict but not in model unexpected_keys = err.unexpected_keys # Load mismatched keys manually model_dict = model.state_dict() for idx, key in enumerate(missing_keys): dict_keys = [_ for _ in unexpected_keys if not "tracked" in _] model_dict[key] = state_dict[dict_keys[idx]] model.load_state_dict(model_dict) return model def get_padded_filename(num_pad, idx): """Get the filename padded with 0.""" file_len = len("%d" % (idx)) filename = "0" * (num_pad - file_len) + "%d" % (idx) return filename def export_predictions(args, dataset_cfg, model_cfg, output_path, export_dataset_mode): """Export predictions.""" # Get the test configuration test_cfg = model_cfg["test"] # Create the dataset and dataloader based on the export_dataset_mode print("\t Initializing dataset and dataloader") batch_size = 4 export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg) export_loader = DataLoader( export_dataset, batch_size=batch_size, num_workers=test_cfg.get("num_workers", 4), shuffle=False, pin_memory=False, collate_fn=collate_fn, ) print("\t Successfully intialized dataset and dataloader.") # Initialize model and load the checkpoint model = get_model(model_cfg, mode="test") checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name) model = restore_weights(model, checkpoint["model_state_dict"]) model = model.cuda() model.eval() print("\t Successfully initialized model") # Start the export process print("[Info] Start exporting predictions") output_dataset_path = output_path + ".h5" filename_idx = 0 with h5py.File(output_dataset_path, "w", libver="latest", swmr=True) as f: # Iterate through all the data in dataloader for data in tqdm(export_loader, ascii=True): # Fetch the data junc_map = data["junction_map"] heatmap = data["heatmap"] valid_mask = data["valid_mask"] input_images = data["image"].cuda() # Run the forward pass with torch.no_grad(): outputs = model(input_images) # Convert predictions junc_np = convert_junc_predictions( outputs["junctions"], model_cfg["grid_size"], model_cfg["detection_thresh"], 300, ) junc_map_np = junc_map.numpy().transpose(0, 2, 3, 1) heatmap_np = ( softmax(outputs["heatmap"].detach(), dim=1) .cpu() .numpy() .transpose(0, 2, 3, 1) ) heatmap_gt_np = heatmap.numpy().transpose(0, 2, 3, 1) valid_mask_np = valid_mask.numpy().transpose(0, 2, 3, 1) # Data entries to save current_batch_size = input_images.shape[0] for batch_idx in range(current_batch_size): output_data = { "image": input_images.cpu() .numpy() .transpose(0, 2, 3, 1)[batch_idx], "junc_gt": junc_map_np[batch_idx], "junc_pred": junc_np["junc_pred"][batch_idx], "junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype( np.float32 ), "heatmap_gt": heatmap_gt_np[batch_idx], "heatmap_pred": heatmap_np[batch_idx], "valid_mask": valid_mask_np[batch_idx], "junc_points": data["junctions"][batch_idx] .numpy()[0] .round() .astype(np.int32), "line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32), } # Save data to h5 dataset num_pad = math.ceil(math.log10(len(export_loader))) + 1 output_key = get_padded_filename(num_pad, filename_idx) f_group = f.create_group(output_key) # Store data for key, output_data in output_data.items(): f_group.create_dataset(key, data=output_data, compression="gzip") filename_idx += 1 def export_homograpy_adaptation( args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device ): """Export homography adaptation results.""" # Check if the export_dataset_mode is supported supported_modes = ["train", "test"] if not export_dataset_mode in supported_modes: raise ValueError("[Error] The specified export_dataset_mode is not supported.") # Get the test configuration test_cfg = model_cfg["test"] # Get the homography adaptation configurations homography_cfg = dataset_cfg.get("homography_adaptation", None) if homography_cfg is None: raise ValueError("[Error] Empty homography_adaptation entry in config.") # Create the dataset and dataloader based on the export_dataset_mode print("\t Initializing dataset and dataloader") batch_size = args.export_batch_size export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg) export_loader = DataLoader( export_dataset, batch_size=batch_size, num_workers=test_cfg.get("num_workers", 4), shuffle=False, pin_memory=False, collate_fn=collate_fn, ) print("\t Successfully intialized dataset and dataloader.") # Initialize model and load the checkpoint model = get_model(model_cfg, mode="test") checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name, device) model = restore_weights(model, checkpoint["model_state_dict"]) model = model.to(device).eval() print("\t Successfully initialized model") # Start the export process print("[Info] Start exporting predictions") output_dataset_path = output_path + ".h5" with h5py.File(output_dataset_path, "w", libver="latest") as f: f.swmr_mode = True for _, data in enumerate(tqdm(export_loader, ascii=True)): input_images = data["image"].to(device) file_keys = data["file_key"] batch_size = input_images.shape[0] # Run the homograpy adaptation outputs = homography_adaptation( input_images, model, model_cfg["grid_size"], homography_cfg ) # Save the entries for batch_idx in range(batch_size): # Get the save key save_key = file_keys[batch_idx] output_data = { "image": input_images.cpu() .numpy() .transpose(0, 2, 3, 1)[batch_idx], "junc_prob_mean": outputs["junc_probs_mean"] .cpu() .numpy() .transpose(0, 2, 3, 1)[batch_idx], "junc_prob_max": outputs["junc_probs_max"] .cpu() .numpy() .transpose(0, 2, 3, 1)[batch_idx], "junc_count": outputs["junc_counts"] .cpu() .numpy() .transpose(0, 2, 3, 1)[batch_idx], "heatmap_prob_mean": outputs["heatmap_probs_mean"] .cpu() .numpy() .transpose(0, 2, 3, 1)[batch_idx], "heatmap_prob_max": outputs["heatmap_probs_max"] .cpu() .numpy() .transpose(0, 2, 3, 1)[batch_idx], "heatmap_cout": outputs["heatmap_counts"] .cpu() .numpy() .transpose(0, 2, 3, 1)[batch_idx], } # Create group and write data f_group = f.create_group(save_key) for key, output_data in output_data.items(): f_group.create_dataset(key, data=output_data, compression="gzip") def homography_adaptation(input_images, model, grid_size, homography_cfg): """The homography adaptation process. Arguments: input_images: The images to be evaluated. model: The pytorch model in evaluation mode. grid_size: Grid size of the junction decoder. homography_cfg: Homography adaptation configurations. """ # Get the device of the current model device = next(model.parameters()).device # Define some constants and placeholder batch_size, _, H, W = input_images.shape num_iter = homography_cfg["num_iter"] junc_probs = torch.zeros([batch_size, num_iter, H, W], device=device) junc_counts = torch.zeros([batch_size, 1, H, W], device=device) heatmap_probs = torch.zeros([batch_size, num_iter, H, W], device=device) heatmap_counts = torch.zeros([batch_size, 1, H, W], device=device) margin = homography_cfg["valid_border_margin"] # Keep a config with no artifacts homography_cfg_no_artifacts = copy.copy(homography_cfg["homographies"]) homography_cfg_no_artifacts["allow_artifacts"] = False for idx in range(num_iter): if idx <= num_iter // 5: # Ensure that 20% of the homographies have no artifact H_mat_lst = [ sample_homography([H, W], **homography_cfg_no_artifacts)[0][None] for _ in range(batch_size) ] else: H_mat_lst = [ sample_homography([H, W], **homography_cfg["homographies"])[0][None] for _ in range(batch_size) ] H_mats = np.concatenate(H_mat_lst, axis=0) H_tensor = torch.tensor(H_mats, dtype=torch.float, device=device) H_inv_tensor = torch.inverse(H_tensor) # Perform the homography warp images_warped = warp_perspective( input_images, H_tensor, (H, W), flags="bilinear" ) # Warp the mask masks_junc_warped = warp_perspective( torch.ones([batch_size, 1, H, W], device=device), H_tensor, (H, W), flags="nearest", ) masks_heatmap_warped = warp_perspective( torch.ones([batch_size, 1, H, W], device=device), H_tensor, (H, W), flags="nearest", ) # Run the network forward pass with torch.no_grad(): outputs = model(images_warped) # Unwarp and mask the junction prediction junc_prob_warped = pixel_shuffle( softmax(outputs["junctions"], dim=1)[:, :-1, :, :], grid_size ) junc_prob = warp_perspective( junc_prob_warped, H_inv_tensor, (H, W), flags="bilinear" ) # Create the out of boundary mask out_boundary_mask = warp_perspective( torch.ones([batch_size, 1, H, W], device=device), H_inv_tensor, (H, W), flags="nearest", ) out_boundary_mask = adjust_border(out_boundary_mask, device, margin) junc_prob = junc_prob * out_boundary_mask junc_count = warp_perspective( masks_junc_warped * out_boundary_mask, H_inv_tensor, (H, W), flags="nearest" ) # Unwarp the mask and heatmap prediction # Always fetch only one channel if outputs["heatmap"].shape[1] == 2: # Convert to single channel directly from here heatmap_prob_warped = softmax(outputs["heatmap"], dim=1)[:, 1:, :, :] else: heatmap_prob_warped = torch.sigmoid(outputs["heatmap"]) heatmap_prob_warped = heatmap_prob_warped * masks_heatmap_warped heatmap_prob = warp_perspective( heatmap_prob_warped, H_inv_tensor, (H, W), flags="bilinear" ) heatmap_count = warp_perspective( masks_heatmap_warped, H_inv_tensor, (H, W), flags="nearest" ) # Record the results junc_probs[:, idx : idx + 1, :, :] = junc_prob heatmap_probs[:, idx : idx + 1, :, :] = heatmap_prob junc_counts += junc_count heatmap_counts += heatmap_count # Perform the accumulation operation if homography_cfg["min_counts"] > 0: min_counts = homography_cfg["min_counts"] junc_count_mask = junc_counts < min_counts heatmap_count_mask = heatmap_counts < min_counts junc_counts[junc_count_mask] = 0 heatmap_counts[heatmap_count_mask] = 0 else: junc_count_mask = np.zeros_like(junc_counts, dtype=bool) heatmap_count_mask = np.zeros_like(heatmap_counts, dtype=bool) # Compute the mean accumulation junc_probs_mean = torch.sum(junc_probs, dim=1, keepdim=True) / junc_counts junc_probs_mean[junc_count_mask] = 0.0 heatmap_probs_mean = torch.sum(heatmap_probs, dim=1, keepdim=True) / heatmap_counts heatmap_probs_mean[heatmap_count_mask] = 0.0 # Compute the max accumulation junc_probs_max = torch.max(junc_probs, dim=1, keepdim=True)[0] junc_probs_max[junc_count_mask] = 0.0 heatmap_probs_max = torch.max(heatmap_probs, dim=1, keepdim=True)[0] heatmap_probs_max[heatmap_count_mask] = 0.0 return { "junc_probs_mean": junc_probs_mean, "junc_probs_max": junc_probs_max, "junc_counts": junc_counts, "heatmap_probs_mean": heatmap_probs_mean, "heatmap_probs_max": heatmap_probs_max, "heatmap_counts": heatmap_counts, } def adjust_border(input_masks, device, margin=3): """Adjust the border of the counts and valid_mask.""" # Convert the mask to numpy array dtype = input_masks.dtype input_masks = np.squeeze(input_masks.cpu().numpy(), axis=1) erosion_kernel = cv2.getStructuringElement( cv2.MORPH_ELLIPSE, (margin * 2, margin * 2) ) batch_size = input_masks.shape[0] output_mask_lst = [] # Erode all the masks for i in range(batch_size): output_mask = cv2.erode(input_masks[i, ...], erosion_kernel) output_mask_lst.append( torch.tensor(output_mask, dtype=dtype, device=device)[None] ) # Concat back along the batch dimension. output_masks = torch.cat(output_mask_lst, dim=0) return output_masks.unsqueeze(dim=1)