File size: 4,609 Bytes
f474bfd 50318d8 f474bfd 50318d8 f474bfd 50318d8 f474bfd ca01e50 f474bfd 50318d8 f474bfd 50318d8 f474bfd 50318d8 f474bfd 50318d8 dfc3164 50318d8 ca01e50 50318d8 ca01e50 dfc3164 ca01e50 f474bfd ca01e50 f474bfd ca01e50 50318d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
from matplotlib import pyplot as plt
from mapper.utils.io import read_image
from mapper.utils.exif import EXIF
from mapper.utils.wrappers import Camera
from mapper.data.image import rectify_image, pad_image, resize_image
from mapper.utils.viz_2d import one_hot_argmax_to_rgb, plot_images
from mapper.module import GenericModule
from perspective2d import PerspectiveFields
import torch
import numpy as np
from typing import Optional, Tuple
from omegaconf import OmegaConf
description = """
<h1 align="center">
<ins>MapItAnywhere (MIA) </ins>
<br>
Empowering Bird’s Eye View Mapping using Large-scale Public Data
<br>
<h3 align="center">
<a href="https://mapitanywhere.github.io" target="_blank">Project Page</a> |
<a href="https://arxiv.org/abs/2109.08203" target="_blank">Paper</a> |
<a href="https://github.com/MapItAnywhere/MapItAnywhere" target="_blank">Code</a>
</h3>
<p align="center">
Mapper generates birds-eye-view maps from in-the-wild monocular first-person view images. You can try our demo by uploading your images or using the examples provided. Tip: You can also try out images across the world using <a href="https://www.mapillary.com/app" target="_blank">Mapillary</a> 😉
</p>
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg = OmegaConf.load("config.yaml")
class ImageCalibrator(PerspectiveFields):
def __init__(self, version: str = "Paramnet-360Cities-edina-centered"):
super().__init__(version)
self.eval()
def run(
self,
image_rgb: np.ndarray,
focal_length: Optional[float] = None,
exif: Optional[EXIF] = None,
) -> Tuple[Tuple[float, float], Camera]:
h, w, *_ = image_rgb.shape
if focal_length is None and exif is not None:
_, focal_ratio = exif.extract_focal()
if focal_ratio != 0:
focal_length = focal_ratio * max(h, w)
calib = self.inference(img_bgr=image_rgb[..., ::-1])
roll_pitch = (calib["pred_roll"].item(), calib["pred_pitch"].item())
if focal_length is None:
vfov = calib["pred_vfov"].item()
focal_length = h / 2 / np.tan(np.deg2rad(vfov) / 2)
camera = Camera.from_dict(
{
"model": "SIMPLE_PINHOLE",
"width": w,
"height": h,
"params": [focal_length, w / 2 + 0.5, h / 2 + 0.5],
}
)
return roll_pitch, camera
def preprocess_pipeline(image, roll_pitch, camera):
image = torch.from_numpy(image).float() / 255
image = image.permute(2, 0, 1).to(device)
camera = camera.to(device)
image, valid = rectify_image(image, camera.float(), -roll_pitch[0], -roll_pitch[1])
roll_pitch *= 0
image, _, camera, valid = resize_image(
image=image,
size=512,
camera=camera,
fn=max,
valid=valid
)
image, valid, camera = pad_image(
image, 512, camera, valid
)
camera = torch.stack([camera])
return {
"image": image.unsqueeze(0).to(device),
"valid": valid.unsqueeze(0).to(device),
"camera": camera.float().to(device),
}
calibrator = ImageCalibrator().to(device)
model = GenericModule(cfg)
model = model.load_from_checkpoint("trained_weights/mapper-excl-ood.ckpt", strict=False, cfg=cfg)
model = model.to(device)
model = model.eval()
def run(input_img):
image_path = input_img.name
image = read_image(image_path)
with open(image_path, "rb") as fid:
exif = EXIF(fid, lambda: image.shape[:2])
gravity, camera = calibrator.run(image, exif=exif)
data = preprocess_pipeline(image, gravity, camera)
res = model(data)
prediction = res['output']
rgb_prediction = one_hot_argmax_to_rgb(prediction, 6).squeeze(0).permute(1, 2, 0).cpu().long().numpy()
valid = res['valid_bev'].squeeze(0)[..., :-1]
rgb_prediction[~valid.cpu().numpy()] = 255
# TODO: add legend here
plot_images([image, rgb_prediction], titles=["Input Image", "Prediction"], pad=2, adaptive=True)
return plt.gcf()
examples = [
["examples/left_crossing.jpg"],
["examples/crossing.jpg"]
["examples/two_roads.jpg"],
["examples/night_road.jpg"],
["examples/night_crossing.jpg"],
]
demo = gr.Interface(
fn=run,
inputs=[
gr.File(file_types=["image"], label="Input Image")
],
outputs=[
gr.Plot(label="Prediction", format="png"),
],
description=description,
examples=examples)
demo.launch(share=False, server_name="0.0.0.0") |