hzxie commited on
Commit
79df973
1 Parent(s): 085810c

feat: citydreamer inference (bugs to be fixed).

Browse files
app.py CHANGED
@@ -4,80 +4,90 @@
4
  # @Author: Haozhe Xie
5
  # @Date: 2024-03-02 16:30:00
6
  # @Last Modified by: Haozhe Xie
7
- # @Last Modified at: 2024-03-03 10:39:25
8
  # @Email: root@haozhexie.com
9
 
 
10
  import logging
 
11
  import os
12
- import torch
13
- import gradio as gr
14
- import subprocess
15
- import urllib.request
16
  import ssl
 
17
  import sys
 
 
 
 
18
 
19
  # Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
20
  ssl._create_default_https_context = ssl._create_unverified_context
21
-
22
- sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))
23
  # Import CityDreamer modules
24
- # import citydreamer.model
25
- # import citydreamer.inference
26
 
27
 
28
  def setup_runtime_env():
29
- subprocess.call(["pip", "freeze"])
 
 
30
  ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
31
  for e in os.listdir(ext_dir):
32
- if not os.path.isdir(e):
33
  continue
34
- subprocess.call(["pip", "install", "."], workdir=os.path.join(ext_dir, e))
35
 
 
36
 
37
- def get_models():
38
- if not os.path.exists("CityDreamer-Fgnd.pth"):
39
- urllib.request.urlretrieve(
40
- "https://huggingface.co/hzxie/city-dreamer/resolve/main/CityDreamer-Fgnd.pth",
41
- "CityDreamer-Fgnd.pth",
42
- )
43
- if not os.path.exists("CityDreamer-Bgnd.pth"):
44
  urllib.request.urlretrieve(
45
- "https://huggingface.co/hzxie/city-dreamer/resolve/main/CityDreamer-Bgnd.pth",
46
- "CityDreamer-Bgnd.pth",
47
  )
48
 
49
- bgm_ckpt = torch.load("CityDreamer-Bgnd.pth")
50
- fgm_ckpt = torch.load("CityDreamer-Fgnd.pth")
51
- bgm = citydreamer.model.GanCraftGenerator(bgm_ckpt["cfg"])
52
- fgm = citydreamer.model.GanCraftGenerator(fgm_ckpt["cfg"])
53
  if torch.cuda.is_available():
54
- fgm = torch.nn.DataParallel(fgm).cuda().eval()
55
- bgm = torch.nn.DataParallel(bgm).cuda().eval()
 
56
 
57
- return bgm, fgm
 
 
 
 
58
 
59
 
60
  def get_generated_city(radius, altitude, azimuth):
61
- print(radius, altitude, azimuth)
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  def main(debug):
65
  title = "CityDreamer Demo 🏙️"
66
  with open("README.md", "r") as f:
67
  markdown = f.read()
68
- desc = markdown[markdown.rfind("---") + 3:]
69
  with open("ARTICLE.md", "r") as f:
70
  arti = f.read()
71
 
72
  app = gr.Interface(
73
  get_generated_city,
74
  [
75
- gr.Slider(
76
- 128, 512, value=320, step=5, label="Camera Radius (m)"
77
- ),
78
- gr.Slider(
79
- 256, 512, value=384, step=5, label="Camera Altitude (m)"
80
- ),
81
  gr.Slider(0, 360, value=180, step=5, label="Camera Azimuth (°)"),
82
  ],
83
  [gr.Image(type="numpy", label="Generated City")],
@@ -94,9 +104,19 @@ if __name__ == "__main__":
94
  logging.basicConfig(
95
  format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
96
  )
97
- logging.info("Compile CUDA extensions...")
98
  # setup_runtime_env()
 
99
  logging.info("Downloading pretrained models...")
100
- # fgm, bgm = get_models()
 
 
 
 
 
 
 
 
 
101
  logging.info("Starting the main application...")
102
  main(os.getenv("DEBUG") == "1")
 
4
  # @Author: Haozhe Xie
5
  # @Date: 2024-03-02 16:30:00
6
  # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-03-03 12:02:23
8
  # @Email: root@haozhexie.com
9
 
10
+ import gradio as gr
11
  import logging
12
+ import numpy as np
13
  import os
 
 
 
 
14
  import ssl
15
+ import subprocess
16
  import sys
17
+ import torch
18
+ import urllib.request
19
+
20
+ from PIL import Image
21
 
22
  # Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
23
  ssl._create_default_https_context = ssl._create_unverified_context
 
 
24
  # Import CityDreamer modules
25
+ sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))
 
26
 
27
 
28
  def setup_runtime_env():
29
+ logging.info("CUDA version is %s" % subprocess.check_output(["nvcc", "--version"]))
30
+ logging.info("GCC version is %s" % subprocess.check_output(["g++", "--version"]))
31
+ # Compile CUDA extensions
32
  ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
33
  for e in os.listdir(ext_dir):
34
+ if not os.path.isdir(os.path.join(ext_dir, e)):
35
  continue
 
36
 
37
+ subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e))
38
 
39
+
40
+ def get_models(file_name):
41
+ import citydreamer.model
42
+
43
+ if not os.path.exists(file_name):
 
 
44
  urllib.request.urlretrieve(
45
+ "https://huggingface.co/hzxie/city-dreamer/resolve/main/%s" % file_name,
46
+ file_name,
47
  )
48
 
49
+ ckpt = torch.load(file_name)
50
+ model = citydreamer.model.GanCraftGenerator(ckpt["cfg"])
 
 
51
  if torch.cuda.is_available():
52
+ model = torch.nn.DataParallel(model).cuda().eval()
53
+
54
+ return model
55
 
56
+
57
+ def get_city_layout():
58
+ hf = np.array(Image.open("assets/NYC-HghtFld.png"))
59
+ seg = np.array(Image.open("assets/NYC-SegMap.png").convert("P"))
60
+ return hf, seg
61
 
62
 
63
  def get_generated_city(radius, altitude, azimuth):
64
+ # The import must be done after CUDA extension compilation
65
+ import citydreamer.inference
66
+
67
+ return citydreamer.inference.generate_city(
68
+ get_generated_city.fgm,
69
+ get_generated_city.bgm,
70
+ get_generated_city.hf,
71
+ get_generated_city.seg,
72
+ radius,
73
+ altitude,
74
+ azimuth,
75
+ )
76
 
77
 
78
  def main(debug):
79
  title = "CityDreamer Demo 🏙️"
80
  with open("README.md", "r") as f:
81
  markdown = f.read()
82
+ desc = markdown[markdown.rfind("---") + 3 :]
83
  with open("ARTICLE.md", "r") as f:
84
  arti = f.read()
85
 
