BFZD233
initial
5b3b0f4
from __future__ import print_function, division
import sys
sys.path.insert(0,'core')
sys.path.append('core/utils')
import os
import argparse
import gradio as gr
import cv2
from core.raft_stereo_depthbeta_refine import RAFTStereoDepthBetaRefine
import torch
import torch.nn as nn
from core.utils.utils import InputPadder
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
parser = argparse.ArgumentParser()
parser.add_argument('--root', help="dataset root", default=None)
parser.add_argument('--sv_root', help="visualization root", default=None)
parser.add_argument('--test_exp_name', default='', help="name your experiment in testing")
parser.add_argument('--mast3r_model_path', default='MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth', help="pretrained model path for MaSt3R")
parser.add_argument('--depthany_model_dir', default='./dav2_models', help="directory of pretrained model path for DepthAnything")
parser.add_argument('--restore_ckpt', help="restore checkpoint", default="./ckpts/diving_stereo.pth")
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass')
parser.add_argument('--eval', action='store_true', help='evaluation mode')
parser.add_argument('--is_test', action='store_true', help='on testing')
# Architecure choices
parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
parser.add_argument('--corr_levels', type=int, default=4, help="number of levels in the correlation pyramid")
parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
parser.add_argument('--context_norm', type=str, default="batch", choices=['group', 'batch', 'instance', 'none'], help="normalization of context encoder")
parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
parser.add_argument('--lbp_neighbor_offsets', default='(-5,-5), (5,5), (5,-5), (-5,5), (-3,0), (3,0), (0,-3), (0,3)', help="determine the neighbors used in LBP encoder")
parser.add_argument('--modulation_ratio', type=float, default=1., help="hyperparameters for modulation")
parser.add_argument('--modulation_alg', choices=["linear", "sigmoid"], default="linear", help="rescale modulation")
parser.add_argument('--conf_from_fea', action='store_true', help="confidence in refinement not only from cost volume but also from other features")
parser.add_argument('--refine_pool', action='store_true', help="use pooling in refinement")
parser.add_argument('--refine_unet', action='store_true', help="use EfficientUnet in refinement")
parser.add_argument('--improvement', action='store_true', help="visualize improvement map (error_map[i] - error_map[i-1])")
parser.add_argument('--movement', action='store_true', help="visualize movement map (flow_pr[i] - flow_pr[i-1])")
parser.add_argument('--acceleration', action='store_true', help="visualize acceleration map (movement_map[i] - movement_map[i-1])")
parser.add_argument('--mask', action='store_true', help="visualize mask")
parser.add_argument('--binary_thold', type=float, default=0.5, help="visualize binary mask")
args = parser.parse_args()
args.conf_from_fea = True
args.eval = True
model = RAFTStereoDepthBetaRefine(args)
model = torch.nn.DataParallel(model, device_ids=[0])
checkpoint_path = hf_hub_download(
repo_id="BFZD/Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching",
filename="ckpts/diving_stereo.pth",
)
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
# model.load_state_dict(checkpoint, strict=True)
new_state_dict = {}
for key, value in checkpoint.items():
if key.find("lbp_encoder.lbp_conv") != -1:
continue
new_state_dict[key] = value
# model.load_state_dict(new_state_dict, strict=True)
model.load_state_dict(new_state_dict, strict=False)
# model.cuda()
model.eval()
def predict(image1, image2):
with torch.no_grad():
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
image1 = image1[None][:,:3,:,:]
image2 = image2[None][:,:3,:,:]
padder = InputPadder(image1.shape, divis_by=32)
image1, image2 = padder.pad(image1, image2)
atom_dict = model(image1, image2, iters=args.valid_iters, test_mode=False, vis_mode=True)
output = atom_dict['disp_predictions'][-1].abs().cpu().numpy()
disp = padder.unpad(output)
disp = disp.squeeze()
normalized_disp = (disp - disp.min()) / (disp.max() - disp.min())
cmap = plt.get_cmap('jet')
colored_disp = cmap(normalized_disp)[:, :, :3] # Get RGB channels
return colored_disp
interface = gr.Interface(fn=predict,
inputs=[gr.Image(label="Left Image"),
gr.Image(label="Right Image")],
outputs="image")
interface.launch()