import gradio as gr from matplotlib import pyplot as plt from mapper.utils.io import read_image from mapper.utils.exif import EXIF from mapper.utils.wrappers import Camera from mapper.data.image import rectify_image, pad_image, resize_image from mapper.utils.viz_2d import one_hot_argmax_to_rgb, plot_images from mapper.module import GenericModule from perspective2d import PerspectiveFields import torch import numpy as np from typing import Optional, Tuple from omegaconf import OmegaConf description = """
Mapper generates birds-eye-view maps from in-the-wild monocular first-person view images. You can try our demo by uploading your images or using the examples provided. Tip: You can also try out images across the world using Mapillary 😉 Also try out some examples that are taken in cities we have not trained on!
""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cfg = OmegaConf.load("config.yaml") class ImageCalibrator(PerspectiveFields): def __init__(self, version: str = "Paramnet-360Cities-edina-centered"): super().__init__(version) self.eval() def run( self, image_rgb: np.ndarray, focal_length: Optional[float] = None, exif: Optional[EXIF] = None, ) -> Tuple[Tuple[float, float], Camera]: h, w, *_ = image_rgb.shape if focal_length is None and exif is not None: _, focal_ratio = exif.extract_focal() if focal_ratio != 0: focal_length = focal_ratio * max(h, w) calib = self.inference(img_bgr=image_rgb[..., ::-1]) roll_pitch = (calib["pred_roll"].item(), calib["pred_pitch"].item()) if focal_length is None: vfov = calib["pred_vfov"].item() focal_length = h / 2 / np.tan(np.deg2rad(vfov) / 2) camera = Camera.from_dict( { "model": "SIMPLE_PINHOLE", "width": w, "height": h, "params": [focal_length, w / 2 + 0.5, h / 2 + 0.5], } ) return roll_pitch, camera def preprocess_pipeline(image, roll_pitch, camera): image = torch.from_numpy(image).float() / 255 image = image.permute(2, 0, 1).to(device) camera = camera.to(device) image, valid = rectify_image(image, camera.float(), -roll_pitch[0], -roll_pitch[1]) roll_pitch *= 0 image, _, camera, valid = resize_image( image=image, size=512, camera=camera, fn=max, valid=valid ) # image, valid, camera = pad_image( # image, 512, camera, valid # ) camera = torch.stack([camera]) return { "image": image.unsqueeze(0).to(device), "valid": valid.unsqueeze(0).to(device), "camera": camera.float().to(device), } calibrator = ImageCalibrator().to(device) model = GenericModule(cfg) model = model.load_from_checkpoint("trained_weights/mapper-excl-ood.ckpt", strict=False, cfg=cfg) model = model.to(device) model = model.eval() def run(input_img): image_path = input_img.name image = read_image(image_path) with open(image_path, "rb") as fid: exif = EXIF(fid, lambda: image.shape[:2]) gravity, camera = calibrator.run(image, exif=exif) data = preprocess_pipeline(image, gravity, camera) res = model(data) prediction = res['output'] rgb_prediction = one_hot_argmax_to_rgb(prediction, 6).squeeze(0).permute(1, 2, 0).cpu().long().numpy() valid = res['valid_bev'].squeeze(0)[..., :-1] rgb_prediction[~valid.cpu().numpy()] = 255 # TODO: add legend here plot_images([image, rgb_prediction], titles=["Input Image", "Top-Down Prediction"], pad=2, adaptive=True) return plt.gcf() examples = [ ["examples/left_crossing.jpg"], ["examples/crossing.jpg"], ["examples/two_roads.jpg"], ["examples/japan_narrow_road.jpeg"], ["examples/zurich_crossing.jpg"], ["examples/night_road.jpg"], ["examples/night_crossing.jpg"], ] demo = gr.Interface( fn=run, inputs=[ gr.File(file_types=["image"], label="Input Image") ], outputs=[ gr.Plot(label="Prediction", format="png"), ], description=description, examples=examples) demo.launch(share=True, server_name="0.0.0.0")