import torch from ..utils.tensor import batch_to_device from .viz2d import cm_RdGn, plot_heatmaps, plot_image_grid, plot_keypoints, plot_matches def make_match_figures(pred_, data_, n_pairs=2): # print first n pairs in batch if "0to1" in pred_.keys(): pred_ = pred_["0to1"] images, kpts, matches, mcolors = [], [], [], [] heatmaps = [] pred = batch_to_device(pred_, "cpu", non_blocking=False) data = batch_to_device(data_, "cpu", non_blocking=False) view0, view1 = data["view0"], data["view1"] n_pairs = min(n_pairs, view0["image"].shape[0]) assert view0["image"].shape[0] >= n_pairs kp0, kp1 = pred["keypoints0"], pred["keypoints1"] m0 = pred["matches0"] gtm0 = pred["gt_matches0"] for i in range(n_pairs): valid = (m0[i] > -1) & (gtm0[i] >= -1) kpm0, kpm1 = kp0[i][valid].numpy(), kp1[i][m0[i][valid]].numpy() images.append( [view0["image"][i].permute(1, 2, 0), view1["image"][i].permute(1, 2, 0)] ) kpts.append([kp0[i], kp1[i]]) matches.append((kpm0, kpm1)) correct = gtm0[i][valid] == m0[i][valid] if "heatmap0" in pred.keys(): heatmaps.append( [ torch.sigmoid(pred["heatmap0"][i, 0]), torch.sigmoid(pred["heatmap1"][i, 0]), ] ) elif "depth" in view0.keys() and view0["depth"] is not None: heatmaps.append([view0["depth"][i], view1["depth"][i]]) mcolors.append(cm_RdGn(correct).tolist()) fig, axes = plot_image_grid(images, return_fig=True, set_lim=True) if len(heatmaps) > 0: [plot_heatmaps(heatmaps[i], axes=axes[i], a=1.0) for i in range(n_pairs)] [plot_keypoints(kpts[i], axes=axes[i], colors="royalblue") for i in range(n_pairs)] [ plot_matches(*matches[i], color=mcolors[i], axes=axes[i], a=0.5, lw=1.0, ps=0.0) for i in range(n_pairs) ] return {"matching": fig}