File size: 4,609 Bytes
f474bfd
 
 
 
 
50318d8
 
 
f474bfd
50318d8
f474bfd
 
50318d8
f474bfd
 
 
 
 
 
 
 
 
 
 
 
 
ca01e50
f474bfd
 
 
50318d8
 
 
 
f474bfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50318d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f474bfd
50318d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f474bfd
 
 
 
 
 
 
 
50318d8
 
 
 
 
 
 
dfc3164
 
50318d8
ca01e50
 
 
 
50318d8
ca01e50
dfc3164
 
 
 
 
ca01e50
f474bfd
 
 
 
 
 
 
ca01e50
f474bfd
ca01e50
 
50318d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
from matplotlib import pyplot as plt
from mapper.utils.io import read_image
from mapper.utils.exif import EXIF
from mapper.utils.wrappers import Camera
from mapper.data.image import rectify_image, pad_image, resize_image
from mapper.utils.viz_2d import one_hot_argmax_to_rgb, plot_images
from mapper.module import GenericModule
from perspective2d import PerspectiveFields
import torch
import numpy as np
from typing import Optional, Tuple
from omegaconf import OmegaConf

description = """
<h1 align="center">
  <ins>MapItAnywhere (MIA) </ins>
  <br>
  Empowering Bird’s Eye View Mapping using Large-scale Public Data
  <br>
<h3 align="center">
    <a href="https://mapitanywhere.github.io" target="_blank">Project Page</a> |
    <a href="https://arxiv.org/abs/2109.08203" target="_blank">Paper</a> |
    <a href="https://github.com/MapItAnywhere/MapItAnywhere" target="_blank">Code</a>
</h3>
<p align="center">
Mapper generates birds-eye-view maps from in-the-wild monocular first-person view images. You can try our demo by uploading your images or using the examples provided. Tip: You can also try out images across the world using <a href="https://www.mapillary.com/app" target="_blank">Mapillary</a> &#128521;
</p>
"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cfg = OmegaConf.load("config.yaml")

class ImageCalibrator(PerspectiveFields):
    def __init__(self, version: str = "Paramnet-360Cities-edina-centered"):
        super().__init__(version)
        self.eval()

    def run(
        self,
        image_rgb: np.ndarray,
        focal_length: Optional[float] = None,
        exif: Optional[EXIF] = None,
    ) -> Tuple[Tuple[float, float], Camera]:
        h, w, *_ = image_rgb.shape
        if focal_length is None and exif is not None:
            _, focal_ratio = exif.extract_focal()
            if focal_ratio != 0:
                focal_length = focal_ratio * max(h, w)
        calib = self.inference(img_bgr=image_rgb[..., ::-1])
        roll_pitch = (calib["pred_roll"].item(), calib["pred_pitch"].item())
        if focal_length is None:
            vfov = calib["pred_vfov"].item()
            focal_length = h / 2 / np.tan(np.deg2rad(vfov) / 2)

        camera = Camera.from_dict(
            {
                "model": "SIMPLE_PINHOLE",
                "width": w,
                "height": h,
                "params": [focal_length, w / 2 + 0.5, h / 2 + 0.5],
            }
        )
        return roll_pitch, camera

def preprocess_pipeline(image, roll_pitch, camera):
    image = torch.from_numpy(image).float() / 255
    image = image.permute(2, 0, 1).to(device)
    camera = camera.to(device)

    image, valid = rectify_image(image, camera.float(), -roll_pitch[0], -roll_pitch[1])

    roll_pitch *= 0

    image, _, camera, valid = resize_image(
        image=image,
        size=512,
        camera=camera,
        fn=max,
        valid=valid
    )

    image, valid, camera = pad_image(
        image, 512, camera, valid
    )

    camera = torch.stack([camera])

    return {
        "image": image.unsqueeze(0).to(device),
        "valid": valid.unsqueeze(0).to(device),
        "camera": camera.float().to(device),
    }


calibrator = ImageCalibrator().to(device)
model = GenericModule(cfg)
model = model.load_from_checkpoint("trained_weights/mapper-excl-ood.ckpt", strict=False, cfg=cfg)
model = model.to(device)
model = model.eval()

def run(input_img):
    image_path = input_img.name

    image = read_image(image_path)
    with open(image_path, "rb") as fid:
        exif = EXIF(fid, lambda: image.shape[:2])

    gravity, camera = calibrator.run(image, exif=exif)

    data = preprocess_pipeline(image, gravity, camera)
    res = model(data)
    
    prediction = res['output']
    rgb_prediction = one_hot_argmax_to_rgb(prediction, 6).squeeze(0).permute(1, 2, 0).cpu().long().numpy()
    valid = res['valid_bev'].squeeze(0)[..., :-1]
    rgb_prediction[~valid.cpu().numpy()] = 255
    
     # TODO: add legend here

    plot_images([image, rgb_prediction], titles=["Input Image", "Prediction"], pad=2, adaptive=True)

    return plt.gcf()


examples = [
    ["examples/left_crossing.jpg"],
    ["examples/crossing.jpg"]
    ["examples/two_roads.jpg"],
    ["examples/night_road.jpg"],
    ["examples/night_crossing.jpg"],
]

demo = gr.Interface(
    fn=run,
    inputs=[
        gr.File(file_types=["image"], label="Input Image")
    ], 
    outputs=[
        gr.Plot(label="Prediction", format="png"),
    ],
    description=description,
    examples=examples)
demo.launch(share=False, server_name="0.0.0.0")