Jiaye Zou commited on
Commit
50318d8
1 Parent(s): f474bfd

update: gradio app with docker

Browse files
Files changed (7) hide show
  1. Dockerfile +30 -0
  2. README.md +2 -3
  3. app.py +62 -12
  4. config.yaml +36 -0
  5. get_weights.sh +9 -0
  6. mapper/utils/viz_2d.py +43 -13
  7. requirements.txt +23 -0
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.3.1-cuda11.8-cudnn8-runtime
2
+
3
+ # Set working directory
4
+ WORKDIR /mapper
5
+
6
+ # Install dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ git \
9
+ wget \
10
+ unzip \
11
+ vim \
12
+ ffmpeg \
13
+ libsm6 \
14
+ libxext6
15
+
16
+ RUN pip install --no-cache-dir gradio[oauth]==4.36.1 "uvicorn>=0.14.0" spaces
17
+
18
+ COPY . /mapper
19
+
20
+ # Install Python dependencies
21
+ RUN pip install --no-cache-dir -r requirements.txt
22
+
23
+ # Get Weights
24
+ RUN bash get_weights.sh
25
+
26
+ # Clear APT and pip cache
27
+ RUN apt-get clean && rm -rf /var/lib/apt/lists/* && rm -rf /tmp/pip-reqs
28
+
29
+ # Start the app
30
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -3,10 +3,9 @@ title: "Map It Anywhere (MIA): Empowering Bird’s Eye View Mapping using Large-
3
  emoji: 🌍
4
  colorFrom: green
5
  colorTo: blue
6
- sdk: gradio
7
- sdk_version: "4.36.1"
8
- app_file: app.py
9
  pinned: true
 
10
  ---
11
  <p align="center">
12
  <h1 align="center">Map It Anywhere (MIA): Empowering Bird’s Eye View Mapping using Large-scale Public Data</h1>
 
3
  emoji: 🌍
4
  colorFrom: green
5
  colorTo: blue
6
+ sdk: docker
 
 
7
  pinned: true
8
+ app_port: 7860
9
  ---
10
  <p align="center">
11
  <h1 align="center">Map It Anywhere (MIA): Empowering Bird’s Eye View Mapping using Large-scale Public Data</h1>
app.py CHANGED
@@ -3,9 +3,14 @@ from matplotlib import pyplot as plt
3
  from mapper.utils.io import read_image
4
  from mapper.utils.exif import EXIF
5
  from mapper.utils.wrappers import Camera
 
 
 
6
  from perspective2d import PerspectiveFields
 
7
  import numpy as np
8
  from typing import Optional, Tuple
 
9
 
10
  description = """
11
  <h1 align="center">
@@ -24,6 +29,10 @@ Mapper generates birds-eye-view maps from first person view monocular images. Tr
24
  </p>
25
  """
26
 
 
 
 
 
27
  class ImageCalibrator(PerspectiveFields):
28
  def __init__(self, version: str = "Paramnet-360Cities-edina-centered"):
29
  super().__init__(version)
@@ -40,7 +49,6 @@ class ImageCalibrator(PerspectiveFields):
40
  _, focal_ratio = exif.extract_focal()
41
  if focal_ratio != 0:
42
  focal_length = focal_ratio * max(h, w)
43
-
44
  calib = self.inference(img_bgr=image_rgb[..., ::-1])
45
  roll_pitch = (calib["pred_roll"].item(), calib["pred_pitch"].item())
46
  if focal_length is None:
@@ -57,26 +65,67 @@ class ImageCalibrator(PerspectiveFields):
57
  )
58
  return roll_pitch, camera
59
 
60
- def run(input_img):
61
- calibrator = ImageCalibrator().to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  image_path = input_img.name
64
 
65
  image = read_image(image_path)
66
- image = image.to("cuda")
67
  with open(image_path, "rb") as fid:
68
  exif = EXIF(fid, lambda: image.shape[:2])
69
 
70
  gravity, camera = calibrator.run(image, exif=exif)
71
 
72
- print(f"Gravity: {gravity}")
73
- print(f"Camera: {camera._data}")
74
-
75
- plt.imshow(image)
76
- plt.axis('off')
77
  fig1 = plt.gcf()
78
 
79
- return fig1
 
 
 
 
 
 
 
 
80
 
81
  demo = gr.Interface(
82
  fn=run,
@@ -84,7 +133,8 @@ demo = gr.Interface(
84
  gr.File(file_types=["image"], label="Input Image")
85
  ],
86
  outputs=[
87
- gr.Plot(label="Inputs", format="png")
 
88
  ],
89
  description=description,)
90
- demo.launch(share=True)
 
3
  from mapper.utils.io import read_image
4
  from mapper.utils.exif import EXIF
5
  from mapper.utils.wrappers import Camera
6
+ from mapper.data.image import rectify_image, pad_image, resize_image
7
+ from mapper.utils.viz_2d import one_hot_argmax_to_rgb, plot_images
8
+ from mapper.module import GenericModule
9
  from perspective2d import PerspectiveFields
10
+ import torch
11
  import numpy as np
12
  from typing import Optional, Tuple
13
+ from omegaconf import OmegaConf
14
 
15
  description = """
16
  <h1 align="center">
 
29
  </p>
30
  """
31
 
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ cfg = OmegaConf.load("config.yaml")
35
+
36
  class ImageCalibrator(PerspectiveFields):
37
  def __init__(self, version: str = "Paramnet-360Cities-edina-centered"):
38
  super().__init__(version)
 
49
  _, focal_ratio = exif.extract_focal()
50
  if focal_ratio != 0:
51
  focal_length = focal_ratio * max(h, w)
 
52
  calib = self.inference(img_bgr=image_rgb[..., ::-1])
53
  roll_pitch = (calib["pred_roll"].item(), calib["pred_pitch"].item())
54
  if focal_length is None:
 
65
  )
66
  return roll_pitch, camera
67
 
68
+ def preprocess_pipeline(image, roll_pitch, camera):
69
+ image = torch.from_numpy(image).float() / 255
70
+ image = image.permute(2, 0, 1).to(device)
71
+ camera = camera.to(device)
72
+
73
+ image, valid = rectify_image(image, camera.float(), -roll_pitch[0], -roll_pitch[1])
74
+
75
+ roll_pitch *= 0
76
+
77
+ image, _, camera, valid = resize_image(
78
+ image=image,
79
+ size=512,
80
+ camera=camera,
81
+ fn=max,
82
+ valid=valid
83
+ )
84
+
85
+ image, valid, camera = pad_image(
86
+ image, 512, camera, valid
87
+ )
88
 
89
+ camera = torch.stack([camera])
90
+
91
+ return {
92
+ "image": image.unsqueeze(0).to(device),
93
+ "valid": valid.unsqueeze(0).to(device),
94
+ "camera": camera.float().to(device),
95
+ }
96
+
97
+
98
+ calibrator = ImageCalibrator().to(device)
99
+ model = GenericModule(cfg)
100
+ model = model.load_from_checkpoint("trained_weights/mapper-excl-ood.ckpt", strict=False, cfg=cfg)
101
+ model = model.to(device)
102
+ model = model.eval()
103
+
104
+ def run(input_img):
105
  image_path = input_img.name
106
 
107
  image = read_image(image_path)
 
108
  with open(image_path, "rb") as fid:
109
  exif = EXIF(fid, lambda: image.shape[:2])
110
 
111
  gravity, camera = calibrator.run(image, exif=exif)
112
 
113
+ data = preprocess_pipeline(image, gravity, camera)
114
+
115
+ res = model(data)
116
+
117
+ plot_images([image], pad=0., adaptive=True)
118
  fig1 = plt.gcf()
119
 
120
+ prediction = res['output']
121
+ rgb_prediction = one_hot_argmax_to_rgb(prediction, 6).squeeze(0).permute(1, 2, 0).cpu().long().numpy()
122
+ valid = res['valid_bev'].squeeze(0)[..., :-1]
123
+ rgb_prediction[~valid.cpu().numpy()] = 255
124
+
125
+ plot_images([rgb_prediction], pad=0., adaptive=True)
126
+ fig2 = plt.gcf()
127
+
128
+ return fig1, fig2
129
 
130
  demo = gr.Interface(
131
  fn=run,
 
133
  gr.File(file_types=["image"], label="Input Image")
134
  ],
135
  outputs=[
136
+ gr.Plot(label="Inputs", format="png"),
137
+ gr.Plot(label="Outputs", format="png"),
138
  ],
139
  description=description,)
140
+ demo.launch(share=False, server_name="0.0.0.0")
config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ image_encoder:
3
+ backbone:
4
+ pretrained: true
5
+ frozen: true
6
+ output_dim: 128
7
+ name: feature_extractor_DPT
8
+ segmentation_head:
9
+ dropout_rate: 0.2
10
+ name: map_perception_net
11
+ num_classes: 6
12
+ latent_dim: 128
13
+ z_max: 50
14
+ x_max: 25
15
+ pixel_per_meter: 2
16
+ num_scale_bins: 32
17
+ loss:
18
+ num_classes: 6
19
+ xent_weight: 1.0
20
+ dice_weight: 1.0
21
+ focal_loss: false
22
+ focal_loss_gamma: 2.0
23
+ requires_frustrum: true
24
+ requires_flood_mask: false
25
+ class_weights:
26
+ - 1.00351229
27
+ - 4.34782609
28
+ - 1.00110121
29
+ - 1.03124678
30
+ - 6.69792364
31
+ - 7.55857899
32
+ label_smoothing: 0.1
33
+ scale_range:
34
+ - 0
35
+ - 9
36
+ z_min: null
get_weights.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # URL of the file to download
4
+ ood_weights="https://huggingface.co/mapitanywhere/mapper/resolve/main/weights/mapper-excl-ood/model.ckpt"
5
+
6
+ mkdir -p trained_weights
7
+
8
+ # Download the file using curl
9
+ wget $ood_weights -O trained_weights/mapper-excl-ood.ckpt
mapper/utils/viz_2d.py CHANGED
@@ -6,6 +6,7 @@
6
 
7
  import numpy as np
8
  import torch
 
9
 
10
 
11
  def features_to_RGB(*Fs, masks=None, skip=1):
@@ -69,22 +70,18 @@ def one_hot_argmax_to_rgb(y, num_class):
69
 
70
  '''
71
 
72
-
73
  class_colors = {
74
- 'road': (0, 0, 0), # 0: Black
75
- 'crossing': (255, 0, 0), # 1; Red
76
- 'explicit_pedestrian': (255, 255, 0), # 2: Yellow
77
  # 'explicit_void': (128, 128, 128), # 3: White
78
- 'park': (0, 255, 0), # 4: Green
79
- 'building': (255, 0, 255), # 5: Magenta
80
- 'water': (0, 0, 255), # 6: Blue
81
- 'terrain': (0, 255, 255), # 7: Cyan
82
- 'parking': (170, 170, 170), # 8: Dark Grey
83
- 'train': (85, 85, 85) , # 9: Light Grey
84
- 'predicted_void': (256, 256, 256)
85
  }
86
  class_colors = class_colors.values()
87
- class_colors = [torch.tensor(x) for x in class_colors]
88
 
89
  argmaxed = torch.argmax((y > 0.5).float(), dim=1) # Take argmax
90
  argmaxed[torch.all(y <= 0.5, dim=1)] = num_class
@@ -97,10 +94,43 @@ def one_hot_argmax_to_rgb(y, num_class):
97
  argmaxed.shape[1],
98
  argmaxed.shape[2],
99
  )
100
- ) * 256
101
  for i in range(num_class + 1):
102
  seg_rgb[:, 0, :, :][argmaxed == i] = class_colors[i][0]
103
  seg_rgb[:, 1, :, :][argmaxed == i] = class_colors[i][1]
104
  seg_rgb[:, 2, :, :][argmaxed == i] = class_colors[i][2]
105
 
106
  return seg_rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  import numpy as np
8
  import torch
9
+ import matplotlib.pyplot as plt
10
 
11
 
12
  def features_to_RGB(*Fs, masks=None, skip=1):
 
70
 
71
  '''
72
 
 
73
  class_colors = {
74
+ 'road': (68, 68, 68), # 0: Black
75
+ 'crossing': (244, 162, 97), # 1; Red
76
+ 'explicit_pedestrian': (233, 196, 106), # 2: Yellow
77
  # 'explicit_void': (128, 128, 128), # 3: White
78
+ 'building': (231, 111, 81), # 5: Magenta
79
+ 'terrain': (42, 157, 143), # 7: Cyan
80
+ 'parking': (204, 204, 204), # 8: Dark Grey
81
+ 'predicted_void': (255, 255, 255)
 
 
 
82
  }
83
  class_colors = class_colors.values()
84
+ class_colors = [torch.tensor(x).float() for x in class_colors]
85
 
86
  argmaxed = torch.argmax((y > 0.5).float(), dim=1) # Take argmax
87
  argmaxed[torch.all(y <= 0.5, dim=1)] = num_class
 
94
  argmaxed.shape[1],
95
  argmaxed.shape[2],
96
  )
97
+ ) * 255
98
  for i in range(num_class + 1):
99
  seg_rgb[:, 0, :, :][argmaxed == i] = class_colors[i][0]
100
  seg_rgb[:, 1, :, :][argmaxed == i] = class_colors[i][1]
101
  seg_rgb[:, 2, :, :][argmaxed == i] = class_colors[i][2]
102
 
103
  return seg_rgb
104
+
105
+ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
106
+ """Plot a set of images horizontally.
107
+ Args:
108
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
109
+ titles: a list of strings, as titles for each image.
110
+ cmaps: colormaps for monochrome images.
111
+ adaptive: whether the figure size should fit the image aspect ratios.
112
+ """
113
+ n = len(imgs)
114
+ if not isinstance(cmaps, (list, tuple)):
115
+ cmaps = [cmaps] * n
116
+
117
+ if adaptive:
118
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
119
+ else:
120
+ ratios = [4 / 3] * n
121
+ figsize = [sum(ratios) * 4.5, 4.5]
122
+ fig, ax = plt.subplots(
123
+ 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
124
+ )
125
+ if n == 1:
126
+ ax = [ax]
127
+ for i in range(n):
128
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
129
+ ax[i].get_yaxis().set_ticks([])
130
+ ax[i].get_xaxis().set_ticks([])
131
+ ax[i].set_axis_off()
132
+ for spine in ax[i].spines.values(): # remove frame
133
+ spine.set_visible(False)
134
+ if titles:
135
+ ax[i].set_title(titles[i])
136
+ fig.tight_layout(pad=pad)
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ opencv-python
5
+ Pillow
6
+ tqdm>=4.36.0
7
+ matplotlib
8
+ plotly
9
+ scipy
10
+ omegaconf
11
+ pytorch-lightning
12
+ torchmetrics
13
+ lxml
14
+ rtree
15
+ scikit-learn
16
+ geopy
17
+ exifread
18
+ hydra-core
19
+ umsgpack
20
+ nuscenes-devkit
21
+ perspective2d @ git+https://github.com/jinlinyi/PerspectiveFields.git
22
+ urllib3>=2
23
+ wandb