MatchStereo / gradio_app.py
Tingman's picture
CPU only
4361e3f
import gradio as gr
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import os
import time
import subprocess
from dataloader.stereo import transforms
from utils.utils import InputPadder, calc_noc_mask
from huggingface_hub import hf_hub_download
from models.match_stereo import MatchStereo
torch.backends.cudnn.benchmark = True
class MatchStereoDemo:
def __init__(self):
self.has_cuda = torch.cuda.is_available()
self.device = "cuda" if self.has_cuda else 'cpu'
self.model = None
self.current_variant = None
self.current_mode = None
self.current_precision = None
self.current_mat_impl = None
self.download_model()
def download_model(self):
REPO_ID = 'Tingman/MatchAttention'
filename_list = ['matchstereo_tiny_fsd.pth', 'matchstereo_small_fsd.pth', 'matchstereo_base_fsd.pth', 'matchflow_base_sintel.pth']
if not os.path.exists('./checkpoints/'):
os.makedirs('./checkpoints/')
for filename in filename_list:
local_file = os.path.join('./checkpoints/', filename)
if not os.path.exists(local_file):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/', local_dir_use_symlinks=False)
def load_model(self, mode, variant, precision, mat_impl):
"""load model, skip if the model has been loaded"""
current_has_cuda = torch.cuda.is_available()
if current_has_cuda != self.has_cuda:
print(f"CUDA status changed: {self.has_cuda} -> {current_has_cuda}")
self.has_cuda = current_has_cuda
self.device = "cuda" if self.has_cuda else 'cpu'
if (self.model is not None and
self.current_variant == variant and
self.current_mode == mode and
self.current_precision == precision and
self.current_mat_impl == mat_impl and
self.has_cuda == current_has_cuda):
return "Model already loaded"
# fixed checkpoint path
checkpoint_base_path = "./checkpoints"
if mode == 'stereo':
checkpoint_name = f"match{mode}_{variant}_fsd.pth"
elif mode == 'flow':
checkpoint_name = f"match{mode}_{variant}_sintel.pth"
else:
raise NotImplementedError
checkpoint_path = os.path.join(checkpoint_base_path, checkpoint_name)
if not os.path.exists(checkpoint_path):
return f"Error: Checkpoint not found at {checkpoint_path}"
args = argparse.Namespace()
args.mode = mode
args.variant = variant
args.mat_impl = mat_impl
if not self.has_cuda:
precision = "fp32"
mat_impl = "pytorch"
dtypes = {'fp32': torch.float32, 'fp16': torch.float16}
self.dtype = dtypes[precision]
self.model = MatchStereo(args)
try:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
self.model.load_state_dict(state_dict=checkpoint['model'], strict=False)
self.model.to(self.device)
self.model.eval()
self.model = self.model.to(self.dtype)
self._warmup_model()
self.current_variant = variant
self.current_mode = mode
self.current_precision = precision
self.current_mat_impl = mat_impl
device_info = "GPU" if self.has_cuda else "CPU"
return f"Successfully loaded {mode} {variant} model on {device_info} (precision: {precision}, mat_impl: {mat_impl})"
except Exception as e:
return f"Error loading model: {str(e)}"
def _warmup_model(self):
"""warmup the model for accurate time measurement"""
if self.model is None:
return
dummy_left = torch.randn(1, 3, 256, 256, device=self.device, dtype=self.dtype)
dummy_right = torch.randn(1, 3, 256, 256, device=self.device, dtype=self.dtype)
with torch.no_grad():
_ = self.model(dummy_left, dummy_right, stereo=(self.current_mode == 'stereo'))
def run_frame(self, left, right, stereo, low_res_init=False, factor=2.):
"""single frame inference"""
if low_res_init:
left_ds = F.interpolate(left, scale_factor=1/factor, mode='bilinear', align_corners=True)
right_ds = F.interpolate(right, scale_factor=1/factor, mode='bilinear', align_corners=True)
padder_ds = InputPadder(left_ds.shape, padding_factor=32)
left_ds, right_ds = padder_ds.pad(left_ds, right_ds)
field_up_ds = self.model(left_ds, right_ds, stereo=stereo)['field_up']
field_up_ds = padder_ds.unpad(field_up_ds.permute(0, 3, 1, 2).contiguous()).contiguous()
field_up_init = F.interpolate(field_up_ds, scale_factor=factor/32, mode='bilinear', align_corners=True)*(factor/32)
field_up_init = field_up_init.permute(0, 2, 3, 1).contiguous()
results_dict = self.model(left, right, stereo=stereo, init_flow=field_up_init)
else:
results_dict = self.model(left, right, stereo=stereo)
return results_dict
def get_inference_size(self, size_name):
if size_name == "Original":
return None
def round_to_32(x):
return (x + 16) // 32 * 32
size_presets = {
"720P": (round_to_32(1280), round_to_32(720)),
"1080P": (round_to_32(1920), round_to_32(1080)),
"2K": (round_to_32(2048), round_to_32(1080)),
## "4K UHD": (round_to_32(3840), round_to_32(2160))
}
return size_presets.get(size_name, None)
def process_images(self, left_image, right_image, mode, variant,
low_res_init=False, inference_size_name="Original",
precision="fp32", mat_impl="pytorch"):
current_has_cuda = torch.cuda.is_available()
if current_has_cuda != self.has_cuda:
print(f"CUDA status changed before processing: {self.has_cuda} -> {current_has_cuda}")
self.has_cuda = current_has_cuda
self.device = "cuda" if self.has_cuda else 'cpu'
if not self.has_cuda:
precision = "fp32"
mat_impl = "pytorch"
load_result = self.load_model(mode, variant, precision, mat_impl)
if load_result.startswith("Error"):
return None, None, None, load_result
try:
left = np.array(left_image.convert('RGB')).astype(np.float32)
right = np.array(right_image.convert('RGB')).astype(np.float32)
original_size = left.shape[:2] # (H, W)
inference_size = self.get_inference_size(inference_size_name)
val_transform_list = [transforms.ToTensor(no_normalize=True)]
val_transform = transforms.Compose(val_transform_list)
sample = {'left': left, 'right': right}
sample = val_transform(sample)
left_tensor = sample['left'].to(self.device, dtype=self.dtype).unsqueeze(0)
right_tensor = sample['right'].to(self.device, dtype=self.dtype).unsqueeze(0)
stereo = (mode == 'stereo')
ori_size = left_tensor.shape[-2:]
if inference_size is not None:
left_tensor = F.interpolate(left_tensor, size=inference_size, mode='bilinear', align_corners=True)
right_tensor = F.interpolate(right_tensor, size=inference_size, mode='bilinear', align_corners=True)
padder = None
else:
padder = InputPadder(left_tensor.shape, padding_factor=32)
left_tensor, right_tensor = padder.pad(left_tensor, right_tensor)
device_type = "GPU" if self.has_cuda else "CPU"
actual_size = inference_size if inference_size else ori_size
status_info = f"Device: {device_type} | Resolution: {actual_size[1]}x{actual_size[0]} | Precision: {precision}"
start_time = time.time()
with torch.no_grad():
results_dict = self.run_frame(left_tensor, right_tensor, stereo, low_res_init)
inference_time = (time.time() - start_time) * 1000 # ms
field_up = results_dict['field_up'].permute(0, 3, 1, 2).float().contiguous()
if padder is not None:
field_up = padder.unpad(field_up)
elif inference_size is not None:
field_up = F.interpolate(field_up, size=ori_size, mode='bilinear', align_corners=True)
field_up[:, 0] = field_up[:, 0] * (ori_size[1] / float(inference_size[1]))
field_up[:, 1] = field_up[:, 1] * (ori_size[0] / float(inference_size[0]))
noc_mask = calc_noc_mask(field_up.permute(0, 2, 3, 1), A=8)
noc_mask = noc_mask[0].detach().cpu().numpy()
noc_mask = np.where(noc_mask, 255, 128).astype(np.uint8)
field_up = torch.cat((field_up, torch.zeros_like(field_up[:, :1])), dim=1)
field_up = field_up.permute(0, 2, 3, 1).contiguous()
field, field_r = field_up.chunk(2, dim=0)
if stereo:
disparity = (-field[..., 0]).clamp(min=0)
disparity_np = disparity[0].detach().cpu().numpy()
min_val = disparity_np.min()
max_val = disparity_np.max()
if max_val - min_val > 1e-6:
disparity_norm = (disparity_np - min_val) / (max_val - min_val)
else:
disparity_norm = np.zeros_like(disparity_np)
disparity_img = (disparity_norm * 255).astype(np.uint8)
return disparity_img, noc_mask, f"Inference time: {inference_time:.2f} ms. (Please re-run to get accurate time.)", status_info
else:
flow = field[0].detach().cpu().numpy()
flow_rgb = self.flow_to_color(flow)
return flow_rgb, noc_mask, f"Inference time: {inference_time:.2f} ms. (Please re-run to get accurate time.)", status_info
except Exception as e:
device_type = "GPU" if self.has_cuda else "CPU"
return None, None, f"Error during inference: {str(e)}", f"Device: {device_type} | Error occurred"
def flow_to_color(self, flow):
"""visualization of flow"""
u = flow[..., 0]
v = flow[..., 1]
rad = np.sqrt(u**2 + v**2)
rad_max = np.max(rad)
epsilon = 1e-8
if rad_max > epsilon:
u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon)
h, w = u.shape
hsv = np.zeros((h, w, 3), dtype=np.uint8)
hsv[..., 1] = 255
mag, ang = cv2.cartToPolar(u, v)
hsv[..., 0] = ang * 180 / np.pi / 2
hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
return flow_rgb
demo_model = MatchStereoDemo()
def compile_cuda_extensions():
try:
print("Start compiling CUDA extension...")
current_dir = os.path.dirname(os.path.abspath(__file__))
models_dir = os.path.join(current_dir, "models")
compile_script = os.path.join(models_dir, "compile.sh")
if os.path.exists(compile_script):
original_cwd = os.getcwd()
os.chdir(models_dir)
result = subprocess.run(["bash", "compile.sh"],
capture_output=True, text=True)
os.chdir(original_cwd)
if result.returncode == 0:
print("CUDA extension compile succeed!")
print("output:", result.stdout)
else:
print("CUDA extension compile failed!")
print(result.stderr)
print(result.stdout)
else:
print(f"no compile scripts found: {compile_script}")
except Exception as e:
print(f"Error during compile: {e}")
compile_cuda_extensions()
# example images
examples = [
["examples/staircase_q_left.png", "examples/staircase_q_right.png", "stereo", "tiny"],
["examples/booster_bathroom_left.png", "examples/booster_bathroom_right.png", "stereo", "tiny"],
["examples/frame_0031_clean.png", "examples/frame_0032_clean.png", "flow", "base"],
]
def process_inference(left_img, right_img, mode, variant,
low_res_init, inference_size, precision, mat_impl):
"""Gradio function"""
if left_img is None or right_img is None:
return None, None, "Please upload both left and right images", "Waiting for input..."
try:
result = demo_model.process_images(
left_img, right_img, mode, variant,
low_res_init, inference_size, precision, mat_impl
)
return result
except Exception as e:
return None, None, f"Error during inference: {str(e)}", f"Error: {str(e)}"
def update_variant_choices(mode):
if mode == "flow":
return gr.Radio(choices=["base"], value="base")
else:
return gr.Radio(choices=["tiny", "small", "base"], value="tiny")
# Gradio UI
with gr.Blocks(title="MatchStereo/MatchFlow Demo") as demo:
gr.Markdown("# MatchStereo/MatchFlow Demo")
gr.Markdown("Upload stereo images for disparity estimation or consecutive frames for optical flow estimation.")
current_has_cuda = torch.cuda.is_available()
if not current_has_cuda:
gr.Markdown("> Note: Running on CPU. Some options (fp16, cuda) are disabled.")
else:
gr.Markdown(f"> Note: Running on GPU ({torch.cuda.get_device_name(0)}).")
with gr.Row():
with gr.Column():
left_image = gr.Image(label="Left Image / Frame 1", type="pil")
right_image = gr.Image(label="Right Image / Frame 2", type="pil")
with gr.Row():
mode = gr.Radio(
choices=["stereo", "flow"],
label="Mode",
value="stereo",
info="Select stereo for disparity estimation or flow for optical flow"
)
variant = gr.Radio(
choices=["tiny", "small", "base"],
label="Model Variant",
value="tiny",
info="Model size variant"
)
with gr.Row():
low_res_init = gr.Checkbox(
label="Low Resolution Init",
value=False,
info="Use low-resolution initialization for high-res images (>=2K)"
)
inference_size = gr.Dropdown(
choices=["Original", "720P", "1080P", "2K"],
label="Inference Size",
value="Original",
info="Rounded to multiples of 32"
)
with gr.Row():
precision = gr.Radio(
choices=["fp32", "fp16"],
label="Precision",
value="fp32",
info="Model precision",
interactive=current_has_cuda
)
mat_impl = gr.Radio(
choices=["pytorch", "cuda"],
label="MatchAttention Implementation",
value="pytorch",
info="MatchAttention implementations",
interactive=current_has_cuda
)
run_btn = gr.Button("Run Inference", variant="primary")
with gr.Column():
output_image = gr.Image(label="Output Result", interactive=False)
noc_mask = gr.Image(label="NOC Mask", interactive=False)
time_output = gr.Textbox(label="Inference Time", interactive=False)
status = gr.Textbox(label="Status Info", interactive=False, lines=2)
gr.Markdown("## Examples")
gr.Examples(
examples=examples,
inputs=[left_image, right_image, mode, variant],
outputs=[output_image, noc_mask, time_output, status],
fn=process_inference,
cache_examples=False,
label="Click any example below to load it"
)
run_btn.click(
fn=process_inference,
inputs=[left_image, right_image, mode, variant,
low_res_init, inference_size, precision, mat_impl],
outputs=[output_image, noc_mask, time_output, status]
)
mode.change(
fn=update_variant_choices,
inputs=[mode],
outputs=[variant]
)
if __name__ == "__main__":
try:
import cv2
except ImportError:
print("Please install OpenCV for optical flow visualization: pip install opencv-python")
demo.launch()