86
  app = gr.Interface(
87
  get_generated_city,
88
  [
89
+ gr.Slider(128, 512, value=320, step=5, label="Camera Radius (m)"),
90
+ gr.Slider(256, 512, value=384, step=5, label="Camera Altitude (m)"),
 
 
 
 
91
  gr.Slider(0, 360, value=180, step=5, label="Camera Azimuth (°)"),
92
  ],
93
  [gr.Image(type="numpy", label="Generated City")],
 
104
  logging.basicConfig(
105
  format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
106
  )
107
+ logging.info("Compiling CUDA extensions...")
108
  # setup_runtime_env()
109
+
110
  logging.info("Downloading pretrained models...")
111
+ fgm = get_models("CityDreamer-Fgnd.pth")
112
+ bgm = get_models("CityDreamer-Bgnd.pth")
113
+ get_generated_city.fgm = fgm
114
+ get_generated_city.bgm = bgm
115
+
116
+ logging.info("Loading New York city layout to RAM...")
117
+ hf, seg = get_city_layout()
118
+ get_generated_city.hf = hf
119
+ get_generated_city.seg = seg
120
+
121
  logging.info("Starting the main application...")
122
  main(os.getenv("DEBUG") == "1")
citydreamer/__init__.py ADDED
File without changes
citydreamer/extensions/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: __init__.py
4
+ # @Author: Haozhe Xie
5
+ # @Date: 2023-03-24 20:23:53
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2023-03-24 20:23:55
8
+ # @Email: root@haozhexie.com
citydreamer/extensions/extrude_tensor/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: __init__.py
4
+ # @Author: Haozhe Xie
5
+ # @Date: 2023-03-24 20:24:38
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2023-06-16 09:55:58
8
+ # @Email: root@haozhexie.com
9
+
10
+ import torch
11
+
12
+ import extrude_tensor_ext
13
+
14
+
15
+ class TensorExtruder(torch.nn.Module):
16
+ def __init__(self, max_height=256):
17
+ super(TensorExtruder, self).__init__()
18
+ self.max_height = max_height
19
+
20
+ def forward(self, seg_map, height_field):
21
+ assert torch.max(height_field) < self.max_height, "Max Value %d" % torch.max(
22
+ height_field
23
+ )
24
+ return ExtrudeTensorFunction.apply(seg_map, height_field, self.max_height)
25
+
26
+
27
+ class ExtrudeTensorFunction(torch.autograd.Function):
28
+ @staticmethod
29
+ def forward(ctx, seg_map, height_field, max_height):
30
+ # seg_map.shape: (B, C, H, W)
31
+ # height_field.shape: (B, C, H, W)
32
+ return extrude_tensor_ext.forward(seg_map, height_field, max_height)
33
+
34
+ @staticmethod
35
+ def backward(ctx, grad_volume):
36
+ # grad_volume.shape: (B, C, H, W, D)
37
+ # Combine the gradients along the Z-axis.
38
+ grad_seg_map = torch.sum(grad_volume, dim=4)
39
+ grad_height_field = grad_seg_map
40
+ return grad_seg_map, grad_height_field
citydreamer/extensions/extrude_tensor/bindings.cpp ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @File: extrude_tensor_ext_cuda.cpp
3
+ * @Author: Haozhe Xie
4
+ * @Date: 2023-03-26 11:06:13
5
+ * @Last Modified by: Haozhe Xie
6
+ * @Last Modified at: 2023-03-26 16:28:20
7
+ * @Email: root@haozhexie.com
8
+ */
9
+
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <torch/extension.h>
12
+
13
+ // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
14
+ #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
15
+ #define CHECK_CONTIGUOUS(x) \
16
+ AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
17
+ #define CHECK_INPUT(x) \
18
+ CHECK_CUDA(x); \
19
+ CHECK_CONTIGUOUS(x)
20
+
21
+ torch::Tensor extrude_tensor_ext_cuda_forward(torch::Tensor seg_map,
22
+ torch::Tensor height_field,
23
+ int max_height,
24
+ cudaStream_t stream);
25
+
26
+ torch::Tensor extrude_tensor_ext_forward(torch::Tensor seg_map,
27
+ torch::Tensor height_field,
28
+ int max_height) {
29
+ CHECK_INPUT(seg_map);
30
+ CHECK_INPUT(height_field);
31
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
32
+ return extrude_tensor_ext_cuda_forward(seg_map, height_field, max_height,
33
+ stream);
34
+ }
35
+
36
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
37
+ m.def("forward", &extrude_tensor_ext_forward,
38
+ "Extrude Tensor Ext. Forward (CUDA)");
39
+ }
citydreamer/extensions/extrude_tensor/extrude_tensor_ext.cu ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @File: extrude_tensor_ext.cu
3
+ * @Author: Haozhe Xie
4
+ * @Date: 2023-03-26 11:06:18
5
+ * @Last Modified by: Haozhe Xie
6
+ * @Last Modified at: 2023-05-03 14:55:01
7
+ * @Email: root@haozhexie.com
8
+ */
9
+
10
+ #include <cmath>
11
+ #include <cstdio>
12
+ #include <cstdlib>
13
+ #include <torch/extension.h>
14
+
15
+ #define CUDA_NUM_THREADS 512
16
+
17
+ // Computer the number of threads needed in GPU
18
+ inline int get_n_threads(int n) {
19
+ const int pow_2 = std::log(static_cast<float>(n)) / std::log(2.0);
20
+ return max(min(1 << pow_2, CUDA_NUM_THREADS), 1);
21
+ }
22
+
23
+ __global__ void extrude_tensor_ext_cuda_kernel(
24
+ int height, int width, int depth, const int *__restrict__ seg_map,
25
+ const int *__restrict__ height_field, int *__restrict__ volume) {
26
+ int batch_index = blockIdx.x;
27
+ int index = threadIdx.x;
28
+ int stride = blockDim.x;
29
+
30
+ seg_map += batch_index * height * width;
31
+ height_field += batch_index * height * width;
32
+ volume += batch_index * height * width * depth;
33
+ for (int i = index; i < height; i += stride) {
34
+ int offset_2d_r = i * width, offset_3d_r = i * width * depth;
35
+ for (int j = 0; j < width; ++j) {
36
+ int offset_2d_c = offset_2d_r + j, offset_3d_c = offset_3d_r + j * depth;
37
+ int seg = seg_map[offset_2d_c];
38
+ int hf = height_field[offset_2d_c];
39
+ for (int k = 0; k < hf + 1; ++k) {
40
+ volume[offset_3d_c + k] = seg;
41
+ }
42
+ }
43
+ }
44
+ }
45
+
46
+ torch::Tensor extrude_tensor_ext_cuda_forward(torch::Tensor seg_map,
47
+ torch::Tensor height_field,
48
+ int max_height,
49
+ cudaStream_t stream) {
50
+ int batch_size = seg_map.size(0);
51
+ int height = seg_map.size(2);
52
+ int width = seg_map.size(3);
53
+ torch::Tensor volume = torch::zeros({batch_size, height, width, max_height},
54
+ torch::CUDA(torch::kInt32));
55
+
56
+ extrude_tensor_ext_cuda_kernel<<<
57
+ batch_size, int(CUDA_NUM_THREADS / CUDA_NUM_THREADS), 0, stream>>>(
58
+ height, width, max_height, seg_map.data_ptr<int>(),
59
+ height_field.data_ptr<int>(), volume.data_ptr<int>());
60
+
61
+ cudaError_t err = cudaGetLastError();
62
+ if (err != cudaSuccess) {
63
+ printf("Error in extrude_tensor_ext_cuda_forward: %s\n",
64
+ cudaGetErrorString(err));
65
+ }
66
+ return volume;
67
+ }
citydreamer/extensions/extrude_tensor/setup.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: setup.py
4
+ # @Author: Haozhe Xie
5
+ # @Date: 2023-03-24 20:35:43
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2023-04-29 10:47:30
8
+ # @Email: root@haozhexie.com
9
+
10
+ from setuptools import setup
11
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
12
+
13
+ setup(
14
+ name="extrude_tensor",
15
+ version="1.0.0",
16
+ ext_modules=[
17
+ CUDAExtension(
18
+ "extrude_tensor_ext",
19
+ [
20
+ "bindings.cpp",
21
+ "extrude_tensor_ext.cu",
22
+ ],
23
+ ),
24
+ ],
25
+ cmdclass={"build_ext": BuildExtension},
26
+ )
citydreamer/extensions/extrude_tensor/test.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: test.py
4
+ # @Author: Haozhe Xie
5
+ # @Date: 2023-03-26 19:23:26
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2023-04-15 10:47:53
8
+ # @Email: root@haozhexie.com
9
+
10
+ # Mayavi off screen rendering
11
+ # Ref: https://github.com/enthought/mayavi/issues/477#issuecomment-477653210
12
+ from xvfbwrapper import Xvfb
13
+
14
+ vdisplay = Xvfb(width=1920, height=1080)
15
+ vdisplay.start()
16
+
17
+ import logging
18
+ import mayavi.mlab
19
+ import numpy as np
20
+ import os
21
+ import sys
22
+ import torch
23
+ import unittest
24
+
25
+ from PIL import Image
26
+ from torch.autograd import gradcheck
27
+
28
+ sys.path.append(
29
+ os.path.abspath(
30
+ os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)
31
+ )
32
+ )
33
+ from extensions.extrude_tensor import ExtrudeTensorFunction
34
+
35
+ # Disable the warning message for PIL decompression bomb
36
+ # Ref: https://stackoverflow.com/questions/25705773/image-cropping-tool-python
37
+ Image.MAX_IMAGE_PIXELS = None
38
+
39
+
40
+ class ExtrudeTensorTestCase(unittest.TestCase):
41
+ @unittest.skip("The CUDA extension is compiled with int types by default.")
42
+ def test_extrude_tensor_grad(self):
43
+ # To run this test, make sure that the int types are replaced by double types in CUDA
44
+ SIZE = 16
45
+ seg_map = (
46
+ torch.randint(low=1, high=7, size=(SIZE, SIZE))
47
+ .double()
48
+ .unsqueeze(dim=0)
49
+ .unsqueeze(dim=0)
50
+ )
51
+ height_field = (
52
+ torch.randint(low=0, high=255, size=(SIZE, SIZE))
53
+ .double()
54
+ .unsqueeze(dim=0)
55
+ .unsqueeze(dim=0)
56
+ )
57
+ logging.debug("SegMap Size: %s" % (seg_map.size(),))
58
+ logging.debug("HeightField Size: %s" % (height_field.size(),))
59
+ seg_map.requires_grad = True
60
+ height_field.requires_grad = True
61
+ logging.info(
62
+ "Gradient Check: %s" % "OK"
63
+ if gradcheck(
64
+ ExtrudeTensorFunction.apply, [seg_map.cuda(), height_field.cuda(), 256]
65
+ )
66
+ else "Failed"
67
+ )
68
+
69
+ def test_extrude_tensor_gen(self):
70
+ MAX_HEIGHT = 256
71
+ proj_home_dir = os.path.join(
72
+ os.path.dirname(__file__), os.path.pardir, os.path.pardir
73
+ )
74
+ osm_data_dir = os.path.join(proj_home_dir, "data", "osm")
75
+ osm_name = "US-NewYork"
76
+ seg_map = Image.open(os.path.join(osm_data_dir, osm_name, "seg.png")).convert(
77
+ "P"
78
+ )
79
+ height_field = Image.open(os.path.join(osm_data_dir, osm_name, "hf.png"))
80
+ # Crop the maps
81
+ seg_map = np.array(seg_map)[3840:4096, 3840:4096]
82
+ height_field = np.array(height_field)[3840:4096, 3840:4096]
83
+ # Convert to tensors
84
+ seg_map_tnsr = (
85
+ torch.from_numpy(seg_map).unsqueeze(dim=0).unsqueeze(dim=0).int().cuda()
86
+ )
87
+ height_field_tnsr = (
88
+ torch.from_numpy(height_field)
89
+ .unsqueeze(dim=0)
90
+ .unsqueeze(dim=0)
91
+ .int()
92
+ .cuda()
93
+ )
94
+ volume = ExtrudeTensorFunction.apply(
95
+ seg_map_tnsr, height_field_tnsr, MAX_HEIGHT
96
+ )
97
+ # 3D Visualization
98
+ vol = volume.squeeze().cpu().numpy().astype(np.uint8)
99
+
100
+ x, y, z = np.where(vol != 0)
101
+ n_pts = len(x)
102
+ colors = np.zeros((n_pts, 4), dtype=np.uint8)
103
+ # fmt: off
104
+ colors[vol[x, y, z] == 1] = [96, 0, 0, 255] # highway -> red
105
+ colors[vol[x, y, z] == 2] = [96, 96, 0, 255] # building -> yellow
106
+ colors[vol[x, y, z] == 3] = [0, 96, 0, 255] # green lands -> green
107
+ colors[vol[x, y, z] == 4] = [0, 96, 96, 255] # construction -> cyan
108
+ colors[vol[x, y, z] == 5] = [0, 0, 96, 255] # water -> blue
109
+ colors[vol[x, y, z] == 6] = [128, 128, 128, 255] # ground -> gray
110
+ # fmt: on
111
+ mayavi.mlab.options.offscreen = True
112
+ mayavi.mlab.figure(size=(1600, 900), bgcolor=(1, 1, 1))
113
+ pts = mayavi.mlab.points3d(x, y, z, mode="cube", scale_factor=1)
114
+ pts.glyph.scale_mode = "scale_by_vector"
115
+ pts.mlab_source.dataset.point_data.scalars = colors
116
+ mayavi.mlab.savefig(os.path.join(proj_home_dir, "logs", "%s-3d.jpg" % osm_name))
117
+
118
+
119
+ if __name__ == "__main__":
120
+ logging.basicConfig(
121
+ format="[%(levelname)s] %(asctime)s %(message)s",
122
+ level=logging.INFO,
123
+ )
124
+ unittest.main()
citydreamer/extensions/grid_encoder/__init__.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: __init__.py
4
+ # @Author: Jiaxiang Tang (@ashawkey)
5
+ # @Date: 2023-04-15 10:39:28
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2023-04-15 13:08:46
8
+ # @Email: ashawkey1999@gmail.com
9
+ # @Ref: https://github.com/ashawkey/torch-ngp
10
+
11
+ import math
12
+ import numpy as np
13
+ import torch
14
+
15
+ import grid_encoder_ext
16
+
17
+
18
+ class GridEncoderFunction(torch.autograd.Function):
19
+ @staticmethod
20
+ def forward(
21
+ ctx,
22
+ inputs,
23
+ embeddings,
24
+ offsets,
25
+ per_level_scale,
26
+ base_resolution,
27
+ calc_grad_inputs=False,
28
+ gridtype=0,
29
+ align_corners=False,
30
+ ):
31
+ # inputs: [B, D], float in [0, 1]
32
+ # embeddings: [sO, C], float
33
+ # offsets: [L + 1], int
34
+ # RETURN: [B, F], float
35
+ inputs = inputs.contiguous()
36
+ # batch size, coord dim
37
+ B, D = inputs.shape
38
+ # level
39
+ L = offsets.shape[0] - 1
40
+ # embedding dim for each level
41
+ C = embeddings.shape[1]
42
+ # resolution multiplier at each level, apply log2 for later CUDA exp2f
43
+ S = math.log2(per_level_scale)
44
+ # base resolution
45
+ H = base_resolution
46
+ # L first, optimize cache for cuda kernel, but needs an extra permute later
47
+ outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
48
+
49
+ if calc_grad_inputs:
50
+ dy_dx = torch.empty(
51
+ B, L * D * C, device=inputs.device, dtype=embeddings.dtype
52
+ )
53
+ else:
54
+ dy_dx = torch.empty(
55
+ 1, device=inputs.device, dtype=embeddings.dtype
56
+ ) # placeholder... TODO: a better way?
57
+
58
+ grid_encoder_ext.forward(
59
+ inputs,
60
+ embeddings,
61
+ offsets,
62
+ outputs,
63
+ B,
64
+ D,
65
+ C,
66
+ L,
67
+ S,
68
+ H,
69
+ calc_grad_inputs,
70
+ dy_dx,
71
+ gridtype,
72
+ align_corners,
73
+ )
74
+ # permute back to [B, L * C]
75
+ outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
76
+ ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
77
+ ctx.dims = [B, D, C, L, S, H, gridtype]
78
+ ctx.calc_grad_inputs = calc_grad_inputs
79
+ ctx.align_corners = align_corners
80
+
81
+ return outputs
82
+
83
+ @staticmethod
84
+ def backward(ctx, grad):
85
+ inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
86
+ B, D, C, L, S, H, gridtype = ctx.dims
87
+ calc_grad_inputs = ctx.calc_grad_inputs
88
+ align_corners = ctx.align_corners
89
+
90
+ # grad: [B, L * C] --> [L, B, C]
91
+ grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
92
+ grad_embeddings = torch.zeros_like(embeddings)
93
+
94
+ if calc_grad_inputs:
95
+ grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
96
+ else:
97
+ grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype)
98
+
99
+ grid_encoder_ext.backward(
100
+ grad,
101
+ inputs,
102
+ embeddings,
103
+ offsets,
104
+ grad_embeddings,
105
+ B,
106
+ D,
107
+ C,
108
+ L,
109
+ S,
110
+ H,
111
+ calc_grad_inputs,
112
+ dy_dx,
113
+ grad_inputs,
114
+ gridtype,
115
+ align_corners,
116
+ )
117
+
118
+ if calc_grad_inputs:
119
+ grad_inputs = grad_inputs.to(inputs.dtype)
120
+ return grad_inputs, grad_embeddings, None, None, None, None, None, None
121
+ else:
122
+ return None, grad_embeddings, None, None, None, None, None, None
123
+
124
+
125
+ class GridEncoder(torch.nn.Module):
126
+ def __init__(
127
+ self,
128
+ in_channels,
129
+ n_levels,
130
+ lvl_channels,
131
+ desired_resolution,
132
+ per_level_scale=2,
133
+ base_resolution=16,
134
+ log2_hashmap_size=19,
135
+ gridtype="hash",
136
+ align_corners=False,
137
+ ):
138
+ super(GridEncoder, self).__init__()
139
+ self.in_channels = in_channels
140
+ self.n_levels = n_levels # num levels, each level multiply resolution by 2
141
+ self.lvl_channels = lvl_channels # encode channels per level
142
+ self.per_level_scale = 2 ** (
143
+ math.log2(desired_resolution / base_resolution) / (n_levels - 1)
144
+ )
145
+ self.log2_hashmap_size = log2_hashmap_size
146
+ self.base_resolution = base_resolution
147
+ self.output_dim = n_levels * lvl_channels
148
+ self.gridtype = gridtype
149
+ self.gridtype_id = 0 if gridtype == "hash" else 1
150
+ self.align_corners = align_corners
151
+
152
+ # allocate parameters
153
+ offsets = []
154
+ offset = 0
155
+ self.max_params = 2**log2_hashmap_size
156
+ for i in range(n_levels):
157
+ resolution = int(math.ceil(base_resolution * per_level_scale**i))
158
+ params_in_level = min(
159
+ self.max_params,
160
+ (resolution if align_corners else resolution + 1) ** in_channels,
161
+ ) # limit max number
162
+ params_in_level = int(math.ceil(params_in_level / 8) * 8) # make divisible
163
+ offsets.append(offset)
164
+ offset += params_in_level
165
+
166
+ offsets.append(offset)
167
+ offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
168
+ self.register_buffer("offsets", offsets)
169
+
170
+ self.n_params = offsets[-1] * lvl_channels
171
+ self.embeddings = torch.nn.Parameter(torch.empty(offset, lvl_channels))
172
+ self._init_weights()
173
+
174
+ def _init_weights(self):
175
+ self.embeddings.data.uniform_(-1e-4, 1e-4)
176
+
177
+ def forward(self, inputs, bound=1):
178
+ # inputs: [..., in_channels], normalized real world positions in [-bound, bound]
179
+ # return: [..., n_levels * lvl_channels]
180
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
181
+ prefix_shape = list(inputs.shape[:-1])
182
+ inputs = inputs.view(-1, self.in_channels)
183
+ outputs = GridEncoderFunction.apply(
184
+ inputs,
185
+ self.embeddings,
186
+ self.offsets,
187
+ self.per_level_scale,
188
+ self.base_resolution,
189
+ inputs.requires_grad,
190
+ self.gridtype_id,
191
+ self.align_corners,
192
+ )
193
+ return outputs.view(prefix_shape + [self.output_dim])
citydreamer/extensions/grid_encoder/bindings.cpp ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @File: grid_encoder_ext_cuda.cpp
3
+ * @Author: Jiaxiang Tang (@ashawkey)
4
+ * @Date: 2023-04-15 10:39:17
5
+ * @Last Modified by: Haozhe Xie
6
+ * @Last Modified at: 2023-04-15 11:01:32
7
+ * @Email: ashawkey1999@gmail.com
8
+ * @Ref: https://github.com/ashawkey/torch-ngp
9
+ */
10
+
11
+ #include <stdint.h>
12
+ #include <torch/extension.h>
13
+ #include <torch/torch.h>
14
+
15
+ // inputs: [B, D], float, in [0, 1]
16
+ // embeddings: [sO, C], float
17
+ // offsets: [L + 1], uint32_t
18
+ // outputs: [B, L * C], float
19
+ // H: base resolution
20
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings,
21
+ const at::Tensor offsets, at::Tensor outputs,
22
+ const uint32_t B, const uint32_t D, const uint32_t C,
23
+ const uint32_t L, const float S, const uint32_t H,
24
+ const bool calc_grad_inputs, at::Tensor dy_dx,
25
+ const uint32_t gridtype, const bool align_corners);
26
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs,
27
+ const at::Tensor embeddings, const at::Tensor offsets,
28
+ at::Tensor grad_embeddings, const uint32_t B,
29
+ const uint32_t D, const uint32_t C, const uint32_t L,
30
+ const float S, const uint32_t H,
31
+ const bool calc_grad_inputs, const at::Tensor dy_dx,
32
+ at::Tensor grad_inputs, const uint32_t gridtype,
33
+ const bool align_corners);
34
+
35
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
36
+ m.def("forward", &grid_encode_forward,
37
+ "grid_encode_forward (CUDA)");
38
+ m.def("backward", &grid_encode_backward,
39
+ "grid_encode_backward (CUDA)");
40
+ }
citydreamer/extensions/grid_encoder/grid_encoder_ext.cu ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @File: grid_encoder_ext.cu
3
+ * @Author: Jiaxiang Tang (@ashawkey)
4
+ * @Date: 2023-04-15 10:43:16
5
+ * @Last Modified by: Haozhe Xie
6
+ * @Last Modified at: 2023-04-29 11:47:54
7
+ * @Email: ashawkey1999@gmail.com
8
+ * @Ref: https://github.com/ashawkey/torch-ngp
9
+ */
10
+
11
+ #include <cuda.h>
12
+ #include <cuda_fp16.h>
13
+ #include <cuda_runtime.h>
14
+
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <torch/torch.h>
17
+
18
+ #include <algorithm>
19
+ #include <stdexcept>
20
+
21
+ #include <cstdio>
22
+ #include <stdint.h>
23
+
24
+ #define CHECK_CUDA(x) \
25
+ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
26
+ #define CHECK_CONTIGUOUS(x) \
27
+ TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
28
+ #define CHECK_IS_INT(x) \
29
+ TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
30
+ #x " must be an int tensor")
31
+ #define CHECK_IS_FLOATING(x) \
32
+ TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || \
33
+ x.scalar_type() == at::ScalarType::Half || \
34
+ x.scalar_type() == at::ScalarType::Double, \
35
+ #x " must be a floating tensor")
36
+
37
+ // just for compatability of half precision in
38
+ // AT_DISPATCH_FLOATING_TYPES_AND_HALF...
39
+ static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
40
+ // requires CUDA >= 10 and ARCH >= 70
41
+ // this is very slow compared to float or __half2, and never used.
42
+ // return atomicAdd(reinterpret_cast<__half*>(address), val);
43
+ }
44
+
45
+ template <typename T>
46
+ static inline __host__ __device__ T div_round_up(T val, T divisor) {
47
+ return (val + divisor - 1) / divisor;
48
+ }
49
+
50
+ template <uint32_t D>
51
+ __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
52
+ static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
53
+
54
+ // While 1 is technically not a good prime for hashing (or a prime at all), it
55
+ // helps memory coherence and is sufficient for our use case of obtaining a
56
+ // uniformly colliding index from high-dimensional coordinates.
57
+ constexpr uint32_t primes[7] = {1, 2654435761, 805459861, 3674653429,
58
+ 2097192037, 1434869437, 2165219737};
59
+
60
+ uint32_t result = 0;
61
+ #pragma unroll
62
+ for (uint32_t i = 0; i < D; ++i) {
63
+ result ^= pos_grid[i] * primes[i];
64
+ }
65
+
66
+ return result;
67
+ }
68
+
69
+ template <uint32_t D, uint32_t C>
70
+ __device__ uint32_t get_grid_index(const uint32_t gridtype,
71
+ const bool align_corners, const uint32_t ch,
72
+ const uint32_t hashmap_size,
73
+ const uint32_t resolution,
74
+ const uint32_t pos_grid[D]) {
75
+ uint32_t stride = 1;
76
+ uint32_t index = 0;
77
+
78
+ #pragma unroll
79
+ for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
80
+ index += pos_grid[d] * stride;
81
+ stride *= align_corners ? resolution : (resolution + 1);
82
+ }
83
+
84
+ // NOTE: for NeRF, the hash is in fact not necessary. Check
85
+ // https://github.com/NVlabs/instant-ngp/issues/97. gridtype: 0 == hash, 1 ==
86
+ // tiled
87
+ if (gridtype == 0 && stride > hashmap_size) {
88
+ index = fast_hash<D>(pos_grid);
89
+ }
90
+
91
+ return (index % hashmap_size) * C + ch;
92
+ }
93
+
94
+ template <typename scalar_t, uint32_t D, uint32_t C>
95
+ __global__ void
96
+ kernel_grid(const float *__restrict__ inputs, const scalar_t *__restrict__ grid,
97
+ const int *__restrict__ offsets, scalar_t *__restrict__ outputs,
98
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
99
+ const bool calc_grad_inputs, scalar_t *__restrict__ dy_dx,
100
+ const uint32_t gridtype, const bool align_corners) {
101
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
102
+
103
+ if (b >= B)
104
+ return;
105
+
106
+ const uint32_t level = blockIdx.y;
107
+
108
+ // locate
109
+ grid += (uint32_t)offsets[level] * C;
110
+ inputs += b * D;
111
+ outputs += level * B * C + b * C;
112
+
113
+ // check input range (should be in [0, 1])
114
+ bool flag_oob = false;
115
+ #pragma unroll
116
+ for (uint32_t d = 0; d < D; d++) {
117
+ if (inputs[d] < 0 || inputs[d] > 1) {
118
+ flag_oob = true;
119
+ }
120
+ }
121
+ // if input out of bound, just set output to 0
122
+ if (flag_oob) {
123
+ #pragma unroll
124
+ for (uint32_t ch = 0; ch < C; ch++) {
125
+ outputs[ch] = 0;
126
+ }
127
+ if (calc_grad_inputs) {
128
+ dy_dx += b * D * L * C + level * D * C; // B L D C
129
+ #pragma unroll
130
+ for (uint32_t d = 0; d < D; d++) {
131
+ #pragma unroll
132
+ for (uint32_t ch = 0; ch < C; ch++) {
133
+ dy_dx[d * C + ch] = 0;
134
+ }
135
+ }
136
+ }
137
+ return;
138
+ }
139
+
140
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
141
+ const float scale = exp2f(level * S) * H - 1.0f;
142
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
143
+
144
+ // calculate coordinate
145
+ float pos[D];
146
+ uint32_t pos_grid[D];
147
+
148
+ #pragma unroll
149
+ for (uint32_t d = 0; d < D; d++) {
150
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
151
+ pos_grid[d] = floorf(pos[d]);
152
+ pos[d] -= (float)pos_grid[d];
153
+ }
154
+
155
+ // printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1],
156
+ // pos_grid[0], pos_grid[1]);
157
+
158
+ // interpolate
159
+ scalar_t results[C] = {0}; // temp results in register
160
+
161
+ #pragma unroll
162
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
163
+ float w = 1;
164
+ uint32_t pos_grid_local[D];
165
+
166
+ #pragma unroll
167
+ for (uint32_t d = 0; d < D; d++) {
168
+ if ((idx & (1 << d)) == 0) {
169
+ w *= 1 - pos[d];
170
+ pos_grid_local[d] = pos_grid[d];
171
+ } else {
172
+ w *= pos[d];
173
+ pos_grid_local[d] = pos_grid[d] + 1;
174
+ }
175
+ }
176
+
177
+ uint32_t index = get_grid_index<D, C>(
178
+ gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
179
+
180
+ // writing to register (fast)
181
+ #pragma unroll
182
+ for (uint32_t ch = 0; ch < C; ch++) {
183
+ results[ch] += w * grid[index + ch];
184
+ }
185
+
186
+ // printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx,
187
+ // index, w, grid[index]);
188
+ }
189
+
190
+ // writing to global memory (slow)
191
+ #pragma unroll
192
+ for (uint32_t ch = 0; ch < C; ch++) {
193
+ outputs[ch] = results[ch];
194
+ }
195
+
196
+ // prepare dy_dx for calc_grad_inputs
197
+ // differentiable (soft) indexing:
198
+ // https://discuss.pytorch.org/t/differentiable-indexing/17647/9
199
+ if (calc_grad_inputs) {
200
+
201
+ dy_dx += b * D * L * C + level * D * C; // B L D C
202
+
203
+ #pragma unroll
204
+ for (uint32_t gd = 0; gd < D; gd++) {
205
+
206
+ scalar_t results_grad[C] = {0};
207
+
208
+ #pragma unroll
209
+ for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
210
+ float w = scale;
211
+ uint32_t pos_grid_local[D];
212
+
213
+ #pragma unroll
214
+ for (uint32_t nd = 0; nd < D - 1; nd++) {
215
+ const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
216
+
217
+ if ((idx & (1 << nd)) == 0) {
218
+ w *= 1 - pos[d];
219
+ pos_grid_local[d] = pos_grid[d];
220
+ } else {
221
+ w *= pos[d];
222
+ pos_grid_local[d] = pos_grid[d] + 1;
223
+ }
224
+ }
225
+
226
+ pos_grid_local[gd] = pos_grid[gd];
227
+ uint32_t index_left =
228
+ get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size,
229
+ resolution, pos_grid_local);
230
+ pos_grid_local[gd] = pos_grid[gd] + 1;
231
+ uint32_t index_right =
232
+ get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size,
233
+ resolution, pos_grid_local);
234
+
235
+ #pragma unroll
236
+ for (uint32_t ch = 0; ch < C; ch++) {
237
+ results_grad[ch] +=
238
+ w * (grid[index_right + ch] - grid[index_left + ch]);
239
+ }
240
+ }
241
+
242
+ #pragma unroll
243
+ for (uint32_t ch = 0; ch < C; ch++) {
244
+ dy_dx[gd * C + ch] = results_grad[ch];
245
+ }
246
+ }
247
+ }
248
+ }
249
+
250
+ template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
251
+ __global__ void kernel_grid_backward(
252
+ const scalar_t *__restrict__ grad, const float *__restrict__ inputs,
253
+ const scalar_t *__restrict__ grid, const int *__restrict__ offsets,
254
+ scalar_t *__restrict__ grad_grid, const uint32_t B, const uint32_t L,
255
+ const float S, const uint32_t H, const uint32_t gridtype,
256
+ const bool align_corners) {
257
+ const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
258
+ if (b >= B)
259
+ return;
260
+
261
+ const uint32_t level = blockIdx.y;
262
+ const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
263
+
264
+ // locate
265
+ grad_grid += offsets[level] * C;
266
+ inputs += b * D;
267
+ grad += level * B * C + b * C + ch; // L, B, C
268
+
269
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
270
+ const float scale = exp2f(level * S) * H - 1.0f;
271
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
272
+
273
+ // check input range (should be in [0, 1])
274
+ #pragma unroll
275
+ for (uint32_t d = 0; d < D; d++) {
276
+ if (inputs[d] < 0 || inputs[d] > 1) {
277
+ return; // grad is init as 0, so we simply return.
278
+ }
279
+ }
280
+
281
+ // calculate coordinate
282
+ float pos[D];
283
+ uint32_t pos_grid[D];
284
+
285
+ #pragma unroll
286
+ for (uint32_t d = 0; d < D; d++) {
287
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
288
+ pos_grid[d] = floorf(pos[d]);
289
+ pos[d] -= (float)pos_grid[d];
290
+ }
291
+
292
+ scalar_t grad_cur[N_C] = {0}; // fetch to register
293
+ #pragma unroll
294
+ for (uint32_t c = 0; c < N_C; c++) {
295
+ grad_cur[c] = grad[c];
296
+ }
297
+
298
+ // interpolate
299
+ #pragma unroll
300
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
301
+ float w = 1;
302
+ uint32_t pos_grid_local[D];
303
+
304
+ #pragma unroll
305
+ for (uint32_t d = 0; d < D; d++) {
306
+ if ((idx & (1 << d)) == 0) {
307
+ w *= 1 - pos[d];
308
+ pos_grid_local[d] = pos_grid[d];
309
+ } else {
310
+ w *= pos[d];
311
+ pos_grid_local[d] = pos_grid[d] + 1;
312
+ }
313
+ }
314
+
315
+ uint32_t index = get_grid_index<D, C>(
316
+ gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
317
+
318
+ // atomicAdd for __half is slow (especially for large values), so we use
319
+ // __half2 if N_C % 2 == 0
320
+ // TODO: use float which is better than __half, if N_C % 2 != 0
321
+ if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
322
+ #pragma unroll
323
+ for (uint32_t c = 0; c < N_C; c += 2) {
324
+ // process two __half at once (by interpreting as a __half2)
325
+ __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
326
+ atomicAdd((__half2 *)&grad_grid[index + c], v);
327
+ }
328
+ // float, or __half when N_C % 2 != 0 (which means C == 1)
329
+ } else {
330
+ #pragma unroll
331
+ for (uint32_t c = 0; c < N_C; c++) {
332
+ atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
333
+ }
334
+ }
335
+ }
336
+ }
337
+
338
+ template <typename scalar_t, uint32_t D, uint32_t C>
339
+ __global__ void kernel_input_backward(const scalar_t *__restrict__ grad,
340
+ const scalar_t *__restrict__ dy_dx,
341
+ scalar_t *__restrict__ grad_inputs,
342
+ uint32_t B, uint32_t L) {
343
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
344
+ if (t >= B * D)
345
+ return;
346
+
347
+ const uint32_t b = t / D;
348
+ const uint32_t d = t - b * D;
349
+
350
+ dy_dx += b * L * D * C;
351
+
352
+ scalar_t result = 0;
353
+
354
+ #pragma unroll
355
+ for (int l = 0; l < L; l++) {
356
+ #pragma unroll
357
+ for (int ch = 0; ch < C; ch++) {
358
+ result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
359
+ }
360
+ }
361
+
362
+ grad_inputs[t] = result;
363
+ }
364
+
365
+ template <typename scalar_t, uint32_t D>
366
+ void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings,
367
+ const int *offsets, scalar_t *outputs,
368
+ const uint32_t B, const uint32_t C, const uint32_t L,
369
+ const float S, const uint32_t H,
370
+ const bool calc_grad_inputs, scalar_t *dy_dx,
371
+ const uint32_t gridtype, const bool align_corners) {
372
+ static constexpr uint32_t N_THREAD = 512;
373
+ const dim3 blocks_hashgrid = {div_round_up(B, N_THREAD), L, 1};
374
+ switch (C) {
375
+ case 1:
376
+ kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(
377
+ inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
378
+ dy_dx, gridtype, align_corners);
379
+ break;
380
+ case 2:
381
+ kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(
382
+ inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
383
+ dy_dx, gridtype, align_corners);
384
+ break;
385
+ case 4:
386
+ kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(
387
+ inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
388
+ dy_dx, gridtype, align_corners);
389
+ break;
390
+ case 8:
391
+ kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(
392
+ inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
393
+ dy_dx, gridtype, align_corners);
394
+ break;
395
+ default:
396
+ throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
397
+ }
398
+ }
399
+
400
+ // inputs: [B, D], float, in [0, 1]
401
+ // embeddings: [sO, C], float
402
+ // offsets: [L + 1], uint32_t
403
+ // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit
404
+ // into cache at a time.) H: base resolution dy_dx: [B, L * D * C]
405
+ template <typename scalar_t>
406
+ void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings,
407
+ const int *offsets, scalar_t *outputs,
408
+ const uint32_t B, const uint32_t D,
409
+ const uint32_t C, const uint32_t L, const float S,
410
+ const uint32_t H, const bool calc_grad_inputs,
411
+ scalar_t *dy_dx, const uint32_t gridtype,
412
+ const bool align_corners) {
413
+ switch (D) {
414
+ case 2:
415
+ kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C,
416
+ L, S, H, calc_grad_inputs, dy_dx, gridtype,
417
+ align_corners);
418
+ break;
419
+ case 3:
420
+ kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C,
421
+ L, S, H, calc_grad_inputs, dy_dx, gridtype,
422
+ align_corners);
423
+ break;
424
+ case 4:
425
+ kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C,
426
+ L, S, H, calc_grad_inputs, dy_dx, gridtype,
427
+ align_corners);
428
+ break;
429
+ case 5:
430
+ kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C,
431
+ L, S, H, calc_grad_inputs, dy_dx, gridtype,
432
+ align_corners);
433
+ break;
434
+ default:
435
+ throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
436
+ }
437
+ }
438
+
439
+ template <typename scalar_t, uint32_t D>
440
+ void kernel_grid_backward_wrapper(
441
+ const scalar_t *grad, const float *inputs, const scalar_t *embeddings,
442
+ const int *offsets, scalar_t *grad_embeddings, const uint32_t B,
443
+ const uint32_t C, const uint32_t L, const float S, const uint32_t H,
444
+ const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs,
445
+ const uint32_t gridtype, const bool align_corners) {
446
+ static constexpr uint32_t N_THREAD = 256;
447
+ const uint32_t N_C = std::min(2u, C); // n_features_per_thread
448
+ const dim3 blocks_hashgrid = {div_round_up(B * C / N_C, N_THREAD), L, 1};
449
+ switch (C) {
450
+ case 1:
451
+ kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(
452
+ grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
453
+ gridtype, align_corners);
454
+ if (calc_grad_inputs)
455
+ kernel_input_backward<scalar_t, D, 1>
456
+ <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
457
+ grad_inputs, B, L);
458
+ break;
459
+ case 2:
460
+ kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(
461
+ grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
462
+ gridtype, align_corners);
463
+ if (calc_grad_inputs)
464
+ kernel_input_backward<scalar_t, D, 2>
465
+ <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
466
+ grad_inputs, B, L);
467
+ break;
468
+ case 4:
469
+ kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(
470
+ grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
471
+ gridtype, align_corners);
472
+ if (calc_grad_inputs)
473
+ kernel_input_backward<scalar_t, D, 4>
474
+ <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
475
+ grad_inputs, B, L);
476
+ break;
477
+ case 8:
478
+ kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(
479
+ grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
480
+ gridtype, align_corners);
481
+ if (calc_grad_inputs)
482
+ kernel_input_backward<scalar_t, D, 8>
483
+ <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
484
+ grad_inputs, B, L);
485
+ break;
486
+ default:
487
+ throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
488
+ }
489
+ }
490
+
491
+ // grad: [L, B, C], float
492
+ // inputs: [B, D], float, in [0, 1]
493
+ // embeddings: [sO, C], float
494
+ // offsets: [L + 1], uint32_t
495
+ // grad_embeddings: [sO, C]
496
+ // H: base resolution
497
+ template <typename scalar_t>
498
+ void grid_encode_backward_cuda(
499
+ const scalar_t *grad, const float *inputs, const scalar_t *embeddings,
500
+ const int *offsets, scalar_t *grad_embeddings, const uint32_t B,
501
+ const uint32_t D, const uint32_t C, const uint32_t L, const float S,
502
+ const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx,
503
+ scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
504
+ switch (D) {
505
+ case 2:
506
+ kernel_grid_backward_wrapper<scalar_t, 2>(
507
+ grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
508
+ calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
509
+ break;
510
+ case 3:
511
+ kernel_grid_backward_wrapper<scalar_t, 3>(
512
+ grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
513
+ calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
514
+ break;
515
+ case 4:
516
+ kernel_grid_backward_wrapper<scalar_t, 4>(
517
+ grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
518
+ calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
519
+ break;
520
+ case 5:
521
+ kernel_grid_backward_wrapper<scalar_t, 5>(
522
+ grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
523
+ calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
524
+ break;
525
+ default:
526
+ throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
527
+ }
528
+ }
529
+
530
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings,
531
+ const at::Tensor offsets, at::Tensor outputs,
532
+ const uint32_t B, const uint32_t D, const uint32_t C,
533
+ const uint32_t L, const float S, const uint32_t H,
534
+ const bool calc_grad_inputs, at::Tensor dy_dx,
535
+ const uint32_t gridtype, const bool align_corners) {
536
+ CHECK_CUDA(inputs);
537
+ CHECK_CUDA(embeddings);
538
+ CHECK_CUDA(offsets);
539
+ CHECK_CUDA(outputs);
540
+ CHECK_CUDA(dy_dx);
541
+
542
+ CHECK_CONTIGUOUS(inputs);
543
+ CHECK_CONTIGUOUS(embeddings);
544
+ CHECK_CONTIGUOUS(offsets);
545
+ CHECK_CONTIGUOUS(outputs);
546
+ CHECK_CONTIGUOUS(dy_dx);
547
+
548
+ CHECK_IS_FLOATING(inputs);
549
+ CHECK_IS_FLOATING(embeddings);
550
+ CHECK_IS_INT(offsets);
551
+ CHECK_IS_FLOATING(outputs);
552
+ CHECK_IS_FLOATING(dy_dx);
553
+
554
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
555
+ embeddings.scalar_type(), "grid_encode_forward", ([&] {
556
+ grid_encode_forward_cuda<scalar_t>(
557
+ inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(),
558
+ offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L,
559
+ S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), gridtype,
560
+ align_corners);
561
+ }));
562
+ }
563
+
564
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs,
565
+ const at::Tensor embeddings, const at::Tensor offsets,
566
+ at::Tensor grad_embeddings, const uint32_t B,
567
+ const uint32_t D, const uint32_t C, const uint32_t L,
568
+ const float S, const uint32_t H,
569
+ const bool calc_grad_inputs, const at::Tensor dy_dx,
570
+ at::Tensor grad_inputs, const uint32_t gridtype,
571
+ const bool align_corners) {
572
+ CHECK_CUDA(grad);
573
+ CHECK_CUDA(inputs);
574
+ CHECK_CUDA(embeddings);
575
+ CHECK_CUDA(offsets);
576
+ CHECK_CUDA(grad_embeddings);
577
+ CHECK_CUDA(dy_dx);
578
+ CHECK_CUDA(grad_inputs);
579
+
580
+ CHECK_CONTIGUOUS(grad);
581
+ CHECK_CONTIGUOUS(inputs);
582
+ CHECK_CONTIGUOUS(embeddings);
583
+ CHECK_CONTIGUOUS(offsets);
584
+ CHECK_CONTIGUOUS(grad_embeddings);
585
+ CHECK_CONTIGUOUS(dy_dx);
586
+ CHECK_CONTIGUOUS(grad_inputs);
587
+
588
+ CHECK_IS_FLOATING(grad);
589
+ CHECK_IS_FLOATING(inputs);
590
+ CHECK_IS_FLOATING(embeddings);
591
+ CHECK_IS_INT(offsets);
592
+ CHECK_IS_FLOATING(grad_embeddings);
593
+ CHECK_IS_FLOATING(dy_dx);
594
+ CHECK_IS_FLOATING(grad_inputs);
595
+
596
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
597
+ grad.scalar_type(), "grid_encode_backward", ([&] {
598
+ grid_encode_backward_cuda<scalar_t>(
599
+ grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(),
600
+ embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(),
601
+ grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H,
602
+ calc_grad_inputs, dy_dx.data_ptr<scalar_t>(),
603
+ grad_inputs.data_ptr<scalar_t>(), gridtype, align_corners);
604
+ }));
605
+ }
citydreamer/extensions/grid_encoder/setup.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: setup.py
4
+ # @Author: Jiaxiang Tang (@ashawkey)
5
+ # @Date: 2023-04-15 10:33:32
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2023-04-29 10:47:10
8
+ # @Email: ashawkey1999@gmail.com
9
+ # @Ref: https://github.com/ashawkey/torch-ngp
10
+
11
+ from setuptools import setup
12
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
13
+
14
+ setup(
15
+ name="grid_encoder",
16
+ version="1.0.0",
17
+ ext_modules=[
18
+ CUDAExtension(
19
+ name="grid_encoder_ext",
20
+ sources=[
21
+ "grid_encoder_ext.cu",
22
+ "bindings.cpp",
23
+ ],
24
+ extra_compile_args={
25
+ "cxx": ["-O3", "-std=c++14"],
26
+ "nvcc": [
27
+ "-O3",
28
+ "-std=c++14",
29
+ "-U__CUDA_NO_HALF_OPERATORS__",
30
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
31
+ "-U__CUDA_NO_HALF2_OPERATORS__",
32
+ ],
33
+ },
34
+ ),
35
+ ],
36
+ cmdclass={
37
+ "build_ext": BuildExtension,
38
+ },
39
+ )
citydreamer/extensions/voxlib/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ from voxlib import ray_voxel_intersection_perspective
citydreamer/extensions/voxlib/ray_voxel_intersection.cu ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, check out LICENSE.md
5
+ //
6
+ // The ray marching algorithm used in this file is a variety of modified
7
+ // Bresenham method:
8
+ // http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.42.3443&rep=rep1&type=pdf
9
+ // Search for "voxel traversal algorithm" for related information
10
+
11
+ #include <torch/types.h>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/AccumulateType.h>
15
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
16
+ #include <ATen/cuda/CUDAContext.h>
17
+
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <curand.h>
21
+ #include <curand_kernel.h>
22
+ #include <time.h>
23
+
24
+ //#include <pybind11/numpy.h>
25
+ #include <pybind11/pybind11.h>
26
+ #include <pybind11/stl.h>
27
+ #include <vector>
28
+
29
+ #include "voxlib_common.h"
30
+
31
+ struct RVIP_Params {
32
+ int voxel_dims[3];
33
+ int voxel_strides[3];
34
+ int max_samples;
35
+ int img_dims[2];
36
+ // Camera parameters
37
+ float cam_ori[3];
38
+ float cam_fwd[3];
39
+ float cam_side[3];
40
+ float cam_up[3];
41
+ float cam_c[2];
42
+ float cam_f;
43
+ // unsigned long seed;
44
+ };
45
+
46
+ /*
47
+ out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples,
48
+ 1] out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples,
49
+ 1] out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1,
50
+ 3] Image coordinates refer to the center of the pixel [0, 0, 0] at voxel
51
+ coordinate is at the corner of the corner block (instead of at the center)
52
+ */
53
+ template <int TILE_DIM>
54
+ static __global__ void ray_voxel_intersection_perspective_kernel(
55
+ int32_t *__restrict__ out_voxel_id, float *__restrict__ out_depth,
56
+ float *__restrict__ out_raydirs, const int32_t *__restrict__ in_voxel,
57
+ const RVIP_Params p) {
58
+
59
+ int img_coords[2];
60
+ img_coords[1] = blockIdx.x * TILE_DIM + threadIdx.x;
61
+ img_coords[0] = blockIdx.y * TILE_DIM + threadIdx.y;
62
+ if (img_coords[0] >= p.img_dims[0] || img_coords[1] >= p.img_dims[1]) {
63
+ return;
64
+ }
65
+ int pix_index = img_coords[0] * p.img_dims[1] + img_coords[1];
66
+
67
+ // Calculate ray origin and direction
68
+ float rayori[3], raydir[3];
69
+ rayori[0] = p.cam_ori[0];
70
+ rayori[1] = p.cam_ori[1];
71
+ rayori[2] = p.cam_ori[2];
72
+
73
+ // Camera intrinsics
74
+ float ndc_imcoords[2];
75
+ ndc_imcoords[0] = p.cam_c[0] - (float)img_coords[0]; // Flip height
76
+ ndc_imcoords[1] = (float)img_coords[1] - p.cam_c[1];
77
+
78
+ raydir[0] = p.cam_up[0] * ndc_imcoords[0] + p.cam_side[0] * ndc_imcoords[1] +
79
+ p.cam_fwd[0] * p.cam_f;
80
+ raydir[1] = p.cam_up[1] * ndc_imcoords[0] + p.cam_side[1] * ndc_imcoords[1] +
81
+ p.cam_fwd[1] * p.cam_f;
82
+ raydir[2] = p.cam_up[2] * ndc_imcoords[0] + p.cam_side[2] * ndc_imcoords[1] +
83
+ p.cam_fwd[2] * p.cam_f;
84
+ normalize<float, 3>(raydir);
85
+
86
+ // Save out_raydirs
87
+ out_raydirs[pix_index * 3] = raydir[0];
88
+ out_raydirs[pix_index * 3 + 1] = raydir[1];
89
+ out_raydirs[pix_index * 3 + 2] = raydir[2];
90
+
91
+ float axis_t[3];
92
+ int axis_int[3];
93
+ // int axis_intbound[3];
94
+
95
+ // Current voxel
96
+ axis_int[0] = floorf(rayori[0]);
97
+ axis_int[1] = floorf(rayori[1]);
98
+ axis_int[2] = floorf(rayori[2]);
99
+
100
+ #pragma unroll
101
+ for (int i = 0; i < 3; i++) {
102
+ if (raydir[i] > 0) {
103
+ // Initial t value
104
+ // Handle boundary case where rayori[i] is a whole number. Always round Up
105
+ // for the next block
106
+ // axis_t[i] = (ceilf(nextafterf(rayori[i], HUGE_VALF)) - rayori[i]) /
107
+ // raydir[i];
108
+ axis_t[i] = ((float)(axis_int[i] + 1) - rayori[i]) / raydir[i];
109
+ } else if (raydir[i] < 0) {
110
+ axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
111
+ } else {
112
+ axis_t[i] = HUGE_VALF;
113
+ }
114
+ }
115
+
116
+ // Fused raymarching and sampling
117
+ bool quit = false;
118
+ for (int cur_plane = 0; cur_plane < p.max_samples;
119
+ cur_plane++) { // Last cycle is for calculating p2
120
+ float t = nanf("0");
121
+ float t2 = nanf("0");
122
+ int32_t blk_id = 0;
123
+ // Find the next intersection
124
+ while (!quit) {
125
+ // Find the next smallest t
126
+ float tnow;
127
+ /*
128
+ #pragma unroll
129
+ for (int i=0; i<3; i++) {
130
+ if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
131
+ // Update current t
132
+ tnow = axis_t[i];
133
+ // Update t candidates
134
+ if (raydir[i] > 0) {
135
+ axis_int[i] += 1;
136
+ if (axis_int[i] >= p.voxel_dims[i]) {
137
+ quit = true;
138
+ }
139
+ axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i];
140
+ } else {
141
+ axis_int[i] -= 1;
142
+ if (axis_int[i] < 0) {
143
+ quit = true;
144
+ }
145
+ axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
146
+ }
147
+ break; // Avoid advancing multiple steps as axis_t is updated
148
+ }
149
+ }
150
+ */
151
+ // Hand unroll
152
+ if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
153
+ // Update current t
154
+ tnow = axis_t[0];
155
+ // Update t candidates
156
+ if (raydir[0] > 0) {
157
+ axis_int[0] += 1;
158
+ if (axis_int[0] >= p.voxel_dims[0]) {
159
+ quit = true;
160
+ }
161
+ axis_t[0] = ((float)(axis_int[0] + 1) - rayori[0]) / raydir[0];
162
+ } else {
163
+ axis_int[0] -= 1;
164
+ if (axis_int[0] < 0) {
165
+ quit = true;
166
+ }
167
+ axis_t[0] = ((float)axis_int[0] - rayori[0]) / raydir[0];
168
+ }
169
+ } else if (axis_t[1] <= axis_t[2]) {
170
+ tnow = axis_t[1];
171
+ if (raydir[1] > 0) {
172
+ axis_int[1] += 1;
173
+ if (axis_int[1] >= p.voxel_dims[1]) {
174
+ quit = true;
175
+ }
176
+ axis_t[1] = ((float)(axis_int[1] + 1) - rayori[1]) / raydir[1];
177
+ } else {
178
+ axis_int[1] -= 1;
179
+ if (axis_int[1] < 0) {
180
+ quit = true;
181
+ }
182
+ axis_t[1] = ((float)axis_int[1] - rayori[1]) / raydir[1];
183
+ }
184
+ } else {
185
+ tnow = axis_t[2];
186
+ if (raydir[2] > 0) {
187
+ axis_int[2] += 1;
188
+ if (axis_int[2] >= p.voxel_dims[2]) {
189
+ quit = true;
190
+ }
191
+ axis_t[2] = ((float)(axis_int[2] + 1) - rayori[2]) / raydir[2];
192
+ } else {
193
+ axis_int[2] -= 1;
194
+ if (axis_int[2] < 0) {
195
+ quit = true;
196
+ }
197
+ axis_t[2] = ((float)axis_int[2] - rayori[2]) / raydir[2];
198
+ }
199
+ }
200
+
201
+ if (quit) {
202
+ break;
203
+ }
204
+
205
+ // Skip empty space
206
+ // Could there be deadlock if the ray direction is away from the world?
207
+ if (axis_int[0] < 0 || axis_int[0] >= p.voxel_dims[0] ||
208
+ axis_int[1] < 0 || axis_int[1] >= p.voxel_dims[1] ||
209
+ axis_int[2] < 0 || axis_int[2] >= p.voxel_dims[2]) {
210
+ continue;
211
+ }
212
+
213
+ // Test intersection using voxel grid
214
+ blk_id = in_voxel[axis_int[0] * p.voxel_strides[0] +
215
+ axis_int[1] * p.voxel_strides[1] +
216
+ axis_int[2] * p.voxel_strides[2]];
217
+ if (blk_id == 0) {
218
+ continue;
219
+ }
220
+
221
+ // Now that there is an intersection
222
+ t = tnow;
223
+ // Calculate t2
224
+ /*
225
+ #pragma unroll
226
+ for (int i=0; i<3; i++) {
227
+ if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
228
+ t2 = axis_t[i];
229
+ break;
230
+ }
231
+ }
232
+ */
233
+ // Hand unroll
234
+ if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
235
+ t2 = axis_t[0];
236
+ } else if (axis_t[1] <= axis_t[2]) {
237
+ t2 = axis_t[1];
238
+ } else {
239
+ t2 = axis_t[2];
240
+ }
241
+ break;
242
+ } // while !quit (ray marching loop)
243
+
244
+ out_depth[pix_index * p.max_samples + cur_plane] = t;
245
+ out_depth[p.img_dims[0] * p.img_dims[1] * p.max_samples +
246
+ pix_index * p.max_samples + cur_plane] = t2;
247
+ out_voxel_id[pix_index * p.max_samples + cur_plane] = blk_id;
248
+ } // cur_plane
249
+ }
250
+
251
+ /*
252
+ out:
253
+ out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1],
254
+ max_samples, 1] out_depth: torch CUDA float [2, img_dims[0], img_dims[1],
255
+ max_samples, 1] out_raydirs: torch CUDA float [ img_dims[0], img_dims[1],
256
+ 1, 3] in: in_voxel: torch CUDA int32 [X, Y, Z] [40, 512, 512] cam_ori:
257
+ torch float [3] cam_dir: torch float [3] cam_up: torch
258
+ float [3] cam_f: float cam_c: int [2]
259
+ img_dims: int [2]
260
+ max_samples: int
261
+ */
262
+ std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(
263
+ const torch::Tensor &in_voxel, const torch::Tensor &cam_ori,
264
+ const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f,
265
+ const std::vector<float> &cam_c, const std::vector<int> &img_dims,
266
+ int max_samples) {
267
+ CHECK_CUDA(in_voxel);
268
+
269
+ int curDevice = -1;
270
+ cudaGetDevice(&curDevice);
271
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
272
+ torch::Device device = in_voxel.device();
273
+
274
+ // assert(in_voxel.dtype() == torch::kU8);
275
+ assert(in_voxel.dtype() == torch::kInt32); // Minecraft compatibility
276
+ assert(in_voxel.dim() == 3);
277
+ assert(cam_ori.dtype() == torch::kFloat32);
278
+ assert(cam_ori.numel() == 3);
279
+ assert(cam_dir.dtype() == torch::kFloat32);
280
+ assert(cam_dir.numel() == 3);
281
+ assert(cam_up.dtype() == torch::kFloat32);
282
+ assert(cam_up.numel() == 3);
283
+ assert(img_dims.size() == 2);
284
+
285
+ RVIP_Params p;
286
+
287
+ // Calculate camera rays
288
+ const torch::Tensor cam_ori_c = cam_ori.cpu();
289
+ const torch::Tensor cam_dir_c = cam_dir.cpu();
290
+ const torch::Tensor cam_up_c = cam_up.cpu();
291
+
292
+ // Get the coordinate frame of camera space in world space
293
+ normalize<float, 3>(p.cam_fwd, cam_dir_c.data_ptr<float>());
294
+ cross<float>(p.cam_side, p.cam_fwd, cam_up_c.data_ptr<float>());
295
+ normalize<float, 3>(p.cam_side);
296
+ cross<float>(p.cam_up, p.cam_side, p.cam_fwd);
297
+ normalize<float, 3>(p.cam_up); // Not absolutely necessary as both vectors are
298
+ // normalized. But just in case...
299
+
300
+ copyarr<float, 3>(p.cam_ori, cam_ori_c.data_ptr<float>());
301
+
302
+ p.cam_f = cam_f;
303
+ p.cam_c[0] = cam_c[0];
304
+ p.cam_c[1] = cam_c[1];
305
+ p.max_samples = max_samples;
306
+ // printf("[Renderer] max_dist: %ld\n", max_dist);
307
+
308
+ p.voxel_dims[0] = in_voxel.size(0);
309
+ p.voxel_dims[1] = in_voxel.size(1);
310
+ p.voxel_dims[2] = in_voxel.size(2);
311
+ p.voxel_strides[0] = in_voxel.stride(0);
312
+ p.voxel_strides[1] = in_voxel.stride(1);
313
+ p.voxel_strides[2] = in_voxel.stride(2);
314
+
315
+ // printf("[Renderer] Voxel resolution: %ld, %ld, %ld\n", p.voxel_dims[0],
316
+ // p.voxel_dims[1], p.voxel_dims[2]);
317
+
318
+ p.img_dims[0] = img_dims[0];
319
+ p.img_dims[1] = img_dims[1];
320
+
321
+ // Create output tensors
322
+ // For Minecraft Seg Mask
323
+ torch::Tensor out_voxel_id =
324
+ torch::empty({p.img_dims[0], p.img_dims[1], p.max_samples, 1},
325
+ torch::TensorOptions().dtype(torch::kInt32).device(device));
326
+
327
+ torch::Tensor out_depth;
328
+ // Produce two sets of localcoords, one for entry point, the other one for
329
+ // exit point. They share the same corner_ids.
330
+ out_depth = torch::empty(
331
+ {2, p.img_dims[0], p.img_dims[1], p.max_samples, 1},
332
+ torch::TensorOptions().dtype(torch::kFloat32).device(device));
333
+
334
+ torch::Tensor out_raydirs = torch::empty({p.img_dims[0], p.img_dims[1], 1, 3},
335
+ torch::TensorOptions()
336
+ .dtype(torch::kFloat32)
337
+ .device(device)
338
+ .requires_grad(false));
339
+
340
+ const int TILE_DIM = 8;
341
+ dim3 dimGrid((p.img_dims[1] + TILE_DIM - 1) / TILE_DIM,
342
+ (p.img_dims[0] + TILE_DIM - 1) / TILE_DIM, 1);
343
+ dim3 dimBlock(TILE_DIM, TILE_DIM, 1);
344
+
345
+ ray_voxel_intersection_perspective_kernel<TILE_DIM>
346
+ <<<dimGrid, dimBlock, 0, stream>>>(
347
+ out_voxel_id.data_ptr<int32_t>(), out_depth.data_ptr<float>(),
348
+ out_raydirs.data_ptr<float>(), in_voxel.data_ptr<int32_t>(), p);
349
+
350
+ return {out_voxel_id, out_depth, out_raydirs};
351
+ }
citydreamer/extensions/voxlib/setup.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ from setuptools import setup
6
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7
+
8
+ cxx_args = ["-fopenmp"]
9
+ nvcc_args = []
10
+
11
+ setup(
12
+ name="voxrender",
13
+ version="1.0.0",
14
+ ext_modules=[
15
+ CUDAExtension(
16
+ "voxlib",
17
+ [
18
+ "voxlib.cpp",
19
+ "ray_voxel_intersection.cu",
20
+ ],
21
+ extra_compile_args={"cxx": cxx_args, "nvcc": nvcc_args},
22
+ )
23
+ ],
24
+ cmdclass={"build_ext": BuildExtension},
25
+ )
citydreamer/extensions/voxlib/voxlib.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, check out LICENSE.md
5
+ #include <pybind11/pybind11.h>
6
+ #include <pybind11/stl.h>
7
+ #include <torch/extension.h>
8
+ #include <vector>
9
+
10
+ // Fast voxel traversal along rays
11
+ std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(
12
+ const torch::Tensor &in_voxel, const torch::Tensor &cam_ori,
13
+ const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f,
14
+ const std::vector<float> &cam_c, const std::vector<int> &img_dims,
15
+ int max_samples);
16
+
17
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
18
+ m.def("ray_voxel_intersection_perspective",
19
+ &ray_voxel_intersection_perspective_cuda,
20
+ "Ray-voxel intersections given perspective camera parameters (CUDA)");
21
+ }
citydreamer/extensions/voxlib/voxlib_common.h ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, check out LICENSE.md
5
+ #ifndef VOXLIB_COMMON_H
6
+ #define VOXLIB_COMMON_H
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) \
10
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
11
+ #define CHECK_INPUT(x) \
12
+ CHECK_CUDA(x); \
13
+ CHECK_CONTIGUOUS(x)
14
+ #define CHECK_CPU(x) \
15
+ TORCH_CHECK(x.device().is_cpu(), #x " must be a CPU tensor")
16
+
17
+ #include <cuda.h>
18
+ #include <cuda_runtime.h>
19
+ // CUDA vector math functions
20
+ __host__ __device__ __forceinline__ int floor_div(int a, int b) {
21
+ int c = a / b;
22
+
23
+ if (c * b > a) {
24
+ c--;
25
+ }
26
+
27
+ return c;
28
+ }
29
+
30
+ template <typename scalar_t>
31
+ __host__ __forceinline__ void cross(scalar_t *r, const scalar_t *a,
32
+ const scalar_t *b) {
33
+ r[0] = a[1] * b[2] - a[2] * b[1];
34
+ r[1] = a[2] * b[0] - a[0] * b[2];
35
+ r[2] = a[0] * b[1] - a[1] * b[0];
36
+ }
37
+
38
+ __device__ __host__ __forceinline__ float dot(const float *a, const float *b) {
39
+ return a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
40
+ }
41
+
42
+ template <typename scalar_t, int ndim>
43
+ __device__ __host__ __forceinline__ void copyarr(scalar_t *r,
44
+ const scalar_t *a) {
45
+ #pragma unroll
46
+ for (int i = 0; i < ndim; i++) {
47
+ r[i] = a[i];
48
+ }
49
+ }
50
+
51
+ // TODO: use rsqrt to speed up
52
+ // inplace version
53
+ template <typename scalar_t, int ndim>
54
+ __device__ __host__ __forceinline__ void normalize(scalar_t *a) {
55
+ scalar_t vec_len = 0.0f;
56
+ #pragma unroll
57
+ for (int i = 0; i < ndim; i++) {
58
+ vec_len += a[i] * a[i];
59
+ }
60
+ vec_len = sqrtf(vec_len);
61
+ #pragma unroll
62
+ for (int i = 0; i < ndim; i++) {
63
+ a[i] /= vec_len;
64
+ }
65
+ }
66
+
67
+ // normalize + copy
68
+ template <typename scalar_t, int ndim>
69
+ __device__ __host__ __forceinline__ void normalize(scalar_t *r,
70
+ const scalar_t *a) {
71
+ scalar_t vec_len = 0.0f;
72
+ #pragma unroll
73
+ for (int i = 0; i < ndim; i++) {
74
+ vec_len += a[i] * a[i];
75
+ }
76
+ vec_len = sqrtf(vec_len);
77
+ #pragma unroll
78
+ for (int i = 0; i < ndim; i++) {
79
+ r[i] = a[i] / vec_len;
80
+ }
81
+ }
82
+
83
+ #endif // VOXLIB_COMMON_H
citydreamer/inference.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: inference.py
4
+ # @Author: Haozhe Xie
5
+ # @Date: 2024-03-02 16:30:00
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-03-03 12:10:18
8
+ # @Email: root@haozhexie.com
9
+
10
+ import copy
11
+ import cv2
12
+ import logging
13
+ import math
14
+ import numpy as np
15
+ import torch
16
+ import torchvision
17
+
18
+ import citydreamer.extensions.extrude_tensor
19
+ import citydreamer.extensions.voxlib
20
+
21
+ # Global constants
22
+ HEIGHTS = {
23
+ "ROAD": 4,
24
+ "GREEN_LANDS": 8,
25
+ "CONSTRUCTION": 10,
26
+ "COAST_ZONES": 0,
27
+ "ROOF": 1,
28
+ }
29
+ CLASSES = {
30
+ "NULL": 0,
31
+ "ROAD": 1,
32
+ "BLD_FACADE": 2,
33
+ "GREEN_LANDS": 3,
34
+ "CONSTRUCTION": 4,
35
+ "COAST_ZONES": 5,
36
+ "OTHERS": 6,
37
+ "BLD_ROOF": 7,
38
+ }
39
+ # NOTE: ID > 10 are reserved for building instances.
40
+ # Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
41
+ CONSTANTS = {
42
+ "BLD_INS_LABEL_MIN": 10,
43
+ "LAYOUT_N_CLASSES": 7,
44
+ "LAYOUT_VOL_SIZE": 1536,
45
+ "BUILDING_VOL_SIZE": 672,
46
+ "EXTENDED_VOL_SIZE": 2880,
47
+ "LAYOUT_MAX_HEIGHT": 640,
48
+ "GES_VFOV": 20,
49
+ "GES_IMAGE_HEIGHT": 540,
50
+ "GES_IMAGE_WIDTH": 960,
51
+ "IMAGE_PADDING": 8,
52
+ "N_VOXEL_INTERSECT_SAMPLES": 6,
53
+ }
54
+
55
+
56
+ def generate_city(fgm, bgm, hf, seg, radius, altitude, azimuth):
57
+ cam_pos = get_orbit_camera_position(radius, altitude, azimuth)
58
+ seg, building_stats = get_instance_seg_map(seg)
59
+ # Generate latent codes
60
+ logging.info("Generating latent codes ...")
61
+ bg_z, building_zs = get_latent_codes(
62
+ building_stats,
63
+ bgm.module.cfg.NETWORK.GANCRAFT.STYLE_DIM,
64
+ bgm.output_device,
65
+ )
66
+ # Random choose the center of the patch
67
+ cy = (
68
+ np.random.randint(seg.shape[0] - CONSTANTS["EXTENDED_VOL_SIZE"])
69
+ + CONSTANTS["EXTENDED_VOL_SIZE"] // 2
70
+ )
71
+ cx = (
72
+ np.random.randint(seg.shape[1] - CONSTANTS["EXTENDED_VOL_SIZE"])
73
+ + CONSTANTS["EXTENDED_VOL_SIZE"] // 2
74
+ )
75
+ # Generate local image patch of the height field and seg map
76
+ part_hf, part_seg = get_part_hf_seg(hf, seg, cx, cy, CONSTANTS["EXTENDED_VOL_SIZE"])
77
+ # Generate local image patch of the height field and seg map
78
+ part_hf, part_seg = get_part_hf_seg(hf, seg, cx, cy, CONSTANTS["EXTENDED_VOL_SIZE"])
79
+ # print(part_hf.shape) # (2880, 2880)
80
+ # print(part_seg.shape) # (2880, 2880)
81
+ # Recalculate the building positions based on the current patch
82
+ _building_stats = get_part_building_stats(part_seg, building_stats, cx, cy)
83
+ # Generate the concatenated height field and seg. map tensor
84
+ hf_seg = get_hf_seg_tensor(part_hf, part_seg, bgm.output_device)
85
+ # print(hf_seg.size()) # torch.Size([1, 8, 2880, 2880])
86
+ # Build seg_volume
87
+ logging.info("Generating seg volume ...")
88
+ seg_volume = get_seg_volume(part_hf, part_seg)
89
+ logging.info("Rendering City Image ...")
90
+ img = render(
91
+ (CONSTANTS["GES_IMAGE_HEIGHT"] // 5, CONSTANTS["GES_IMAGE_WIDTH"] // 5),
92
+ seg_volume,
93
+ hf_seg,
94
+ cam_pos,
95
+ bgm,
96
+ fgm,
97
+ _building_stats,
98
+ bg_z,
99
+ building_zs,
100
+ )
101
+ return ((img.cpu().numpy().squeeze().transpose((1, 2, 0)) / 2 + 0.5) * 255).astype(
102
+ np.uint8
103
+ )
104
+
105
+
106
+ def get_orbit_camera_position(radius, altitude, azimuth):
107
+ cx = CONSTANTS["LAYOUT_VOL_SIZE"] // 2
108
+ cy = cx
109
+ theta = np.deg2rad(azimuth)
110
+ cam_x = cx + radius * math.cos(theta)
111
+ cam_y = cy + radius * math.sin(theta)
112
+ return {"x": cam_x, "y": cam_y, "z": altitude}
113
+
114
+
115
+ def get_instance_seg_map(seg_map):
116
+ # Mapping constructions to buildings
117
+ seg_map[seg_map == CLASSES["CONSTRUCTION"]] = CLASSES["BLD_FACADE"]
118
+ # Use connected components to get building instances
119
+ _, labels, stats, _ = cv2.connectedComponentsWithStats(
120
+ (seg_map == CLASSES["BLD_FACADE"]).astype(np.uint8), connectivity=4
121
+ )
122
+ # Remove non-building instance masks
123
+ labels[seg_map != CLASSES["BLD_FACADE"]] = 0
124
+ # Building instance mask
125
+ building_mask = labels != 0
126
+ # Make building instance IDs are even numbers and start from 10
127
+ # Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
128
+ labels = (labels + CONSTANTS["BLD_INS_LABEL_MIN"]) * 2
129
+
130
+ seg_map[seg_map == CLASSES["BLD_FACADE"]] = 0
131
+ seg_map = seg_map * (1 - building_mask) + labels * building_mask
132
+ assert np.max(labels) < 2147483648
133
+ return seg_map.astype(np.int32), stats[:, :4]
134
+
135
+
136
+ def get_latent_codes(building_stats, bg_style_dim, output_device):
137
+ bg_z = _get_z(output_device, bg_style_dim)
138
+ building_zs = {
139
+ (i + CONSTANTS["BLD_INS_LABEL_MIN"]) * 2: _get_z(output_device)
140
+ for i in range(len(building_stats))
141
+ }
142
+ return bg_z, building_zs
143
+
144
+
145
+ def _get_z(device, z_dim=256):
146
+ if z_dim is None:
147
+ return None
148
+
149
+ return torch.randn(1, z_dim, dtype=torch.float32, device=device)
150
+
151
+
152
+ def get_part_hf_seg(hf, seg, cx, cy, patch_size):
153
+ part_hf = _get_image_patch(hf, cx, cy, patch_size)
154
+ part_seg = _get_image_patch(seg, cx, cy, patch_size)
155
+ assert part_hf.shape == (
156
+ patch_size,
157
+ patch_size,
158
+ ), part_hf.shape
159
+ assert part_hf.shape == part_seg.shape, part_seg.shape
160
+ return part_hf, part_seg
161
+
162
+
163
+ def _get_image_patch(image, cx, cy, patch_size):
164
+ sx = cx - patch_size // 2
165
+ sy = cy - patch_size // 2
166
+ ex = sx + patch_size
167
+ ey = sy + patch_size
168
+ return image[sy:ey, sx:ex]
169
+
170
+
171
+ def get_part_building_stats(part_seg, building_stats, cx, cy):
172
+ _buildings = np.unique(part_seg[part_seg > CONSTANTS["BLD_INS_LABEL_MIN"]])
173
+ _building_stats = {}
174
+ for b in _buildings:
175
+ _b = b // 2 - CONSTANTS["BLD_INS_LABEL_MIN"]
176
+ _building_stats[b] = [
177
+ building_stats[_b, 1] - cy + building_stats[_b, 3] / 2,
178
+ building_stats[_b, 0] - cx + building_stats[_b, 2] / 2,
179
+ ]
180
+ return _building_stats
181
+
182
+
183
+ def get_hf_seg_tensor(part_hf, part_seg, output_device):
184
+ part_hf = torch.from_numpy(part_hf[None, None, ...]).to(output_device)
185
+ part_seg = torch.from_numpy(part_seg[None, None, ...]).to(output_device)
186
+ part_hf = part_hf / CONSTANTS["LAYOUT_MAX_HEIGHT"]
187
+ part_seg = _masks_to_onehots(part_seg[:, 0, :, :], CONSTANTS["LAYOUT_N_CLASSES"])
188
+ return torch.cat([part_hf, part_seg], dim=1)
189
+
190
+
191
+ def _masks_to_onehots(masks, n_class, ignored_classes=[]):
192
+ b, h, w = masks.shape
193
+ n_class_actual = n_class - len(ignored_classes)
194
+ one_hot_masks = torch.zeros(
195
+ (b, n_class_actual, h, w), dtype=torch.float32, device=masks.device
196
+ )
197
+
198
+ n_class_cnt = 0
199
+ for i in range(n_class):
200
+ if i not in ignored_classes:
201
+ one_hot_masks[:, n_class_cnt] = masks == i
202
+ n_class_cnt += 1
203
+ return one_hot_masks
204
+
205
+
206
+ def get_seg_volume(part_hf, part_seg):
207
+ tensor_extruder = citydreamer.extensions.extrude_tensor.TensorExtruder(
208
+ CONSTANTS["LAYOUT_MAX_HEIGHT"]
209
+ )
210
+
211
+ if part_hf.shape == (
212
+ CONSTANTS["EXTENDED_VOL_SIZE"],
213
+ CONSTANTS["EXTENDED_VOL_SIZE"],
214
+ ):
215
+ part_hf = part_hf[
216
+ CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
217
+ CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
218
+ ]
219
+ # print(part_hf.shape) # torch.Size([1, 8, 1536, 1536])
220
+ part_seg = part_seg[
221
+ CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
222
+ CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
223
+ ]
224
+ # print(part_seg.shape) # torch.Size([1, 8, 1536, 1536])
225
+
226
+ assert part_hf.shape == (
227
+ CONSTANTS["LAYOUT_VOL_SIZE"],
228
+ CONSTANTS["LAYOUT_VOL_SIZE"],
229
+ )
230
+ assert part_hf.shape == part_seg.shape, part_seg.shape
231
+
232
+ seg_volume = tensor_extruder(
233
+ torch.from_numpy(part_seg[None, None, ...]).cuda(),
234
+ torch.from_numpy(part_hf[None, None, ...]).cuda(),
235
+ ).squeeze()
236
+ logging.debug("The shape of SegVolume: %s" % (seg_volume.size(),))
237
+ # Change the top-level voxel of the "Building Facade" to "Building Roof"
238
+ roof_seg_map = part_seg.copy()
239
+ non_roof_msk = part_seg <= CONSTANTS["BLD_INS_LABEL_MIN"]
240
+ # Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
241
+ roof_seg_map = roof_seg_map - 1
242
+ roof_seg_map[non_roof_msk] = 0
243
+ for rh in range(1, HEIGHTS["ROOF"] + 1):
244
+ seg_volume = seg_volume.scatter_(
245
+ dim=2,
246
+ index=torch.from_numpy(part_hf[..., None] + rh).long().cuda(),
247
+ src=torch.from_numpy(roof_seg_map[..., None]).cuda(),
248
+ )
249
+ # print(seg_volume.size()) # torch.Size([1536, 1536, 640])
250
+ return seg_volume
251
+
252
+
253
+ def get_voxel_intersection_perspective(seg_volume, camera_location):
254
+ CAMERA_FOCAL = (
255
+ CONSTANTS["GES_IMAGE_HEIGHT"] / 2 / np.tan(np.deg2rad(CONSTANTS["GES_VFOV"]))
256
+ )
257
+ # print(seg_volume.size()) # torch.Size([1536, 1536, 640])
258
+ camera_target = {
259
+ "x": seg_volume.size(1) // 2 - 1,
260
+ "y": seg_volume.size(0) // 2 - 1,
261
+ }
262
+ cam_origin = torch.tensor(
263
+ [
264
+ camera_location["y"],
265
+ camera_location["x"],
266
+ camera_location["z"],
267
+ ],
268
+ dtype=torch.float32,
269
+ device=seg_volume.device,
270
+ )
271
+
272
+ (
273
+ voxel_id,
274
+ depth2,
275
+ raydirs,
276
+ ) = citydreamer.extensions.voxlib.ray_voxel_intersection_perspective(
277
+ seg_volume,
278
+ cam_origin,
279
+ torch.tensor(
280
+ [
281
+ camera_target["y"] - camera_location["y"],
282
+ camera_target["x"] - camera_location["x"],
283
+ -camera_location["z"],
284
+ ],
285
+ dtype=torch.float32,
286
+ device=seg_volume.device,
287
+ ),
288
+ torch.tensor([0, 0, 1], dtype=torch.float32),
289
+ CAMERA_FOCAL * 2.06,
290
+ [
291
+ (CONSTANTS["GES_IMAGE_HEIGHT"] - 1) / 2.0,
292
+ (CONSTANTS["GES_IMAGE_WIDTH"] - 1) / 2.0,
293
+ ],
294
+ [CONSTANTS["GES_IMAGE_HEIGHT"], CONSTANTS["GES_IMAGE_WIDTH"]],
295
+ CONSTANTS["N_VOXEL_INTERSECT_SAMPLES"],
296
+ )
297
+ return (
298
+ voxel_id.unsqueeze(dim=0),
299
+ depth2.permute(1, 2, 0, 3, 4).unsqueeze(dim=0),
300
+ raydirs.unsqueeze(dim=0),
301
+ cam_origin.unsqueeze(dim=0),
302
+ )
303
+
304
+
305
+ def _get_pad_img_bbox(sx, ex, sy, ey):
306
+ psx = sx - CONSTANTS["IMAGE_PADDING"] if sx != 0 else 0
307
+ psy = sy - CONSTANTS["IMAGE_PADDING"] if sy != 0 else 0
308
+ pex = (
309
+ ex + CONSTANTS["IMAGE_PADDING"]
310
+ if ex != CONSTANTS["GES_IMAGE_WIDTH"]
311
+ else CONSTANTS["GES_IMAGE_WIDTH"]
312
+ )
313
+ pey = (
314
+ ey + CONSTANTS["IMAGE_PADDING"]
315
+ if ey != CONSTANTS["GES_IMAGE_HEIGHT"]
316
+ else CONSTANTS["GES_IMAGE_HEIGHT"]
317
+ )
318
+ return psx, pex, psy, pey
319
+
320
+
321
+ def _get_img_without_pad(img, sx, ex, sy, ey, psx, pex, psy, pey):
322
+ if CONSTANTS["IMAGE_PADDING"] == 0:
323
+ return img
324
+
325
+ return img[
326
+ :,
327
+ :,
328
+ sy - psy : ey - pey if ey != pey else ey,
329
+ sx - psx : ex - pex if ex != pex else ex,
330
+ ]
331
+
332
+
333
+ def render_bg(
334
+ patch_size, gancraft_bg, hf_seg, voxel_id, depth2, raydirs, cam_origin, z
335
+ ):
336
+ assert hf_seg.size(2) == CONSTANTS["EXTENDED_VOL_SIZE"]
337
+ assert hf_seg.size(3) == CONSTANTS["EXTENDED_VOL_SIZE"]
338
+ hf_seg = hf_seg[
339
+ :,
340
+ :,
341
+ CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
342
+ CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"],
343
+ ]
344
+ assert hf_seg.size(2) == CONSTANTS["LAYOUT_VOL_SIZE"]
345
+ assert hf_seg.size(3) == CONSTANTS["LAYOUT_VOL_SIZE"]
346
+
347
+ blurrer = torchvision.transforms.GaussianBlur(kernel_size=3, sigma=(2, 2))
348
+ _voxel_id = copy.deepcopy(voxel_id)
349
+ _voxel_id[voxel_id >= CONSTANTS["BLD_INS_LABEL_MIN"]] = CLASSES["BLD_FACADE"]
350
+ assert (_voxel_id < CONSTANTS["LAYOUT_N_CLASSES"]).all()
351
+ bg_img = torch.zeros(
352
+ 1,
353
+ 3,
354
+ CONSTANTS["GES_IMAGE_HEIGHT"],
355
+ CONSTANTS["GES_IMAGE_WIDTH"],
356
+ dtype=torch.float32,
357
+ device=gancraft_bg.output_device,
358
+ )
359
+ # Render background patches by patch to avoid OOM
360
+ for i in range(CONSTANTS["GES_IMAGE_HEIGHT"] // patch_size[0]):
361
+ for j in range(CONSTANTS["GES_IMAGE_WIDTH"] // patch_size[1]):
362
+ sy, sx = i * patch_size[0], j * patch_size[1]
363
+ ey, ex = sy + patch_size[0], sx + patch_size[1]
364
+ psx, pex, psy, pey = _get_pad_img_bbox(sx, ex, sy, ey)
365
+ output_bg = gancraft_bg(
366
+ hf_seg=hf_seg,
367
+ voxel_id=_voxel_id[:, psy:pey, psx:pex],
368
+ depth2=depth2[:, psy:pey, psx:pex],
369
+ raydirs=raydirs[:, psy:pey, psx:pex],
370
+ cam_origin=cam_origin,
371
+ building_stats=None,
372
+ z=z,
373
+ deterministic=True,
374
+ )
375
+ # Make road blurry
376
+ road_mask = (
377
+ (_voxel_id[:, None, psy:pey, psx:pex, 0, 0] == CLASSES["ROAD"])
378
+ .repeat(1, 3, 1, 1)
379
+ .float()
380
+ )
381
+ output_bg = blurrer(output_bg) * road_mask + output_bg * (1 - road_mask)
382
+ bg_img[:, :, sy:ey, sx:ex] = _get_img_without_pad(
383
+ output_bg, sx, ex, sy, ey, psx, pex, psy, pey
384
+ )
385
+
386
+ return bg_img
387
+
388
+
389
+ def render_fg(
390
+ patch_size,
391
+ gancraft_fg,
392
+ building_id,
393
+ hf_seg,
394
+ voxel_id,
395
+ depth2,
396
+ raydirs,
397
+ cam_origin,
398
+ building_stats,
399
+ building_z,
400
+ ):
401
+ _voxel_id = copy.deepcopy(voxel_id)
402
+ _curr_bld = torch.tensor([building_id, building_id - 1], device=voxel_id.device)
403
+ _voxel_id[~torch.isin(_voxel_id, _curr_bld)] = 0
404
+ _voxel_id[voxel_id == building_id] = CLASSES["BLD_FACADE"]
405
+ _voxel_id[voxel_id == building_id - 1] = CLASSES["BLD_ROOF"]
406
+
407
+ # assert (_voxel_id < CONSTANTS["LAYOUT_N_CLASSES"]).all()
408
+ _hf_seg = copy.deepcopy(hf_seg)
409
+ _hf_seg[hf_seg != building_id] = 0
410
+ _hf_seg[hf_seg == building_id] = CLASSES["BLD_FACADE"]
411
+ _raydirs = copy.deepcopy(raydirs)
412
+ _raydirs[_voxel_id[..., 0, 0] == 0] = 0
413
+
414
+ # Crop the "hf_seg" image using the center of the target building as the reference
415
+ cx = CONSTANTS["EXTENDED_VOL_SIZE"] // 2 - int(building_stats[1])
416
+ cy = CONSTANTS["EXTENDED_VOL_SIZE"] // 2 - int(building_stats[0])
417
+ sx = cx - CONSTANTS["BUILDING_VOL_SIZE"] // 2
418
+ ex = cx + CONSTANTS["BUILDING_VOL_SIZE"] // 2
419
+ sy = cy - CONSTANTS["BUILDING_VOL_SIZE"] // 2
420
+ ey = cy + CONSTANTS["BUILDING_VOL_SIZE"] // 2
421
+ _hf_seg = hf_seg[:, :, sy:ey, sx:ex]
422
+
423
+ fg_img = torch.zeros(
424
+ 1,
425
+ 3,
426
+ CONSTANTS["GES_IMAGE_HEIGHT"],
427
+ CONSTANTS["GES_IMAGE_WIDTH"],
428
+ dtype=torch.float32,
429
+ device=gancraft_fg.output_device,
430
+ )
431
+ fg_mask = torch.zeros(
432
+ 1,
433
+ 1,
434
+ CONSTANTS["GES_IMAGE_HEIGHT"],
435
+ CONSTANTS["GES_IMAGE_WIDTH"],
436
+ dtype=torch.float32,
437
+ device=gancraft_fg.output_device,
438
+ )
439
+ # Prevent some buildings are out of bound.
440
+ # THIS SHOULD NEVER HAPPEN AGAIN.
441
+ # if (
442
+ # _hf_seg.size(2) != CONSTANTS["BUILDING_VOL_SIZE"]
443
+ # or _hf_seg.size(3) != CONSTANTS["BUILDING_VOL_SIZE"]
444
+ # ):
445
+ # return fg_img, fg_mask
446
+
447
+ # Render foreground patches by patch to avoid OOM
448
+ for i in range(CONSTANTS["GES_IMAGE_HEIGHT"] // patch_size[0]):
449
+ for j in range(CONSTANTS["GES_IMAGE_WIDTH"] // patch_size[1]):
450
+ sy, sx = i * patch_size[0], j * patch_size[1]
451
+ ey, ex = sy + patch_size[0], sx + patch_size[1]
452
+ psx, pex, psy, pey = _get_pad_img_bbox(sx, ex, sy, ey)
453
+
454
+ if torch.count_nonzero(_raydirs[:, sy:ey, sx:ex]) > 0:
455
+ output_fg = gancraft_fg(
456
+ _hf_seg,
457
+ _voxel_id[:, psy:pey, psx:pex],
458
+ depth2[:, psy:pey, psx:pex],
459
+ _raydirs[:, psy:pey, psx:pex],
460
+ cam_origin,
461
+ building_stats=torch.from_numpy(np.array(building_stats)).unsqueeze(
462
+ dim=0
463
+ ),
464
+ z=building_z,
465
+ deterministic=True,
466
+ )
467
+ facade_mask = (
468
+ voxel_id[:, sy:ey, sx:ex, 0, 0] == building_id
469
+ ).unsqueeze(dim=1)
470
+ roof_mask = (
471
+ voxel_id[:, sy:ey, sx:ex, 0, 0] == building_id - 1
472
+ ).unsqueeze(dim=1)
473
+ facade_img = facade_mask * _get_img_without_pad(
474
+ output_fg, sx, ex, sy, ey, psx, pex, psy, pey
475
+ )
476
+ # Make roof blurry
477
+ # output_fg = F.interpolate(
478
+ # F.interpolate(output_fg * 0.8, scale_factor=0.75),
479
+ # scale_factor=4 / 3,
480
+ # ),
481
+ roof_img = roof_mask * _get_img_without_pad(
482
+ output_fg,
483
+ sx,
484
+ ex,
485
+ sy,
486
+ ey,
487
+ psx,
488
+ pex,
489
+ psy,
490
+ pey,
491
+ )
492
+ fg_mask[:, :, sy:ey, sx:ex] = torch.logical_or(facade_mask, roof_mask)
493
+ fg_img[:, :, sy:ey, sx:ex] = (
494
+ facade_img * facade_mask + roof_img * roof_mask
495
+ )
496
+
497
+ return fg_img, fg_mask
498
+
499
+
500
+ def render(
501
+ patch_size,
502
+ seg_volume,
503
+ hf_seg,
504
+ cam_pos,
505
+ gancraft_bg,
506
+ gancraft_fg,
507
+ building_stats,
508
+ bg_z,
509
+ building_zs,
510
+ ):
511
+ voxel_id, depth2, raydirs, cam_origin = get_voxel_intersection_perspective(
512
+ seg_volume, cam_pos
513
+ )
514
+ buildings = torch.unique(voxel_id[voxel_id > CONSTANTS["BLD_INS_LABEL_MIN"]])
515
+ # Remove odd numbers from the list because they are reserved by roofs.
516
+ buildings = buildings[buildings % 2 == 0]
517
+ with torch.no_grad():
518
+ bg_img = render_bg(
519
+ patch_size, gancraft_bg, hf_seg, voxel_id, depth2, raydirs, cam_origin, bg_z
520
+ )
521
+ for b in buildings:
522
+ assert b % 2 == 0, "Building Instance ID MUST be an even number."
523
+ fg_img, fg_mask = render_fg(
524
+ patch_size,
525
+ gancraft_fg,
526
+ b.item(),
527
+ hf_seg,
528
+ voxel_id,
529
+ depth2,
530
+ raydirs,
531
+ cam_origin,
532
+ building_stats[b.item()],
533
+ building_zs[b.item()],
534
+ )
535
+ bg_img = bg_img * (1 - fg_mask) + fg_img * fg_mask
536
+
537
+ return bg_img
citydreamer/model.py ADDED
@@ -0,0 +1,1264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: gancraft.py
4
+ # @Author: Haozhe Xie
5
+ # @Date: 2023-04-12 19:53:21
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-03-03 11:15:36
8
+ # @Email: root@haozhexie.com
9
+ # @Ref: https://github.com/FrozenBurning/SceneDreamer
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ import citydreamer.extensions.grid_encoder
16
+
17
+
18
+ class GanCraftGenerator(torch.nn.Module):
19
+ def __init__(self, cfg):
20
+ super(GanCraftGenerator, self).__init__()
21
+ self.cfg = cfg
22
+ self.render_net = RenderMLP(cfg)
23
+ self.denoiser = RenderCNN(cfg)
24
+ if cfg.NETWORK.GANCRAFT.ENCODER == "GLOBAL":
25
+ self.encoder = GlobalEncoder(cfg)
26
+ elif cfg.NETWORK.GANCRAFT.ENCODER == "LOCAL":
27
+ self.encoder = LocalEncoder(cfg)
28
+ else:
29
+ self.encoder = None
30
+
31
+ if (
32
+ not cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
33
+ and not cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
34
+ ):
35
+ raise ValueError(
36
+ "Either POS_EMD_INCUDE_CORDS or POS_EMD_INCUDE_FEATURES should be True."
37
+ )
38
+
39
+ if cfg.NETWORK.GANCRAFT.POS_EMD == "HASH_GRID":
40
+ grid_encoder_in_dim = 3 if cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS else 0
41
+ if (
42
+ cfg.NETWORK.GANCRAFT.ENCODER in ["GLOBAL", "LOCAL"]
43
+ and cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
44
+ ):
45
+ grid_encoder_in_dim += cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM
46
+
47
+ self.pos_encoder = citydreamer.extensions.grid_encoder.GridEncoder(
48
+ in_channels=grid_encoder_in_dim,
49
+ n_levels=cfg.NETWORK.GANCRAFT.HASH_GRID_N_LEVELS,
50
+ lvl_channels=cfg.NETWORK.GANCRAFT.HASH_GRID_LEVEL_DIM,
51
+ desired_resolution=cfg.NETWORK.GANCRAFT.HASH_GRID_RESOLUTION,
52
+ )
53
+ elif cfg.NETWORK.GANCRAFT.POS_EMD == "SIN_COS":
54
+ self.pos_encoder = SinCosEncoder(cfg)
55
+
56
+ def forward(
57
+ self,
58
+ hf_seg,
59
+ voxel_id,
60
+ depth2,
61
+ raydirs,
62
+ cam_origin,
63
+ building_stats=None,
64
+ z=None,
65
+ deterministic=False,
66
+ ):
67
+ r"""GANcraft Generator forward.
68
+
69
+ Args:
70
+ hf_seg (N x (1 + M) x H' x W' tensor) : height field + seg map, where M is the number of classes.
71
+ voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray.
72
+ depth2 (N x H x W x 2 x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel
73
+ intersection.
74
+ raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
75
+ cam_origin (N x 3 tensor): Camera origins.
76
+ building_stats (N x 5 tensor): The dy, dx, h, w, ID of the target building. (Only used in building mode)
77
+ z (N x STYLE_DIM tensor): The style vector.
78
+ deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling.
79
+ Returns:
80
+ fake_images (N x 3 x H x W tensor): fake images
81
+ """
82
+ bs, device = hf_seg.size(0), hf_seg.device
83
+ if z is None and self.cfg.NETWORK.GANCRAFT.STYLE_DIM is not None:
84
+ z = torch.randn(
85
+ bs,
86
+ self.cfg.NETWORK.GANCRAFT.STYLE_DIM,
87
+ dtype=torch.float32,
88
+ device=device,
89
+ )
90
+
91
+ features = None
92
+ if self.encoder is not None:
93
+ features = self.encoder(hf_seg)
94
+
95
+ net_out = self._forward_perpix(
96
+ features,
97
+ voxel_id,
98
+ depth2,
99
+ raydirs,
100
+ cam_origin,
101
+ z,
102
+ building_stats,
103
+ deterministic,
104
+ )
105
+ fake_images = self._forward_global(net_out, z)
106
+ return fake_images
107
+
108
+ def _forward_perpix(
109
+ self,
110
+ features,
111
+ voxel_id,
112
+ depth2,
113
+ raydirs,
114
+ cam_origin,
115
+ z,
116
+ building_stats=None,
117
+ deterministic=False,
118
+ ):
119
+ r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features
120
+
121
+ Args:
122
+ features (N x C1 tensor): Local features determined by the current pixel.
123
+ voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels
124
+ depth2 (N x H x W x 2 x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection.
125
+ raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
126
+ cam_origin (N x 3 tensor): Camera origins.
127
+ z (N x C3 tensor): Intermediate style vectors.
128
+ building_stats (N x 4 tensor): The dy, dx, h, w of the target building. (Only used in building mode)
129
+ deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling.
130
+ """
131
+ # Generate sky_mask; PE transform on ray direction.
132
+ with torch.no_grad():
133
+ # sky_only_mask: when True, ray hits nothing but sky
134
+ sky_only_mask = voxel_id[:, :, :, [0], :] == 0
135
+
136
+ with torch.no_grad():
137
+ normalized_cord, new_dists, new_idx = self._get_sampled_coordinates(
138
+ self.cfg.NETWORK.GANCRAFT.N_SAMPLE_POINTS_PER_RAY,
139
+ depth2,
140
+ raydirs,
141
+ cam_origin,
142
+ building_stats,
143
+ deterministic,
144
+ )
145
+ # Generate per-sample segmentation label
146
+ seg_map_bev = torch.gather(voxel_id, -2, new_idx)
147
+ # print(seg_map_bev.size()) # torch.Size([N, H, W, n_samples + 1, 1])
148
+ # In Building Mode, the one more channel is used for building roofs
149
+ n_classes = (
150
+ self.cfg.NETWORK.GANCRAFT.N_CLASSES + 1
151
+ if self.cfg.NETWORK.GANCRAFT.BUILDING_MODE
152
+ else self.cfg.NETWORK.GANCRAFT.N_CLASSES
153
+ )
154
+ seg_map_bev_onehot = torch.zeros(
155
+ [
156
+ seg_map_bev.size(0),
157
+ seg_map_bev.size(1),
158
+ seg_map_bev.size(2),
159
+ seg_map_bev.size(3),
160
+ n_classes,
161
+ ],
162
+ dtype=torch.float,
163
+ device=voxel_id.device,
164
+ )
165
+ # print(seg_map_bev_onehot.size()) # torch.Size([N, H, W, n_samples + 1, 1])
166
+ seg_map_bev_onehot.scatter_(-1, seg_map_bev.long(), 1.0)
167
+
168
+ net_out_s, net_out_c = self._forward_perpix_sub(
169
+ features, normalized_cord, z, seg_map_bev_onehot
170
+ )
171
+
172
+ # Blending
173
+ weights = self._volum_rendering_relu(
174
+ net_out_s, new_dists * self.cfg.NETWORK.GANCRAFT.DIST_SCALE, dim=-2
175
+ )
176
+ # If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero.
177
+ weights = weights * torch.logical_not(sky_only_mask).float()
178
+ # print(weights.size()) # torch.Size([N, H, W, n_samples + 1, 1])
179
+
180
+ rgbs = torch.clamp(net_out_c, -1, 1) + 1
181
+ net_out = torch.sum(weights * rgbs, dim=-2, keepdim=True)
182
+ net_out = net_out.squeeze(-2)
183
+ net_out = net_out - 1
184
+ return net_out
185
+
186
+ def _get_sampled_coordinates(
187
+ self,
188
+ n_samples,
189
+ depth2,
190
+ raydirs,
191
+ cam_origin,
192
+ building_stats=None,
193
+ deterministic=False,
194
+ ):
195
+ # Random sample points along the ray
196
+ rand_depth, new_dists, new_idx = self._sample_depth_batched(
197
+ depth2,
198
+ n_samples + 1,
199
+ deterministic=deterministic,
200
+ use_box_boundaries=False,
201
+ sample_depth=3,
202
+ )
203
+ nan_mask = torch.isnan(rand_depth)
204
+ inf_mask = torch.isinf(rand_depth)
205
+ rand_depth[nan_mask | inf_mask] = 0.0
206
+ world_coord = raydirs * rand_depth + cam_origin[:, None, None, None, :]
207
+ # assert worldcoord2.shape[-1] == 3
208
+ if self.cfg.NETWORK.GANCRAFT.BUILDING_MODE:
209
+ assert building_stats is not None
210
+ # Make the building object-centric
211
+ building_stats = building_stats[:, None, None, None, :].repeat(
212
+ 1, world_coord.size(1), world_coord.size(2), world_coord.size(3), 1
213
+ )
214
+ world_coord[..., 0] -= (
215
+ building_stats[..., 0] + self.cfg.NETWORK.GANCRAFT.CENTER_OFFSET
216
+ )
217
+ world_coord[..., 1] -= (
218
+ building_stats[..., 1] + self.cfg.NETWORK.GANCRAFT.CENTER_OFFSET
219
+ )
220
+ # TODO: Fix non-building rays
221
+ zero_rd_mask = raydirs.repeat(1, 1, 1, n_samples, 1)
222
+ world_coord[zero_rd_mask == 0] = 0
223
+
224
+ normalized_cord = self._get_normalized_coordinates(world_coord)
225
+ return normalized_cord, new_dists, new_idx
226
+
227
+ def _get_normalized_coordinates(self, world_coord):
228
+ delimeter = torch.tensor(
229
+ self.cfg.NETWORK.GANCRAFT.NORMALIZE_DELIMETER, device=world_coord.device
230
+ )
231
+ normalized_cord = world_coord / delimeter * 2 - 1
232
+ # TODO: Temporary fix
233
+ normalized_cord[normalized_cord > 1] = 1
234
+ normalized_cord[normalized_cord < -1] = -1
235
+ # assert (normalized_cord <= 1).all()
236
+ # assert (normalized_cord >= -1).all()
237
+ # print(delimeter, torch.min(normalized_cord), torch.max(normalized_cord))
238
+ # print(normalized_cord.size()) # torch.Size([1, 192, 192, 24, 3])
239
+ return normalized_cord
240
+
241
+ def _sample_depth_batched(
242
+ self,
243
+ depth2,
244
+ n_samples,
245
+ deterministic=False,
246
+ use_box_boundaries=True,
247
+ sample_depth=3,
248
+ ):
249
+ r"""Make best effort to sample points within the same distance for every ray.
250
+ Exception: When there is not enough voxel.
251
+
252
+ Args:
253
+ depth2 (N x H x W x 2 x M x 1 tensor):
254
+ - N: Batch.
255
+ - H, W: Height, Width.
256
+ - 2: Entrance / exit depth for each intersected box.
257
+ - M: Number of intersected boxes along the ray.
258
+ - 1: One extra dim for consistent tensor dims.
259
+ depth2 can include NaNs.
260
+ deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling.
261
+ use_box_boundaries (bool): Whether to add the entrance / exit points into the sample.
262
+ sample_depth (float): Truncate the ray when it travels further than sample_depth inside voxels.
263
+ """
264
+ bs = depth2.size(0)
265
+ dim0 = depth2.size(1)
266
+ dim1 = depth2.size(2)
267
+ dists = depth2[:, :, :, 1] - depth2[:, :, :, 0]
268
+ dists[torch.isnan(dists)] = 0
269
+ # print(dists.size()) # torch.Size([N, H, W, M, 1])
270
+ accu_depth = torch.cumsum(dists, dim=-2)
271
+ # print(accu_depth.size()) # torch.Size([N, H, W, M, 1])
272
+ total_depth = accu_depth[..., [-1], :]
273
+ # print(total_depth.size()) # torch.Size([N, H, W, 1, 1])
274
+ total_depth = torch.clamp(total_depth, None, sample_depth)
275
+
276
+ # Ignore out of range box boundaries. Fill with random samples.
277
+ if use_box_boundaries:
278
+ boundary_samples = accu_depth.clone().detach()
279
+ boundary_samples_filler = torch.rand_like(boundary_samples) * total_depth
280
+ bad_mask = (accu_depth > sample_depth) | (dists == 0)
281
+ boundary_samples[bad_mask] = boundary_samples_filler[bad_mask]
282
+
283
+ rand_shape = [bs, dim0, dim1, n_samples, 1]
284
+ if deterministic:
285
+ rand_samples = torch.empty(
286
+ rand_shape, dtype=total_depth.dtype, device=total_depth.device
287
+ )
288
+ rand_samples[..., :, 0] = torch.linspace(0, 1, n_samples + 2)[1:-1]
289
+ else:
290
+ rand_samples = torch.rand(
291
+ rand_shape, dtype=total_depth.dtype, device=total_depth.device
292
+ )
293
+ # Stratified sampling as in NeRF
294
+ rand_samples = rand_samples / n_samples
295
+ rand_samples[..., :, 0] += torch.linspace(
296
+ 0, 1, n_samples + 1, device=rand_samples.device
297
+ )[:-1]
298
+
299
+ rand_samples = rand_samples * total_depth
300
+ # print(rand_samples.size()) # torch.Size([N, H, W, n_samples, 1])
301
+
302
+ # Can also include boundaries
303
+ if use_box_boundaries:
304
+ rand_samples = torch.cat(
305
+ [
306
+ rand_samples,
307
+ boundary_samples,
308
+ torch.zeros(
309
+ [bs, dim0, dim1, 1, 1],
310
+ dtype=total_depth.dtype,
311
+ device=total_depth.device,
312
+ ),
313
+ ],
314
+ dim=-2,
315
+ )
316
+ rand_samples, _ = torch.sort(rand_samples, dim=-2, descending=False)
317
+
318
+ midpoints = (rand_samples[..., 1:, :] + rand_samples[..., :-1, :]) / 2
319
+ # print(midpoints.size()) # torch.Size([N, H, W, n_samples, 1])
320
+ new_dists = rand_samples[..., 1:, :] - rand_samples[..., :-1, :]
321
+
322
+ # Scatter the random samples back
323
+ # print(midpoints.unsqueeze(-3).size()) # torch.Size([N, H, W, 1, n_samples, 1])
324
+ # print(accu_depth.unsqueeze(-2).size()) # torch.Size([N, H, W, M, 1, 1])
325
+ idx = torch.sum(midpoints.unsqueeze(-3) > accu_depth.unsqueeze(-2), dim=-3)
326
+ # print(idx.shape, idx.max(), idx.min()) # torch.Size([N, H, W, n_samples, 1]) max 5, min 0
327
+
328
+ depth_deltas = (
329
+ depth2[:, :, :, 0, 1:, :] - depth2[:, :, :, 1, :-1, :]
330
+ ) # There might be NaNs!
331
+ # print(depth_deltas.size()) # torch.Size([N, H, W, M, M - 1, 1])
332
+ depth_deltas = torch.cumsum(depth_deltas, dim=-2)
333
+ depth_deltas = torch.cat(
334
+ [depth2[:, :, :, 0, [0], :], depth_deltas + depth2[:, :, :, 0, [0], :]],
335
+ dim=-2,
336
+ )
337
+ heads = torch.gather(depth_deltas, -2, idx)
338
+ # print(heads.size()) # torch.Size([N, H, W, M, 1])
339
+ # print(torch.any(torch.isnan(heads)))
340
+ rand_depth = heads + midpoints
341
+ # print(rand_depth.size()) # torch.Size([N, H, W, M, n_samples, 1])
342
+ return rand_depth, new_dists, idx
343
+
344
+ def _volum_rendering_relu(self, sigma, dists, dim=2):
345
+ free_energy = F.relu(sigma) * dists
346
+ a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here
347
+ b = torch.exp(
348
+ -self._cumsum_exclusive(free_energy, dim=dim)
349
+ ) # probability of everything is empty up to now
350
+ return a * b # probability of the ray hits something here
351
+
352
+ def _cumsum_exclusive(self, tensor, dim):
353
+ cumsum = torch.cumsum(tensor, dim)
354
+ cumsum = torch.roll(cumsum, 1, dim)
355
+ cumsum.index_fill_(
356
+ dim, torch.tensor([0], dtype=torch.long, device=tensor.device), 0
357
+ )
358
+ return cumsum
359
+
360
+ def _forward_perpix_sub(self, features, normalized_cord, z, seg_map_bev_onehot):
361
+ r"""Forwarding the MLP.
362
+
363
+ Args:
364
+ features (N x C1 x ...? tensor): Local features determined by the current pixel.
365
+ normalized_coord (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. L is number of samples; N is batch size, always 1.
366
+ z (N x C3 tensor): Intermediate style vectors.
367
+ seg_map_bev_onehot (N x H x W x L x C4): One-hot segmentation maps.
368
+ Returns:
369
+ net_out_s (N x H x W x L x 1 tensor): Opacities.
370
+ net_out_c (N x H x W x L x C5 tensor): Color embeddings.
371
+ """
372
+ feature_in = torch.empty(
373
+ normalized_cord.size(0),
374
+ normalized_cord.size(1),
375
+ normalized_cord.size(2),
376
+ normalized_cord.size(3),
377
+ 0,
378
+ device=normalized_cord.device,
379
+ )
380
+ if self.cfg.NETWORK.GANCRAFT.ENCODER == "GLOBAL":
381
+ # print(features.size()) # torch.Size([N, ENCODER_OUT_DIM])
382
+ feature_in = features[:, None, None, None, :].repeat(
383
+ 1,
384
+ normalized_cord.size(1),
385
+ normalized_cord.size(2),
386
+ normalized_cord.size(3),
387
+ 1,
388
+ )
389
+ elif self.cfg.NETWORK.GANCRAFT.ENCODER == "LOCAL":
390
+ # print(features.size()) # torch.Size([N, ENCODER_OUT_DIM - 1, H, W])
391
+ # print(world_coord.size()) # torch.Size([N, H, W, L, 3])
392
+ # NOTE: grid specifies the sampling pixel locations normalized by the input spatial
393
+ # dimensions. Therefore, it should have most values in the range of [-1, 1].
394
+ grid = normalized_cord.permute(0, 3, 1, 2, 4).reshape(
395
+ -1, normalized_cord.size(1), normalized_cord.size(2), 3
396
+ )
397
+ # print(grid.size()) # torch.Size([N * L, H, W, 3])
398
+ feature_in = F.grid_sample(
399
+ features.repeat(grid.size(0), 1, 1, 1),
400
+ grid[..., [1, 0]],
401
+ align_corners=False,
402
+ )
403
+ # print(feature_in.size()) # torch.Size([N * L, ENCODER_OUT_DIM - 1, H, W])
404
+ feature_in = feature_in.reshape(
405
+ normalized_cord.size(0),
406
+ normalized_cord.size(3),
407
+ feature_in.size(1),
408
+ feature_in.size(2),
409
+ feature_in.size(3),
410
+ ).permute(0, 3, 4, 1, 2)
411
+ # print(feature_in.size()) # torch.Size([N, H, W, L, ENCODER_OUT_DIM - 1])
412
+ feature_in = torch.cat([feature_in, normalized_cord[..., [2]]], dim=-1)
413
+ # print(feature_in.size()) # torch.Size([N, H, W, L, ENCODER_OUT_DIM])
414
+
415
+ if self.cfg.NETWORK.GANCRAFT.POS_EMD in ["HASH_GRID", "SIN_COS"]:
416
+ if (
417
+ self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
418
+ and self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
419
+ ):
420
+ feature_in = self.pos_encoder(
421
+ torch.cat([normalized_cord, feature_in], dim=-1)
422
+ )
423
+ elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS:
424
+ feature_in = torch.cat(
425
+ [self.pos_encoder(normalized_cord), feature_in], dim=-1
426
+ )
427
+ elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES:
428
+ # Ignore normalized_cord here to make it decoupled with coordinates
429
+ feature_in = torch.cat([self.pos_encoder(feature_in)], dim=-1)
430
+ else:
431
+ if (
432
+ self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
433
+ and self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
434
+ ):
435
+ feature_in = torch.cat([normalized_cord, feature_in], dim=-1)
436
+ elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS:
437
+ feature_in = normalized_cord
438
+ elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES:
439
+ feature_in = feature_in
440
+
441
+ net_out_s, net_out_c = self.render_net(feature_in, z, seg_map_bev_onehot)
442
+ return net_out_s, net_out_c
443
+
444
+ def _forward_global(self, net_out, z):
445
+ r"""Forward the CNN
446
+
447
+ Args:
448
+ net_out (N x C5 x H x W tensor): Intermediate feature maps.
449
+ z (N x C3 tensor): Intermediate style vectors.
450
+
451
+ Returns:
452
+ fake_images (N x 3 x H x W tensor): Output image.
453
+ """
454
+ fake_images = net_out.permute(0, 3, 1, 2).contiguous()
455
+ if self.denoiser is not None:
456
+ fake_images = self.denoiser(fake_images, z)
457
+ fake_images = torch.tanh(fake_images)
458
+
459
+ return fake_images
460
+
461
+
462
+ class GlobalEncoder(torch.nn.Module):
463
+ def __init__(self, cfg):
464
+ super(GlobalEncoder, self).__init__()
465
+ n_classes = cfg.NETWORK.GANCRAFT.N_CLASSES
466
+ self.hf_conv = torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1)
467
+ self.seg_conv = torch.nn.Conv2d(
468
+ n_classes,
469
+ 8,
470
+ kernel_size=3,
471
+ stride=2,
472
+ padding=1,
473
+ )
474
+ conv_blocks = []
475
+ cur_hidden_channels = 16
476
+ for _ in range(1, cfg.NETWORK.GANCRAFT.GLOBAL_ENCODER_N_BLOCKS):
477
+ conv_blocks.append(
478
+ SRTConvBlock(in_channels=cur_hidden_channels, out_channels=None)
479
+ )
480
+ cur_hidden_channels *= 2
481
+
482
+ self.conv_blocks = torch.nn.Sequential(*conv_blocks)
483
+ self.fc1 = torch.nn.Linear(cur_hidden_channels, 16)
484
+ self.fc2 = torch.nn.Linear(16, cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM)
485
+ self.act = torch.nn.LeakyReLU(0.2)
486
+
487
+ def forward(self, hf_seg):
488
+ hf = self.act(self.hf_conv(hf_seg[:, [0]]))
489
+ seg = self.act(self.seg_conv(hf_seg[:, 1:]))
490
+ out = torch.cat([hf, seg], dim=1)
491
+ for layer in self.conv_blocks:
492
+ out = self.act(layer(out))
493
+
494
+ out = out.permute(0, 2, 3, 1)
495
+ out = torch.mean(out.reshape(out.shape[0], -1, out.shape[-1]), dim=1)
496
+ cond = self.act(self.fc1(out))
497
+ cond = torch.tanh(self.fc2(cond))
498
+ return cond
499
+
500
+
501
+ class LocalEncoder(torch.nn.Module):
502
+ def __init__(self, cfg):
503
+ super(LocalEncoder, self).__init__()
504
+ n_classes = cfg.NETWORK.GANCRAFT.N_CLASSES
505
+ self.hf_conv = torch.nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3)
506
+ self.seg_conv = torch.nn.Conv2d(
507
+ n_classes, 32, kernel_size=7, stride=2, padding=3
508
+ )
509
+ if cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM == "BATCH_NORM":
510
+ self.bn1 = torch.nn.BatchNorm2d(64)
511
+ elif cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM == "GROUP_NORM":
512
+ self.bn1 = torch.nn.GroupNorm(32, 64)
513
+ else:
514
+ raise ValueError(
515
+ "Unknown normalization: %s" % cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM
516
+ )
517
+ self.conv2 = ResConvBlock(64, 128, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM)
518
+ self.conv3 = ResConvBlock(128, 256, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM)
519
+ self.conv4 = ResConvBlock(256, 512, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM)
520
+ self.dconv5 = torch.nn.ConvTranspose2d(
521
+ 512, 128, kernel_size=4, stride=2, padding=1
522
+ )
523
+ self.dconv6 = torch.nn.ConvTranspose2d(
524
+ 128, 32, kernel_size=4, stride=2, padding=1
525
+ )
526
+ self.dconv7 = torch.nn.Conv2d(
527
+ 32, cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM - 1, kernel_size=1
528
+ )
529
+
530
+ def forward(self, hf_seg):
531
+ hf = self.hf_conv(hf_seg[:, [0]])
532
+ seg = self.seg_conv(hf_seg[:, 1:])
533
+ out = F.relu(self.bn1(torch.cat([hf, seg], dim=1)), inplace=True)
534
+ # print(out.size()) # torch.Size([N, 64, H/2, W/2])
535
+ out = F.avg_pool2d(self.conv2(out), 2, stride=2)
536
+ # print(out.size()) # torch.Size([N, 128, H/4, W/4])
537
+ out = self.conv3(out)
538
+ # print(out.size()) # torch.Size([N, 256, H/4, W/4])
539
+ out = self.conv4(out)
540
+ # print(out.size()) # torch.Size([N, 512, H/4, W/4])
541
+ out = self.dconv5(out)
542
+ # print(out.size()) # torch.Size([N, 128, H/2, W/2])
543
+ out = self.dconv6(out)
544
+ # print(out.size()) # torch.Size([N, 32, H, W])
545
+ out = self.dconv7(out)
546
+ # print(out.size()) # torch.Size([N, OUT_DIM - 1, H, W])
547
+ return torch.tanh(out)
548
+
549
+
550
+ class SinCosEncoder(torch.nn.Module):
551
+ def __init__(self, cfg):
552
+ super(SinCosEncoder, self).__init__()
553
+ self.freq_bands = 2.0 ** torch.linspace(
554
+ 0,
555
+ cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS - 1,
556
+ steps=cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS,
557
+ )
558
+
559
+ def forward(self, features):
560
+ cord_sin = torch.cat(
561
+ [torch.sin(features * fb) for fb in self.freq_bands], dim=-1
562
+ )
563
+ cord_cos = torch.cat(
564
+ [torch.cos(features * fb) for fb in self.freq_bands], dim=-1
565
+ )
566
+ return torch.cat([cord_sin, cord_cos], dim=-1)
567
+
568
+
569
+ class RenderMLP(torch.nn.Module):
570
+ r"""MLP with affine modulation."""
571
+
572
+ def __init__(self, cfg):
573
+ super(RenderMLP, self).__init__()
574
+ in_dim = 0
575
+ f_dim = (
576
+ cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM
577
+ if cfg.NETWORK.GANCRAFT.ENCODER in ["GLOBAL", "LOCAL"]
578
+ else 0
579
+ )
580
+ if cfg.NETWORK.GANCRAFT.POS_EMD == "HASH_GRID":
581
+ in_dim = (
582
+ cfg.NETWORK.GANCRAFT.HASH_GRID_N_LEVELS
583
+ * cfg.NETWORK.GANCRAFT.HASH_GRID_LEVEL_DIM
584
+ )
585
+ in_dim += (
586
+ f_dim
587
+ if cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
588
+ and not cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
589
+ else 0
590
+ )
591
+ elif cfg.NETWORK.GANCRAFT.POS_EMD == "SIN_COS":
592
+ if (
593
+ cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
594
+ and cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
595
+ ):
596
+ in_dim = (3 + f_dim) * cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS * 2
597
+ elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS:
598
+ in_dim = 3 * cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS * 2 + f_dim
599
+ elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES:
600
+ in_dim = f_dim * cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS * 2
601
+ else:
602
+ if (
603
+ cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS
604
+ and cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES
605
+ ):
606
+ in_dim = 3 + f_dim
607
+ elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS:
608
+ in_dim = 3
609
+ elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES:
610
+ in_dim = f_dim
611
+
612
+ self.fc_m_a = torch.nn.Linear(
613
+ cfg.NETWORK.GANCRAFT.N_CLASSES + 1
614
+ if cfg.NETWORK.GANCRAFT.BUILDING_MODE
615
+ else cfg.NETWORK.GANCRAFT.N_CLASSES,
616
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
617
+ bias=False,
618
+ )
619
+ self.fc_1 = torch.nn.Linear(
620
+ in_dim,
621
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
622
+ )
623
+ self.fc_2 = (
624
+ ModLinear(
625
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
626
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
627
+ cfg.NETWORK.GANCRAFT.STYLE_DIM,
628
+ bias=False,
629
+ mod_bias=True,
630
+ output_mode=True,
631
+ )
632
+ if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
633
+ else torch.nn.Linear(
634
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
635
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
636
+ )
637
+ )
638
+ self.fc_3 = (
639
+ ModLinear(
640
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
641
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
642
+ cfg.NETWORK.GANCRAFT.STYLE_DIM,
643
+ bias=False,
644
+ mod_bias=True,
645
+ output_mode=True,
646
+ )
647
+ if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
648
+ else torch.nn.Linear(
649
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
650
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
651
+ )
652
+ )
653
+ self.fc_4 = (
654
+ ModLinear(
655
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
656
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
657
+ cfg.NETWORK.GANCRAFT.STYLE_DIM,
658
+ bias=False,
659
+ mod_bias=True,
660
+ output_mode=True,
661
+ )
662
+ if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
663
+ else torch.nn.Linear(
664
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
665
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
666
+ )
667
+ )
668
+ self.fc_sigma = (
669
+ torch.nn.Linear(
670
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
671
+ cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_SIGMA,
672
+ )
673
+ if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
674
+ else torch.nn.Linear(
675
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
676
+ cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_SIGMA,
677
+ )
678
+ )
679
+ self.fc_5 = (
680
+ ModLinear(
681
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
682
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
683
+ cfg.NETWORK.GANCRAFT.STYLE_DIM,
684
+ bias=False,
685
+ mod_bias=True,
686
+ output_mode=True,
687
+ )
688
+ if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
689
+ else torch.nn.Linear(
690
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
691
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
692
+ )
693
+ )
694
+ self.fc_6 = (
695
+ ModLinear(
696
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
697
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
698
+ cfg.NETWORK.GANCRAFT.STYLE_DIM,
699
+ bias=False,
700
+ mod_bias=True,
701
+ output_mode=True,
702
+ )
703
+ if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
704
+ else torch.nn.Linear(
705
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
706
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
707
+ )
708
+ )
709
+ self.fc_out_c = (
710
+ torch.nn.Linear(
711
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
712
+ cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_COLOR,
713
+ )
714
+ if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None
715
+ else torch.nn.Linear(
716
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
717
+ cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_COLOR,
718
+ )
719
+ )
720
+ self.act = torch.nn.LeakyReLU(negative_slope=0.2)
721
+
722
+ def forward(self, x, z, m):
723
+ r"""Forward network
724
+
725
+ Args:
726
+ x (N x H x W x M x in_channels tensor): Projected features.
727
+ z (N x cfg.NETWORK.GANCRAFT.STYLE_DIM tensor): Style codes.
728
+ m (N x H x W x M x mask_channels tensor): One-hot segmentation maps.
729
+ """
730
+ # b, h, w, n, _ = x.size()
731
+ if z is not None:
732
+ z = z[:, None, None, None, :]
733
+ f = self.fc_1(x)
734
+ f = f + self.fc_m_a(m)
735
+ # Common MLP
736
+ f = self.act(f)
737
+ f = self.act(self.fc_2(f, z)) if z is not None else self.act(self.fc_2(f))
738
+ f = self.act(self.fc_3(f, z)) if z is not None else self.act(self.fc_3(f))
739
+ f = self.act(self.fc_4(f, z)) if z is not None else self.act(self.fc_4(f))
740
+ # Sigma MLP
741
+ sigma = self.fc_sigma(f) if z is not None else self.act(self.fc_sigma(f))
742
+ # Color MLP
743
+ f = self.act(self.fc_5(f, z)) if z is not None else self.act(self.fc_5(f))
744
+ f = self.act(self.fc_6(f, z)) if z is not None else self.act(self.fc_6(f))
745
+ c = self.fc_out_c(f)
746
+ return sigma, c
747
+
748
+
749
+ class RenderCNN(torch.nn.Module):
750
+ r"""CNN converting intermediate feature map to final image."""
751
+
752
+ def __init__(self, cfg):
753
+ super(RenderCNN, self).__init__()
754
+ if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None:
755
+ self.fc_z_cond = torch.nn.Linear(
756
+ cfg.NETWORK.GANCRAFT.STYLE_DIM,
757
+ 2 * 2 * cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
758
+ )
759
+ self.conv1 = torch.nn.Conv2d(
760
+ cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_COLOR,
761
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
762
+ 1,
763
+ stride=1,
764
+ padding=0,
765
+ )
766
+ self.conv2a = torch.nn.Conv2d(
767
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
768
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
769
+ 3,
770
+ stride=1,
771
+ padding=1,
772
+ )
773
+ self.conv2b = torch.nn.Conv2d(
774
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
775
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
776
+ 3,
777
+ stride=1,
778
+ padding=1,
779
+ bias=False,
780
+ )
781
+ self.conv3a = torch.nn.Conv2d(
782
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
783
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
784
+ 3,
785
+ stride=1,
786
+ padding=1,
787
+ )
788
+ self.conv3b = torch.nn.Conv2d(
789
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
790
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
791
+ 3,
792
+ stride=1,
793
+ padding=1,
794
+ bias=False,
795
+ )
796
+ self.conv4a = torch.nn.Conv2d(
797
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
798
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
799
+ 1,
800
+ stride=1,
801
+ padding=0,
802
+ )
803
+ self.conv4b = torch.nn.Conv2d(
804
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
805
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
806
+ 1,
807
+ stride=1,
808
+ padding=0,
809
+ )
810
+ self.conv4 = torch.nn.Conv2d(
811
+ cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, 3, 1, stride=1, padding=0
812
+ )
813
+ self.act = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
814
+
815
+ def modulate(self, x, w, b):
816
+ w = w[..., None, None]
817
+ b = b[..., None, None]
818
+ return x * (w + 1) + b
819
+
820
+ def forward(self, x, z):
821
+ r"""Forward network.
822
+
823
+ Args:
824
+ x (N x in_channels x H x W tensor): Intermediate feature map
825
+ z (N x style_dim tensor): Style codes.
826
+ """
827
+ if z is not None:
828
+ z = self.fc_z_cond(z)
829
+ adapt = torch.chunk(z, 2 * 2, dim=-1)
830
+
831
+ y = self.act(self.conv1(x))
832
+ y = y + self.conv2b(self.act(self.conv2a(y)))
833
+ if z is not None:
834
+ y = self.act(self.modulate(y, adapt[0], adapt[1]))
835
+ else:
836
+ y = self.act(y)
837
+
838
+ y = y + self.conv3b(self.act(self.conv3a(y)))
839
+ if z is not None:
840
+ y = self.act(self.modulate(y, adapt[2], adapt[3]))
841
+ else:
842
+ y = self.act(y)
843
+
844
+ y = y + self.conv4b(self.act(self.conv4a(y)))
845
+ y = self.act(y)
846
+ y = self.conv4(y)
847
+
848
+ return y
849
+
850
+
851
+ class SRTConvBlock(torch.nn.Module):
852
+ def __init__(self, in_channels, hidden_channels=None, out_channels=None):
853
+ super(SRTConvBlock, self).__init__()
854
+ if hidden_channels is None:
855
+ hidden_channels = in_channels
856
+ if out_channels is None:
857
+ out_channels = 2 * hidden_channels
858
+
859
+ self.layers = torch.nn.Sequential(
860
+ torch.nn.Conv2d(
861
+ in_channels,
862
+ hidden_channels,
863
+ stride=1,
864
+ kernel_size=3,
865
+ padding=1,
866
+ bias=False,
867
+ ),
868
+ torch.nn.ReLU(),
869
+ torch.nn.Conv2d(
870
+ hidden_channels,
871
+ out_channels,
872
+ stride=2,
873
+ kernel_size=3,
874
+ padding=1,
875
+ bias=False,
876
+ ),
877
+ torch.nn.ReLU(),
878
+ )
879
+
880
+ def forward(self, x):
881
+ return self.layers(x)
882
+
883
+
884
+ class ResConvBlock(torch.nn.Module):
885
+ def __init__(self, in_channels, out_channels, norm, bias=False):
886
+ super(ResConvBlock, self).__init__()
887
+ # conv3x3(in_planes, int(out_planes / 2))
888
+ self.conv1 = torch.nn.Conv2d(
889
+ in_channels,
890
+ out_channels // 2,
891
+ kernel_size=3,
892
+ stride=1,
893
+ padding=1,
894
+ bias=bias,
895
+ )
896
+ # conv3x3(int(out_planes / 2), int(out_planes / 4))
897
+ self.conv2 = torch.nn.Conv2d(
898
+ out_channels // 2,
899
+ out_channels // 4,
900
+ kernel_size=3,
901
+ stride=1,
902
+ padding=1,
903
+ bias=bias,
904
+ )
905
+ # conv3x3(int(out_planes / 4), int(out_planes / 4))
906
+ self.conv3 = torch.nn.Conv2d(
907
+ out_channels // 4,
908
+ out_channels // 4,
909
+ kernel_size=3,
910
+ stride=1,
911
+ padding=1,
912
+ bias=bias,
913
+ )
914
+ if norm == "BATCH_NORM":
915
+ self.bn1 = torch.nn.BatchNorm2d(in_channels)
916
+ self.bn2 = torch.nn.BatchNorm2d(out_channels // 2)
917
+ self.bn3 = torch.nn.BatchNorm2d(out_channels // 4)
918
+ self.bn4 = torch.nn.BatchNorm2d(in_channels)
919
+ elif norm == "GROUP_NORM":
920
+ self.bn1 = torch.nn.GroupNorm(32, in_channels)
921
+ self.bn2 = torch.nn.GroupNorm(32, out_channels // 2)
922
+ self.bn3 = torch.nn.GroupNorm(32, out_channels // 4)
923
+ self.bn4 = torch.nn.GroupNorm(32, in_channels)
924
+
925
+ if in_channels != out_channels:
926
+ self.downsample = torch.nn.Sequential(
927
+ self.bn4,
928
+ torch.nn.ReLU(True),
929
+ torch.nn.Conv2d(
930
+ in_channels, out_channels, kernel_size=1, stride=1, bias=False
931
+ ),
932
+ )
933
+ else:
934
+ self.downsample = None
935
+
936
+ def forward(self, x):
937
+ residual = x
938
+ # print(residual.size()) # torch.Size([N, 64, H, W])
939
+ out1 = self.bn1(x)
940
+ out1 = F.relu(out1, True)
941
+ out1 = self.conv1(out1)
942
+ # print(out1.size()) # torch.Size([N, 64, H, W])
943
+ out2 = self.bn2(out1)
944
+ out2 = F.relu(out2, True)
945
+ out2 = self.conv2(out2)
946
+ # print(out2.size()) # torch.Size([N, 32, H, W])
947
+ out3 = self.bn3(out2)
948
+ out3 = F.relu(out3, True)
949
+ out3 = self.conv3(out3)
950
+ # print(out3.size()) # torch.Size([N, 32, H, W])
951
+ out3 = torch.cat((out1, out2, out3), dim=1)
952
+ # print(out3.size()) # torch.Size([N, 128, H, W])
953
+ if self.downsample is not None:
954
+ residual = self.downsample(residual)
955
+ # print(residual.size()) # torch.Size([N, 128, H, W])
956
+ out3 += residual
957
+ return out3
958
+
959
+
960
+ class ModLinear(torch.nn.Module):
961
+ r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod).
962
+ Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across
963
+ multiple inputs.
964
+ Args:
965
+ in_features (int): Number of input features.
966
+ out_features (int): Number of output features.
967
+ style_features (int): Number of style features.
968
+ bias (bool): Apply additive bias before the activation function?
969
+ mod_bias (bool): Whether to modulate bias.
970
+ output_mode (bool): If True, modulate output instead of input.
971
+ weight_gain (float): Initialization gain
972
+ """
973
+
974
+ def __init__(
975
+ self,
976
+ in_features,
977
+ out_features,
978
+ style_features,
979
+ bias=True,
980
+ mod_bias=True,
981
+ output_mode=False,
982
+ weight_gain=1,
983
+ bias_init=0,
984
+ ):
985
+ super(ModLinear, self).__init__()
986
+ weight_gain = weight_gain / np.sqrt(in_features)
987
+ self.weight = torch.nn.Parameter(
988
+ torch.randn([out_features, in_features]) * weight_gain
989
+ )
990
+ self.bias = (
991
+ torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
992
+ if bias
993
+ else None
994
+ )
995
+ self.weight_alpha = torch.nn.Parameter(
996
+ torch.randn([in_features, style_features]) / np.sqrt(style_features)
997
+ )
998
+ self.bias_alpha = torch.nn.Parameter(
999
+ torch.full([in_features], 1, dtype=torch.float)
1000
+ ) # init to 1
1001
+ self.weight_beta = None
1002
+ self.bias_beta = None
1003
+ self.mod_bias = mod_bias
1004
+ self.output_mode = output_mode
1005
+ if mod_bias:
1006
+ if output_mode:
1007
+ mod_bias_dims = out_features
1008
+ else:
1009
+ mod_bias_dims = in_features
1010
+ self.weight_beta = torch.nn.Parameter(
1011
+ torch.randn([mod_bias_dims, style_features]) / np.sqrt(style_features)
1012
+ )
1013
+ self.bias_beta = torch.nn.Parameter(
1014
+ torch.full([mod_bias_dims], 0, dtype=torch.float)
1015
+ )
1016
+
1017
+ @staticmethod
1018
+ def _linear_f(x, w, b):
1019
+ w = w.to(x.dtype)
1020
+ x_shape = x.shape
1021
+ x = x.reshape(-1, x_shape[-1])
1022
+ if b is not None:
1023
+ b = b.to(x.dtype)
1024
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
1025
+ else:
1026
+ x = x.matmul(w.t())
1027
+ x = x.reshape(*x_shape[:-1], -1)
1028
+ return x
1029
+
1030
+ # x: B, ... , Cin
1031
+ # z: B, 1, 1, , Cz
1032
+ def forward(self, x, z):
1033
+ x_shape = x.shape
1034
+ z_shape = z.shape
1035
+ x = x.reshape(x_shape[0], -1, x_shape[-1])
1036
+ z = z.reshape(z_shape[0], 1, z_shape[-1])
1037
+
1038
+ alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I]
1039
+ w = self.weight.to(x.dtype) # [O I]
1040
+ w = w.unsqueeze(0) * alpha # [1 O I] * [B 1 I] = [B O I]
1041
+
1042
+ if self.mod_bias:
1043
+ beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I]
1044
+ if not self.output_mode:
1045
+ x = x + beta
1046
+
1047
+ b = self.bias
1048
+ if b is not None:
1049
+ b = b.to(x.dtype)[None, None, :]
1050
+ if self.mod_bias and self.output_mode:
1051
+ if b is None:
1052
+ b = beta
1053
+ else:
1054
+ b = b + beta
1055
+
1056
+ # [B ? I] @ [B I O] = [B ? O]
1057
+ if b is not None:
1058
+ x = torch.baddbmm(b, x, w.transpose(1, 2))
1059
+ else:
1060
+ x = x.bmm(w.transpose(1, 2))
1061
+ x = x.reshape(*x_shape[:-1], x.shape[-1])
1062
+ return x
1063
+
1064
+
1065
+ class GanCraftDiscriminator(torch.nn.Module):
1066
+ def __init__(self, cfg):
1067
+ super(GanCraftDiscriminator, self).__init__()
1068
+ # bottom-up pathway
1069
+ # down_conv2d_block = Conv2dBlock, stride=2, kernel=3, padding=1, weight_norm=spectral
1070
+ # self.enc1 = down_conv2d_block(num_input_channels, num_filters) # 3
1071
+ self.enc1 = torch.nn.Sequential(
1072
+ torch.nn.utils.spectral_norm(
1073
+ torch.nn.Conv2d(
1074
+ 3, # RGB
1075
+ cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1076
+ stride=2,
1077
+ kernel_size=3,
1078
+ padding=1,
1079
+ bias=True,
1080
+ )
1081
+ ),
1082
+ torch.nn.LeakyReLU(0.2),
1083
+ )
1084
+ # self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters) # 7
1085
+ self.enc2 = torch.nn.Sequential(
1086
+ torch.nn.utils.spectral_norm(
1087
+ torch.nn.Conv2d(
1088
+ 1 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1089
+ 2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1090
+ stride=2,
1091
+ kernel_size=3,
1092
+ padding=1,
1093
+ bias=True,
1094
+ )
1095
+ ),
1096
+ torch.nn.LeakyReLU(0.2),
1097
+ )
1098
+ # self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters) # 15
1099
+ self.enc3 = torch.nn.Sequential(
1100
+ torch.nn.utils.spectral_norm(
1101
+ torch.nn.Conv2d(
1102
+ 2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1103
+ 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1104
+ stride=2,
1105
+ kernel_size=3,
1106
+ padding=1,
1107
+ bias=True,
1108
+ )
1109
+ ),
1110
+ torch.nn.LeakyReLU(0.2),
1111
+ )
1112
+ # self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters) # 31
1113
+ self.enc4 = torch.nn.Sequential(
1114
+ torch.nn.utils.spectral_norm(
1115
+ torch.nn.Conv2d(
1116
+ 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1117
+ 8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1118
+ stride=2,
1119
+ kernel_size=3,
1120
+ padding=1,
1121
+ bias=True,
1122
+ )
1123
+ ),
1124
+ torch.nn.LeakyReLU(0.2),
1125
+ )
1126
+ # self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters) # 63
1127
+ self.enc5 = torch.nn.Sequential(
1128
+ torch.nn.utils.spectral_norm(
1129
+ torch.nn.Conv2d(
1130
+ 8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1131
+ 8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1132
+ stride=2,
1133
+ kernel_size=3,
1134
+ padding=1,
1135
+ bias=True,
1136
+ )
1137
+ ),
1138
+ torch.nn.LeakyReLU(0.2),
1139
+ )
1140
+ # top-down pathway
1141
+ # latent_conv2d_block = Conv2dBlock, stride=1, kernel=1, weight_norm=spectral
1142
+ # self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters)
1143
+ self.lat2 = torch.nn.Sequential(
1144
+ torch.nn.utils.spectral_norm(
1145
+ torch.nn.Conv2d(
1146
+ 2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1147
+ 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1148
+ stride=1,
1149
+ kernel_size=1,
1150
+ bias=True,
1151
+ )
1152
+ ),
1153
+ torch.nn.LeakyReLU(0.2),
1154
+ )
1155
+ # self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters)
1156
+ self.lat3 = torch.nn.Sequential(
1157
+ torch.nn.utils.spectral_norm(
1158
+ torch.nn.Conv2d(
1159
+ 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1160
+ 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1161
+ stride=1,
1162
+ kernel_size=1,
1163
+ bias=True,
1164
+ )
1165
+ ),
1166
+ torch.nn.LeakyReLU(0.2),
1167
+ )
1168
+ # self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
1169
+ self.lat4 = torch.nn.Sequential(
1170
+ torch.nn.utils.spectral_norm(
1171
+ torch.nn.Conv2d(
1172
+ 8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1173
+ 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1174
+ stride=1,
1175
+ kernel_size=1,
1176
+ bias=True,
1177
+ )
1178
+ ),
1179
+ torch.nn.LeakyReLU(0.2),
1180
+ )
1181
+ # self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
1182
+ self.lat5 = torch.nn.Sequential(
1183
+ torch.nn.utils.spectral_norm(
1184
+ torch.nn.Conv2d(
1185
+ 8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1186
+ 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1187
+ stride=1,
1188
+ kernel_size=1,
1189
+ bias=True,
1190
+ )
1191
+ ),
1192
+ torch.nn.LeakyReLU(0.2),
1193
+ )
1194
+ # upsampling
1195
+ self.upsample2x = torch.nn.Upsample(
1196
+ scale_factor=2, mode="bilinear", align_corners=False
1197
+ )
1198
+ # final layers
1199
+ # stride1_conv2d_block = Conv2dBlock, stride=1, kernel=3, padding=1, weight_norm=spectral
1200
+ # self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
1201
+ self.final2 = torch.nn.Sequential(
1202
+ torch.nn.utils.spectral_norm(
1203
+ torch.nn.Conv2d(
1204
+ 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1205
+ 2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1206
+ stride=1,
1207
+ kernel_size=3,
1208
+ padding=1,
1209
+ bias=True,
1210
+ )
1211
+ ),
1212
+ torch.nn.LeakyReLU(0.2),
1213
+ )
1214
+ # self.output = Conv2dBlock(num_filters * 2, num_labels + 1, kernel_size=1)
1215
+ self.output = torch.nn.Sequential(
1216
+ torch.nn.Conv2d(
1217
+ 2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE,
1218
+ cfg.NETWORK.GANCRAFT.N_CLASSES + 1,
1219
+ stride=1,
1220
+ kernel_size=1,
1221
+ bias=True,
1222
+ ),
1223
+ torch.nn.LeakyReLU(0.2),
1224
+ )
1225
+ self.interpolator = self._smooth_interp
1226
+
1227
+ @staticmethod
1228
+ def _smooth_interp(x, size):
1229
+ r"""Smooth interpolation of segmentation maps.
1230
+
1231
+ Args:
1232
+ x (4D tensor): Segmentation maps.
1233
+ size(2D list): Target size (H, W).
1234
+ """
1235
+ x = F.interpolate(x, size=size, mode="area")
1236
+ onehot_idx = torch.argmax(x, dim=-3, keepdims=True)
1237
+ x.fill_(0.0)
1238
+ x.scatter_(1, onehot_idx, 1.0)
1239
+ return x
1240
+
1241
+ def _single_forward(self, images, seg_maps):
1242
+ # bottom-up pathway
1243
+ feat11 = self.enc1(images)
1244
+ feat12 = self.enc2(feat11)
1245
+ feat13 = self.enc3(feat12)
1246
+ feat14 = self.enc4(feat13)
1247
+ feat15 = self.enc5(feat14)
1248
+ # top-down pathway and lateral connections
1249
+ feat25 = self.lat5(feat15)
1250
+ feat24 = self.upsample2x(feat25) + self.lat4(feat14)
1251
+ feat23 = self.upsample2x(feat24) + self.lat3(feat13)
1252
+ feat22 = self.upsample2x(feat23) + self.lat2(feat12)
1253
+ # final prediction layers
1254
+ feat32 = self.final2(feat22)
1255
+
1256
+ label_map = self.interpolator(seg_maps, size=feat32.size()[2:])
1257
+ pred = self.output(feat32) # N, num_labels + 1, H//4, W//4
1258
+ return {"pred": pred, "label": label_map}
1259
+
1260
+ def forward(self, images, seg_maps, masks):
1261
+ # print(seg_maps.size()) # torch.Size([1, 7, H, W])
1262
+ # print(masks.size()) # torch.Size([1, 1, H, W])
1263
+ seg_maps = seg_maps * masks
1264
+ return self._single_forward(images * masks, seg_maps)
requirements.txt CHANGED
@@ -2,6 +2,9 @@
2
  torch==1.12.0
3
  torchvision
4
 
 
 
5
  numpy
6
  opencv-python
7
- gradio
 
 
2
  torch==1.12.0
3
  torchvision
4
 
5
+ easydict
6
+ gradio
7
  numpy
8
  opencv-python
9
+ pillow
10
+