Spaces:
Running
Running
| import gradio as gr | |
| import argparse | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import os | |
| import time | |
| import subprocess | |
| from dataloader.stereo import transforms | |
| from utils.utils import InputPadder, calc_noc_mask | |
| from huggingface_hub import hf_hub_download | |
| from models.match_stereo import MatchStereo | |
| torch.backends.cudnn.benchmark = True | |
| class MatchStereoDemo: | |
| def __init__(self): | |
| self.has_cuda = torch.cuda.is_available() | |
| self.device = "cuda" if self.has_cuda else 'cpu' | |
| self.model = None | |
| self.current_variant = None | |
| self.current_mode = None | |
| self.current_precision = None | |
| self.current_mat_impl = None | |
| self.download_model() | |
| def download_model(self): | |
| REPO_ID = 'Tingman/MatchAttention' | |
| filename_list = ['matchstereo_tiny_fsd.pth', 'matchstereo_small_fsd.pth', 'matchstereo_base_fsd.pth', 'matchflow_base_sintel.pth'] | |
| if not os.path.exists('./checkpoints/'): | |
| os.makedirs('./checkpoints/') | |
| for filename in filename_list: | |
| local_file = os.path.join('./checkpoints/', filename) | |
| if not os.path.exists(local_file): | |
| hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/', local_dir_use_symlinks=False) | |
| def load_model(self, mode, variant, precision, mat_impl): | |
| """load model, skip if the model has been loaded""" | |
| current_has_cuda = torch.cuda.is_available() | |
| if current_has_cuda != self.has_cuda: | |
| print(f"CUDA status changed: {self.has_cuda} -> {current_has_cuda}") | |
| self.has_cuda = current_has_cuda | |
| self.device = "cuda" if self.has_cuda else 'cpu' | |
| if (self.model is not None and | |
| self.current_variant == variant and | |
| self.current_mode == mode and | |
| self.current_precision == precision and | |
| self.current_mat_impl == mat_impl and | |
| self.has_cuda == current_has_cuda): | |
| return "Model already loaded" | |
| # fixed checkpoint path | |
| checkpoint_base_path = "./checkpoints" | |
| if mode == 'stereo': | |
| checkpoint_name = f"match{mode}_{variant}_fsd.pth" | |
| elif mode == 'flow': | |
| checkpoint_name = f"match{mode}_{variant}_sintel.pth" | |
| else: | |
| raise NotImplementedError | |
| checkpoint_path = os.path.join(checkpoint_base_path, checkpoint_name) | |
| if not os.path.exists(checkpoint_path): | |
| return f"Error: Checkpoint not found at {checkpoint_path}" | |
| args = argparse.Namespace() | |
| args.mode = mode | |
| args.variant = variant | |
| args.mat_impl = mat_impl | |
| if not self.has_cuda: | |
| precision = "fp32" | |
| mat_impl = "pytorch" | |
| dtypes = {'fp32': torch.float32, 'fp16': torch.float16} | |
| self.dtype = dtypes[precision] | |
| self.model = MatchStereo(args) | |
| try: | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| self.model.load_state_dict(state_dict=checkpoint['model'], strict=False) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.model = self.model.to(self.dtype) | |
| self._warmup_model() | |
| self.current_variant = variant | |
| self.current_mode = mode | |
| self.current_precision = precision | |
| self.current_mat_impl = mat_impl | |
| device_info = "GPU" if self.has_cuda else "CPU" | |
| return f"Successfully loaded {mode} {variant} model on {device_info} (precision: {precision}, mat_impl: {mat_impl})" | |
| except Exception as e: | |
| return f"Error loading model: {str(e)}" | |
| def _warmup_model(self): | |
| """warmup the model for accurate time measurement""" | |
| if self.model is None: | |
| return | |
| dummy_left = torch.randn(1, 3, 256, 256, device=self.device, dtype=self.dtype) | |
| dummy_right = torch.randn(1, 3, 256, 256, device=self.device, dtype=self.dtype) | |
| with torch.no_grad(): | |
| _ = self.model(dummy_left, dummy_right, stereo=(self.current_mode == 'stereo')) | |
| def run_frame(self, left, right, stereo, low_res_init=False, factor=2.): | |
| """single frame inference""" | |
| if low_res_init: | |
| left_ds = F.interpolate(left, scale_factor=1/factor, mode='bilinear', align_corners=True) | |
| right_ds = F.interpolate(right, scale_factor=1/factor, mode='bilinear', align_corners=True) | |
| padder_ds = InputPadder(left_ds.shape, padding_factor=32) | |
| left_ds, right_ds = padder_ds.pad(left_ds, right_ds) | |
| field_up_ds = self.model(left_ds, right_ds, stereo=stereo)['field_up'] | |
| field_up_ds = padder_ds.unpad(field_up_ds.permute(0, 3, 1, 2).contiguous()).contiguous() | |
| field_up_init = F.interpolate(field_up_ds, scale_factor=factor/32, mode='bilinear', align_corners=True)*(factor/32) | |
| field_up_init = field_up_init.permute(0, 2, 3, 1).contiguous() | |
| results_dict = self.model(left, right, stereo=stereo, init_flow=field_up_init) | |
| else: | |
| results_dict = self.model(left, right, stereo=stereo) | |
| return results_dict | |
| def get_inference_size(self, size_name): | |
| if size_name == "Original": | |
| return None | |
| def round_to_32(x): | |
| return (x + 16) // 32 * 32 | |
| size_presets = { | |
| "720P": (round_to_32(1280), round_to_32(720)), | |
| "1080P": (round_to_32(1920), round_to_32(1080)), | |
| "2K": (round_to_32(2048), round_to_32(1080)), | |
| ## "4K UHD": (round_to_32(3840), round_to_32(2160)) | |
| } | |
| return size_presets.get(size_name, None) | |
| def process_images(self, left_image, right_image, mode, variant, | |
| low_res_init=False, inference_size_name="Original", | |
| precision="fp32", mat_impl="pytorch"): | |
| current_has_cuda = torch.cuda.is_available() | |
| if current_has_cuda != self.has_cuda: | |
| print(f"CUDA status changed before processing: {self.has_cuda} -> {current_has_cuda}") | |
| self.has_cuda = current_has_cuda | |
| self.device = "cuda" if self.has_cuda else 'cpu' | |
| if not self.has_cuda: | |
| precision = "fp32" | |
| mat_impl = "pytorch" | |
| load_result = self.load_model(mode, variant, precision, mat_impl) | |
| if load_result.startswith("Error"): | |
| return None, None, None, load_result | |
| try: | |
| left = np.array(left_image.convert('RGB')).astype(np.float32) | |
| right = np.array(right_image.convert('RGB')).astype(np.float32) | |
| original_size = left.shape[:2] # (H, W) | |
| inference_size = self.get_inference_size(inference_size_name) | |
| val_transform_list = [transforms.ToTensor(no_normalize=True)] | |
| val_transform = transforms.Compose(val_transform_list) | |
| sample = {'left': left, 'right': right} | |
| sample = val_transform(sample) | |
| left_tensor = sample['left'].to(self.device, dtype=self.dtype).unsqueeze(0) | |
| right_tensor = sample['right'].to(self.device, dtype=self.dtype).unsqueeze(0) | |
| stereo = (mode == 'stereo') | |
| ori_size = left_tensor.shape[-2:] | |
| if inference_size is not None: | |
| left_tensor = F.interpolate(left_tensor, size=inference_size, mode='bilinear', align_corners=True) | |
| right_tensor = F.interpolate(right_tensor, size=inference_size, mode='bilinear', align_corners=True) | |
| padder = None | |
| else: | |
| padder = InputPadder(left_tensor.shape, padding_factor=32) | |
| left_tensor, right_tensor = padder.pad(left_tensor, right_tensor) | |
| device_type = "GPU" if self.has_cuda else "CPU" | |
| actual_size = inference_size if inference_size else ori_size | |
| status_info = f"Device: {device_type} | Resolution: {actual_size[1]}x{actual_size[0]} | Precision: {precision}" | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| results_dict = self.run_frame(left_tensor, right_tensor, stereo, low_res_init) | |
| inference_time = (time.time() - start_time) * 1000 # ms | |
| field_up = results_dict['field_up'].permute(0, 3, 1, 2).float().contiguous() | |
| if padder is not None: | |
| field_up = padder.unpad(field_up) | |
| elif inference_size is not None: | |
| field_up = F.interpolate(field_up, size=ori_size, mode='bilinear', align_corners=True) | |
| field_up[:, 0] = field_up[:, 0] * (ori_size[1] / float(inference_size[1])) | |
| field_up[:, 1] = field_up[:, 1] * (ori_size[0] / float(inference_size[0])) | |
| noc_mask = calc_noc_mask(field_up.permute(0, 2, 3, 1), A=8) | |
| noc_mask = noc_mask[0].detach().cpu().numpy() | |
| noc_mask = np.where(noc_mask, 255, 128).astype(np.uint8) | |
| field_up = torch.cat((field_up, torch.zeros_like(field_up[:, :1])), dim=1) | |
| field_up = field_up.permute(0, 2, 3, 1).contiguous() | |
| field, field_r = field_up.chunk(2, dim=0) | |
| if stereo: | |
| disparity = (-field[..., 0]).clamp(min=0) | |
| disparity_np = disparity[0].detach().cpu().numpy() | |
| min_val = disparity_np.min() | |
| max_val = disparity_np.max() | |
| if max_val - min_val > 1e-6: | |
| disparity_norm = (disparity_np - min_val) / (max_val - min_val) | |
| else: | |
| disparity_norm = np.zeros_like(disparity_np) | |
| disparity_img = (disparity_norm * 255).astype(np.uint8) | |
| return disparity_img, noc_mask, f"Inference time: {inference_time:.2f} ms. (Please re-run to get accurate time.)", status_info | |
| else: | |
| flow = field[0].detach().cpu().numpy() | |
| flow_rgb = self.flow_to_color(flow) | |
| return flow_rgb, noc_mask, f"Inference time: {inference_time:.2f} ms. (Please re-run to get accurate time.)", status_info | |
| except Exception as e: | |
| device_type = "GPU" if self.has_cuda else "CPU" | |
| return None, None, f"Error during inference: {str(e)}", f"Device: {device_type} | Error occurred" | |
| def flow_to_color(self, flow): | |
| """visualization of flow""" | |
| u = flow[..., 0] | |
| v = flow[..., 1] | |
| rad = np.sqrt(u**2 + v**2) | |
| rad_max = np.max(rad) | |
| epsilon = 1e-8 | |
| if rad_max > epsilon: | |
| u = u / (rad_max + epsilon) | |
| v = v / (rad_max + epsilon) | |
| h, w = u.shape | |
| hsv = np.zeros((h, w, 3), dtype=np.uint8) | |
| hsv[..., 1] = 255 | |
| mag, ang = cv2.cartToPolar(u, v) | |
| hsv[..., 0] = ang * 180 / np.pi / 2 | |
| hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) | |
| flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) | |
| return flow_rgb | |
| demo_model = MatchStereoDemo() | |
| def compile_cuda_extensions(): | |
| try: | |
| print("Start compiling CUDA extension...") | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| models_dir = os.path.join(current_dir, "models") | |
| compile_script = os.path.join(models_dir, "compile.sh") | |
| if os.path.exists(compile_script): | |
| original_cwd = os.getcwd() | |
| os.chdir(models_dir) | |
| result = subprocess.run(["bash", "compile.sh"], | |
| capture_output=True, text=True) | |
| os.chdir(original_cwd) | |
| if result.returncode == 0: | |
| print("CUDA extension compile succeed!") | |
| print("output:", result.stdout) | |
| else: | |
| print("CUDA extension compile failed!") | |
| print(result.stderr) | |
| print(result.stdout) | |
| else: | |
| print(f"no compile scripts found: {compile_script}") | |
| except Exception as e: | |
| print(f"Error during compile: {e}") | |
| compile_cuda_extensions() | |
| # example images | |
| examples = [ | |
| ["examples/staircase_q_left.png", "examples/staircase_q_right.png", "stereo", "tiny"], | |
| ["examples/booster_bathroom_left.png", "examples/booster_bathroom_right.png", "stereo", "tiny"], | |
| ["examples/frame_0031_clean.png", "examples/frame_0032_clean.png", "flow", "base"], | |
| ] | |
| def process_inference(left_img, right_img, mode, variant, | |
| low_res_init, inference_size, precision, mat_impl): | |
| """Gradio function""" | |
| if left_img is None or right_img is None: | |
| return None, None, "Please upload both left and right images", "Waiting for input..." | |
| try: | |
| result = demo_model.process_images( | |
| left_img, right_img, mode, variant, | |
| low_res_init, inference_size, precision, mat_impl | |
| ) | |
| return result | |
| except Exception as e: | |
| return None, None, f"Error during inference: {str(e)}", f"Error: {str(e)}" | |
| def update_variant_choices(mode): | |
| if mode == "flow": | |
| return gr.Radio(choices=["base"], value="base") | |
| else: | |
| return gr.Radio(choices=["tiny", "small", "base"], value="tiny") | |
| # Gradio UI | |
| with gr.Blocks(title="MatchStereo/MatchFlow Demo") as demo: | |
| gr.Markdown("# MatchStereo/MatchFlow Demo") | |
| gr.Markdown("Upload stereo images for disparity estimation or consecutive frames for optical flow estimation.") | |
| current_has_cuda = torch.cuda.is_available() | |
| if not current_has_cuda: | |
| gr.Markdown("> Note: Running on CPU. Some options (fp16, cuda) are disabled.") | |
| else: | |
| gr.Markdown(f"> Note: Running on GPU ({torch.cuda.get_device_name(0)}).") | |
| with gr.Row(): | |
| with gr.Column(): | |
| left_image = gr.Image(label="Left Image / Frame 1", type="pil") | |
| right_image = gr.Image(label="Right Image / Frame 2", type="pil") | |
| with gr.Row(): | |
| mode = gr.Radio( | |
| choices=["stereo", "flow"], | |
| label="Mode", | |
| value="stereo", | |
| info="Select stereo for disparity estimation or flow for optical flow" | |
| ) | |
| variant = gr.Radio( | |
| choices=["tiny", "small", "base"], | |
| label="Model Variant", | |
| value="tiny", | |
| info="Model size variant" | |
| ) | |
| with gr.Row(): | |
| low_res_init = gr.Checkbox( | |
| label="Low Resolution Init", | |
| value=False, | |
| info="Use low-resolution initialization for high-res images (>=2K)" | |
| ) | |
| inference_size = gr.Dropdown( | |
| choices=["Original", "720P", "1080P", "2K"], | |
| label="Inference Size", | |
| value="Original", | |
| info="Rounded to multiples of 32" | |
| ) | |
| with gr.Row(): | |
| precision = gr.Radio( | |
| choices=["fp32", "fp16"], | |
| label="Precision", | |
| value="fp32", | |
| info="Model precision", | |
| interactive=current_has_cuda | |
| ) | |
| mat_impl = gr.Radio( | |
| choices=["pytorch", "cuda"], | |
| label="MatchAttention Implementation", | |
| value="pytorch", | |
| info="MatchAttention implementations", | |
| interactive=current_has_cuda | |
| ) | |
| run_btn = gr.Button("Run Inference", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Output Result", interactive=False) | |
| noc_mask = gr.Image(label="NOC Mask", interactive=False) | |
| time_output = gr.Textbox(label="Inference Time", interactive=False) | |
| status = gr.Textbox(label="Status Info", interactive=False, lines=2) | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[left_image, right_image, mode, variant], | |
| outputs=[output_image, noc_mask, time_output, status], | |
| fn=process_inference, | |
| cache_examples=False, | |
| label="Click any example below to load it" | |
| ) | |
| run_btn.click( | |
| fn=process_inference, | |
| inputs=[left_image, right_image, mode, variant, | |
| low_res_init, inference_size, precision, mat_impl], | |
| outputs=[output_image, noc_mask, time_output, status] | |
| ) | |
| mode.change( | |
| fn=update_variant_choices, | |
| inputs=[mode], | |
| outputs=[variant] | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| import cv2 | |
| except ImportError: | |
| print("Please install OpenCV for optical flow visualization: pip install opencv-python") | |
| demo.launch() |