File size: 4,780 Bytes
f474bfd 50318d8 f474bfd 50318d8 f474bfd 50318d8 f474bfd b684d11 f474bfd 50318d8 f474bfd 50318d8 04e8534 f474bfd 50318d8 f474bfd 50318d8 dfc3164 50318d8 b684d11 ca01e50 50318d8 ca01e50 dfc3164 e46cbce dfc3164 b684d11 dfc3164 ca01e50 f474bfd ca01e50 f474bfd ca01e50 2a7d1a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
from matplotlib import pyplot as plt
from mapper.utils.io import read_image
from mapper.utils.exif import EXIF
from mapper.utils.wrappers import Camera
from mapper.data.image import rectify_image, pad_image, resize_image
from mapper.utils.viz_2d import one_hot_argmax_to_rgb, plot_images
from mapper.module import GenericModule
from perspective2d import PerspectiveFields
import torch
import numpy as np
from typing import Optional, Tuple
from omegaconf import OmegaConf
description = """
<h1 align="center">
<ins>MapItAnywhere (MIA) </ins>
<br>
Empowering Bird’s Eye View Mapping using Large-scale Public Data
<br>
<h3 align="center">
<a href="https://mapitanywhere.github.io" target="_blank">Project Page</a> |
<a href="https://arxiv.org/abs/2109.08203" target="_blank">Paper</a> |
<a href="https://github.com/MapItAnywhere/MapItAnywhere" target="_blank">Code</a>
</h3>
<p align="center">
Mapper generates birds-eye-view maps from in-the-wild monocular first-person view images. You can try our demo by uploading your images or using the examples provided. Tip: You can also try out images across the world using <a href="https://www.mapillary.com/app" target="_blank">Mapillary</a> 😉 Also try out some examples that are taken in cities we have not trained on!
</p>
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg = OmegaConf.load("config.yaml")
class ImageCalibrator(PerspectiveFields):
def __init__(self, version: str = "Paramnet-360Cities-edina-centered"):
super().__init__(version)
self.eval()
def run(
self,
image_rgb: np.ndarray,
focal_length: Optional[float] = None,
exif: Optional[EXIF] = None,
) -> Tuple[Tuple[float, float], Camera]:
h, w, *_ = image_rgb.shape
if focal_length is None and exif is not None:
_, focal_ratio = exif.extract_focal()
if focal_ratio != 0:
focal_length = focal_ratio * max(h, w)
calib = self.inference(img_bgr=image_rgb[..., ::-1])
roll_pitch = (calib["pred_roll"].item(), calib["pred_pitch"].item())
if focal_length is None:
vfov = calib["pred_vfov"].item()
focal_length = h / 2 / np.tan(np.deg2rad(vfov) / 2)
camera = Camera.from_dict(
{
"model": "SIMPLE_PINHOLE",
"width": w,
"height": h,
"params": [focal_length, w / 2 + 0.5, h / 2 + 0.5],
}
)
return roll_pitch, camera
def preprocess_pipeline(image, roll_pitch, camera):
image = torch.from_numpy(image).float() / 255
image = image.permute(2, 0, 1).to(device)
camera = camera.to(device)
image, valid = rectify_image(image, camera.float(), -roll_pitch[0], -roll_pitch[1])
roll_pitch *= 0
image, _, camera, valid = resize_image(
image=image,
size=512,
camera=camera,
fn=max,
valid=valid
)
# image, valid, camera = pad_image(
# image, 512, camera, valid
# )
camera = torch.stack([camera])
return {
"image": image.unsqueeze(0).to(device),
"valid": valid.unsqueeze(0).to(device),
"camera": camera.float().to(device),
}
calibrator = ImageCalibrator().to(device)
model = GenericModule(cfg)
model = model.load_from_checkpoint("trained_weights/mapper-excl-ood.ckpt", strict=False, cfg=cfg)
model = model.to(device)
model = model.eval()
def run(input_img):
image_path = input_img.name
image = read_image(image_path)
with open(image_path, "rb") as fid:
exif = EXIF(fid, lambda: image.shape[:2])
gravity, camera = calibrator.run(image, exif=exif)
data = preprocess_pipeline(image, gravity, camera)
res = model(data)
prediction = res['output']
rgb_prediction = one_hot_argmax_to_rgb(prediction, 6).squeeze(0).permute(1, 2, 0).cpu().long().numpy()
valid = res['valid_bev'].squeeze(0)[..., :-1]
rgb_prediction[~valid.cpu().numpy()] = 255
# TODO: add legend here
plot_images([image, rgb_prediction], titles=["Input Image", "Top-Down Prediction"], pad=2, adaptive=True)
return plt.gcf()
examples = [
["examples/left_crossing.jpg"],
["examples/crossing.jpg"],
["examples/two_roads.jpg"],
["examples/japan_narrow_road.jpeg"],
["examples/zurich_crossing.jpg"],
["examples/night_road.jpg"],
["examples/night_crossing.jpg"],
]
demo = gr.Interface(
fn=run,
inputs=[
gr.File(file_types=["image"], label="Input Image")
],
outputs=[
gr.Plot(label="Prediction", format="png"),
],
description=description,
examples=examples)
demo.launch(share=True, server_name="0.0.0.0")
|