ashawkey commited on
Commit
904ef7d
0 Parent(s):
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ build/
3
+ *.egg-info/
4
+ *.so
5
+
6
+ tmp*
7
+ data/
8
+ trial*/
9
+ .vs/
10
+
11
+ TOKEN
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hawkey
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
activation.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Function
3
+ from torch.cuda.amp import custom_bwd, custom_fwd
4
+
5
+ class _trunc_exp(Function):
6
+ @staticmethod
7
+ @custom_fwd(cast_inputs=torch.float)
8
+ def forward(ctx, x):
9
+ ctx.save_for_backward(x)
10
+ return torch.exp(x)
11
+
12
+ @staticmethod
13
+ @custom_bwd
14
+ def backward(ctx, g):
15
+ x = ctx.saved_tensors[0]
16
+ return g * torch.exp(x.clamp(-15, 15))
17
+
18
+ trunc_exp = _trunc_exp.apply
assets/gallery.md ADDED
File without changes
assets/update_logs.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ### 2022.10.5
2
+ * Basic reproduction finished.
3
+ * Non --cuda_ray, --tcnn are not working, need to fix.
4
+ * Shading is not working, disabled in utils.py for now. Surface normals are bad.
5
+ * Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry...
encoding.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class FreqEncoder(nn.Module):
6
+ def __init__(self, input_dim, max_freq_log2, N_freqs,
7
+ log_sampling=True, include_input=True,
8
+ periodic_fns=(torch.sin, torch.cos)):
9
+
10
+ super().__init__()
11
+
12
+ self.input_dim = input_dim
13
+ self.include_input = include_input
14
+ self.periodic_fns = periodic_fns
15
+
16
+ self.output_dim = 0
17
+ if self.include_input:
18
+ self.output_dim += self.input_dim
19
+
20
+ self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
21
+
22
+ if log_sampling:
23
+ self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
24
+ else:
25
+ self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)
26
+
27
+ self.freq_bands = self.freq_bands.numpy().tolist()
28
+
29
+ def forward(self, input, **kwargs):
30
+
31
+ out = []
32
+ if self.include_input:
33
+ out.append(input)
34
+
35
+ for i in range(len(self.freq_bands)):
36
+ freq = self.freq_bands[i]
37
+ for p_fn in self.periodic_fns:
38
+ out.append(p_fn(input * freq))
39
+
40
+ out = torch.cat(out, dim=-1)
41
+
42
+
43
+ return out
44
+
45
+ def get_encoder(encoding, input_dim=3,
46
+ multires=6,
47
+ degree=4,
48
+ num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
49
+ **kwargs):
50
+
51
+ if encoding == 'None':
52
+ return lambda x, **kwargs: x, input_dim
53
+
54
+ elif encoding == 'frequency':
55
+ #encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
56
+ from freqencoder import FreqEncoder
57
+ encoder = FreqEncoder(input_dim=input_dim, degree=multires)
58
+
59
+ elif encoding == 'sphere_harmonics':
60
+ from shencoder import SHEncoder
61
+ encoder = SHEncoder(input_dim=input_dim, degree=degree)
62
+
63
+ elif encoding == 'hashgrid':
64
+ from gridencoder import GridEncoder
65
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
66
+
67
+ elif encoding == 'tiledgrid':
68
+ from gridencoder import GridEncoder
69
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
70
+
71
+ elif encoding == 'ash':
72
+ from ashencoder import AshEncoder
73
+ encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution)
74
+
75
+ else:
76
+ raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
77
+
78
+ return encoder, encoder.output_dim
freqencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .freq import FreqEncoder
freqencoder/backend.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.cpp_extension import load
3
+
4
+ _src_path = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ nvcc_flags = [
7
+ '-O3', '-std=c++14',
8
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
+ '-use_fast_math'
10
+ ]
11
+
12
+ if os.name == "posix":
13
+ c_flags = ['-O3', '-std=c++14']
14
+ elif os.name == "nt":
15
+ c_flags = ['/O2', '/std:c++17']
16
+
17
+ # find cl.exe
18
+ def find_cl_path():
19
+ import glob
20
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
+ if paths:
23
+ return paths[0]
24
+
25
+ # If cl.exe is not on path, try to find it.
26
+ if os.system("where cl.exe >nul 2>nul") != 0:
27
+ cl_path = find_cl_path()
28
+ if cl_path is None:
29
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
+ os.environ["PATH"] += ";" + cl_path
31
+
32
+ _backend = load(name='_freqencoder',
33
+ extra_cflags=c_flags,
34
+ extra_cuda_cflags=nvcc_flags,
35
+ sources=[os.path.join(_src_path, 'src', f) for f in [
36
+ 'freqencoder.cu',
37
+ 'bindings.cpp',
38
+ ]],
39
+ )
40
+
41
+ __all__ = ['_backend']
freqencoder/freq.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.autograd import Function
6
+ from torch.autograd.function import once_differentiable
7
+ from torch.cuda.amp import custom_bwd, custom_fwd
8
+
9
+ try:
10
+ import _freqencoder as _backend
11
+ except ImportError:
12
+ from .backend import _backend
13
+
14
+
15
+ class _freq_encoder(Function):
16
+ @staticmethod
17
+ @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
18
+ def forward(ctx, inputs, degree, output_dim):
19
+ # inputs: [B, input_dim], float
20
+ # RETURN: [B, F], float
21
+
22
+ if not inputs.is_cuda: inputs = inputs.cuda()
23
+ inputs = inputs.contiguous()
24
+
25
+ B, input_dim = inputs.shape # batch size, coord dim
26
+
27
+ outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
28
+
29
+ _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
30
+
31
+ ctx.save_for_backward(inputs, outputs)
32
+ ctx.dims = [B, input_dim, degree, output_dim]
33
+
34
+ return outputs
35
+
36
+ @staticmethod
37
+ #@once_differentiable
38
+ @custom_bwd
39
+ def backward(ctx, grad):
40
+ # grad: [B, C * C]
41
+
42
+ grad = grad.contiguous()
43
+ inputs, outputs = ctx.saved_tensors
44
+ B, input_dim, degree, output_dim = ctx.dims
45
+
46
+ grad_inputs = torch.zeros_like(inputs)
47
+ _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
48
+
49
+ return grad_inputs, None, None
50
+
51
+
52
+ freq_encode = _freq_encoder.apply
53
+
54
+
55
+ class FreqEncoder(nn.Module):
56
+ def __init__(self, input_dim=3, degree=4):
57
+ super().__init__()
58
+
59
+ self.input_dim = input_dim
60
+ self.degree = degree
61
+ self.output_dim = input_dim + input_dim * 2 * degree
62
+
63
+ def __repr__(self):
64
+ return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
65
+
66
+ def forward(self, inputs, **kwargs):
67
+ # inputs: [..., input_dim]
68
+ # return: [..., ]
69
+
70
+ prefix_shape = list(inputs.shape[:-1])
71
+ inputs = inputs.reshape(-1, self.input_dim)
72
+
73
+ outputs = freq_encode(inputs, self.degree, self.output_dim)
74
+
75
+ outputs = outputs.reshape(prefix_shape + [self.output_dim])
76
+
77
+ return outputs
freqencoder/setup.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
+
5
+ _src_path = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ nvcc_flags = [
8
+ '-O3', '-std=c++14',
9
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
+ '-use_fast_math'
11
+ ]
12
+
13
+ if os.name == "posix":
14
+ c_flags = ['-O3', '-std=c++14']
15
+ elif os.name == "nt":
16
+ c_flags = ['/O2', '/std:c++17']
17
+
18
+ # find cl.exe
19
+ def find_cl_path():
20
+ import glob
21
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
22
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
23
+ if paths:
24
+ return paths[0]
25
+
26
+ # If cl.exe is not on path, try to find it.
27
+ if os.system("where cl.exe >nul 2>nul") != 0:
28
+ cl_path = find_cl_path()
29
+ if cl_path is None:
30
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
31
+ os.environ["PATH"] += ";" + cl_path
32
+
33
+ setup(
34
+ name='freqencoder', # package name, import this to use python API
35
+ ext_modules=[
36
+ CUDAExtension(
37
+ name='_freqencoder', # extension name, import this to use CUDA API
38
+ sources=[os.path.join(_src_path, 'src', f) for f in [
39
+ 'freqencoder.cu',
40
+ 'bindings.cpp',
41
+ ]],
42
+ extra_compile_args={
43
+ 'cxx': c_flags,
44
+ 'nvcc': nvcc_flags,
45
+ }
46
+ ),
47
+ ],
48
+ cmdclass={
49
+ 'build_ext': BuildExtension,
50
+ }
51
+ )
freqencoder/src/bindings.cpp ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include "freqencoder.h"
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
7
+ m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
8
+ }
freqencoder/src/freqencoder.cu ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdint.h>
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_fp16.h>
5
+ #include <cuda_runtime.h>
6
+
7
+ #include <ATen/cuda/CUDAContext.h>
8
+ #include <torch/torch.h>
9
+
10
+ #include <algorithm>
11
+ #include <stdexcept>
12
+
13
+ #include <cstdio>
14
+
15
+
16
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
17
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
18
+ #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
19
+ #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
20
+
21
+ inline constexpr __device__ float PI() { return 3.141592653589793f; }
22
+
23
+ template <typename T>
24
+ __host__ __device__ T div_round_up(T val, T divisor) {
25
+ return (val + divisor - 1) / divisor;
26
+ }
27
+
28
+ // inputs: [B, D]
29
+ // outputs: [B, C], C = D + D * deg * 2
30
+ __global__ void kernel_freq(
31
+ const float * __restrict__ inputs,
32
+ uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
33
+ float * outputs
34
+ ) {
35
+ // parallel on per-element
36
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
37
+ if (t >= B * C) return;
38
+
39
+ // get index
40
+ const uint32_t b = t / C;
41
+ const uint32_t c = t - b * C; // t % C;
42
+
43
+ // locate
44
+ inputs += b * D;
45
+ outputs += t;
46
+
47
+ // write self
48
+ if (c < D) {
49
+ outputs[0] = inputs[c];
50
+ // write freq
51
+ } else {
52
+ const uint32_t col = c / D - 1;
53
+ const uint32_t d = c % D;
54
+ const uint32_t freq = col / 2;
55
+ const float phase_shift = (col % 2) * (PI() / 2);
56
+ outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
57
+ }
58
+ }
59
+
60
+ // grad: [B, C], C = D + D * deg * 2
61
+ // outputs: [B, C]
62
+ // grad_inputs: [B, D]
63
+ __global__ void kernel_freq_backward(
64
+ const float * __restrict__ grad,
65
+ const float * __restrict__ outputs,
66
+ uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
67
+ float * grad_inputs
68
+ ) {
69
+ // parallel on per-element
70
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
71
+ if (t >= B * D) return;
72
+
73
+ const uint32_t b = t / D;
74
+ const uint32_t d = t - b * D; // t % D;
75
+
76
+ // locate
77
+ grad += b * C;
78
+ outputs += b * C;
79
+ grad_inputs += t;
80
+
81
+ // register
82
+ float result = grad[d];
83
+ grad += D;
84
+ outputs += D;
85
+
86
+ for (uint32_t f = 0; f < deg; f++) {
87
+ result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
88
+ grad += 2 * D;
89
+ outputs += 2 * D;
90
+ }
91
+
92
+ // write
93
+ grad_inputs[0] = result;
94
+ }
95
+
96
+
97
+ void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
98
+ CHECK_CUDA(inputs);
99
+ CHECK_CUDA(outputs);
100
+
101
+ CHECK_CONTIGUOUS(inputs);
102
+ CHECK_CONTIGUOUS(outputs);
103
+
104
+ CHECK_IS_FLOATING(inputs);
105
+ CHECK_IS_FLOATING(outputs);
106
+
107
+ static constexpr uint32_t N_THREADS = 128;
108
+
109
+ kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
110
+ }
111
+
112
+
113
+ void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
114
+ CHECK_CUDA(grad);
115
+ CHECK_CUDA(outputs);
116
+ CHECK_CUDA(grad_inputs);
117
+
118
+ CHECK_CONTIGUOUS(grad);
119
+ CHECK_CONTIGUOUS(outputs);
120
+ CHECK_CONTIGUOUS(grad_inputs);
121
+
122
+ CHECK_IS_FLOATING(grad);
123
+ CHECK_IS_FLOATING(outputs);
124
+ CHECK_IS_FLOATING(grad_inputs);
125
+
126
+ static constexpr uint32_t N_THREADS = 128;
127
+
128
+ kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
129
+ }
freqencoder/src/freqencoder.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # pragma once
2
+
3
+ #include <stdint.h>
4
+ #include <torch/torch.h>
5
+
6
+ // _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
7
+ void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
8
+
9
+ // _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
10
+ void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
gridencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .grid import GridEncoder
gridencoder/backend.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.cpp_extension import load
3
+
4
+ _src_path = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ nvcc_flags = [
7
+ '-O3', '-std=c++14',
8
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
+ ]
10
+
11
+ if os.name == "posix":
12
+ c_flags = ['-O3', '-std=c++14']
13
+ elif os.name == "nt":
14
+ c_flags = ['/O2', '/std:c++17']
15
+
16
+ # find cl.exe
17
+ def find_cl_path():
18
+ import glob
19
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21
+ if paths:
22
+ return paths[0]
23
+
24
+ # If cl.exe is not on path, try to find it.
25
+ if os.system("where cl.exe >nul 2>nul") != 0:
26
+ cl_path = find_cl_path()
27
+ if cl_path is None:
28
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29
+ os.environ["PATH"] += ";" + cl_path
30
+
31
+ _backend = load(name='_grid_encoder',
32
+ extra_cflags=c_flags,
33
+ extra_cuda_cflags=nvcc_flags,
34
+ sources=[os.path.join(_src_path, 'src', f) for f in [
35
+ 'gridencoder.cu',
36
+ 'bindings.cpp',
37
+ ]],
38
+ )
39
+
40
+ __all__ = ['_backend']
gridencoder/grid.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.autograd import Function
6
+ from torch.autograd.function import once_differentiable
7
+ from torch.cuda.amp import custom_bwd, custom_fwd
8
+
9
+ try:
10
+ import _gridencoder as _backend
11
+ except ImportError:
12
+ from .backend import _backend
13
+
14
+ _gridtype_to_id = {
15
+ 'hash': 0,
16
+ 'tiled': 1,
17
+ }
18
+
19
+ class _grid_encode(Function):
20
+ @staticmethod
21
+ @custom_fwd
22
+ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
23
+ # inputs: [B, D], float in [0, 1]
24
+ # embeddings: [sO, C], float
25
+ # offsets: [L + 1], int
26
+ # RETURN: [B, F], float
27
+
28
+ inputs = inputs.contiguous()
29
+
30
+ B, D = inputs.shape # batch size, coord dim
31
+ L = offsets.shape[0] - 1 # level
32
+ C = embeddings.shape[1] # embedding dim for each level
33
+ S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
34
+ H = base_resolution # base resolution
35
+
36
+ # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
37
+ # if C % 2 != 0, force float, since half for atomicAdd is very slow.
38
+ if torch.is_autocast_enabled() and C % 2 == 0:
39
+ embeddings = embeddings.to(torch.half)
40
+
41
+ # L first, optimize cache for cuda kernel, but needs an extra permute later
42
+ outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
43
+
44
+ if calc_grad_inputs:
45
+ dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
46
+ else:
47
+ dy_dx = None
48
+
49
+ _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners)
50
+
51
+ # permute back to [B, L * C]
52
+ outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
53
+
54
+ ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
55
+ ctx.dims = [B, D, C, L, S, H, gridtype]
56
+ ctx.align_corners = align_corners
57
+
58
+ return outputs
59
+
60
+ @staticmethod
61
+ #@once_differentiable
62
+ @custom_bwd
63
+ def backward(ctx, grad):
64
+
65
+ inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
66
+ B, D, C, L, S, H, gridtype = ctx.dims
67
+ align_corners = ctx.align_corners
68
+
69
+ # grad: [B, L * C] --> [L, B, C]
70
+ grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
71
+
72
+ grad_embeddings = torch.zeros_like(embeddings)
73
+
74
+ if dy_dx is not None:
75
+ grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
76
+ else:
77
+ grad_inputs = None
78
+
79
+ _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners)
80
+
81
+ if dy_dx is not None:
82
+ grad_inputs = grad_inputs.to(inputs.dtype)
83
+
84
+ return grad_inputs, grad_embeddings, None, None, None, None, None, None
85
+
86
+
87
+
88
+ grid_encode = _grid_encode.apply
89
+
90
+
91
+ class GridEncoder(nn.Module):
92
+ def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
93
+ super().__init__()
94
+
95
+ # the finest resolution desired at the last level, if provided, overridee per_level_scale
96
+ if desired_resolution is not None:
97
+ per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
98
+
99
+ self.input_dim = input_dim # coord dims, 2 or 3
100
+ self.num_levels = num_levels # num levels, each level multiply resolution by 2
101
+ self.level_dim = level_dim # encode channels per level
102
+ self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
103
+ self.log2_hashmap_size = log2_hashmap_size
104
+ self.base_resolution = base_resolution
105
+ self.output_dim = num_levels * level_dim
106
+ self.gridtype = gridtype
107
+ self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
108
+ self.align_corners = align_corners
109
+
110
+ # allocate parameters
111
+ offsets = []
112
+ offset = 0
113
+ self.max_params = 2 ** log2_hashmap_size
114
+ for i in range(num_levels):
115
+ resolution = int(np.ceil(base_resolution * per_level_scale ** i))
116
+ params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
117
+ params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
118
+ offsets.append(offset)
119
+ offset += params_in_level
120
+ offsets.append(offset)
121
+ offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
122
+ self.register_buffer('offsets', offsets)
123
+
124
+ self.n_params = offsets[-1] * level_dim
125
+
126
+ # parameters
127
+ self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
128
+
129
+ self.reset_parameters()
130
+
131
+ def reset_parameters(self):
132
+ std = 1e-4
133
+ self.embeddings.data.uniform_(-std, std)
134
+
135
+ def __repr__(self):
136
+ return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
137
+
138
+ def forward(self, inputs, bound=1):
139
+ # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
140
+ # return: [..., num_levels * level_dim]
141
+
142
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
143
+
144
+ #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
145
+
146
+ prefix_shape = list(inputs.shape[:-1])
147
+ inputs = inputs.view(-1, self.input_dim)
148
+
149
+ outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
150
+ outputs = outputs.view(prefix_shape + [self.output_dim])
151
+
152
+ #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
153
+
154
+ return outputs
gridencoder/setup.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
+
5
+ _src_path = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ nvcc_flags = [
8
+ '-O3', '-std=c++14',
9
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
+ ]
11
+
12
+ if os.name == "posix":
13
+ c_flags = ['-O3', '-std=c++14']
14
+ elif os.name == "nt":
15
+ c_flags = ['/O2', '/std:c++17']
16
+
17
+ # find cl.exe
18
+ def find_cl_path():
19
+ import glob
20
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
+ if paths:
23
+ return paths[0]
24
+
25
+ # If cl.exe is not on path, try to find it.
26
+ if os.system("where cl.exe >nul 2>nul") != 0:
27
+ cl_path = find_cl_path()
28
+ if cl_path is None:
29
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
+ os.environ["PATH"] += ";" + cl_path
31
+
32
+ setup(
33
+ name='gridencoder', # package name, import this to use python API
34
+ ext_modules=[
35
+ CUDAExtension(
36
+ name='_gridencoder', # extension name, import this to use CUDA API
37
+ sources=[os.path.join(_src_path, 'src', f) for f in [
38
+ 'gridencoder.cu',
39
+ 'bindings.cpp',
40
+ ]],
41
+ extra_compile_args={
42
+ 'cxx': c_flags,
43
+ 'nvcc': nvcc_flags,
44
+ }
45
+ ),
46
+ ],
47
+ cmdclass={
48
+ 'build_ext': BuildExtension,
49
+ }
50
+ )
gridencoder/src/bindings.cpp ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include "gridencoder.h"
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
7
+ m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
8
+ }
gridencoder/src/gridencoder.cu ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda.h>
2
+ #include <cuda_fp16.h>
3
+ #include <cuda_runtime.h>
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <torch/torch.h>
7
+
8
+ #include <algorithm>
9
+ #include <stdexcept>
10
+
11
+ #include <stdint.h>
12
+ #include <cstdio>
13
+
14
+
15
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
16
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
17
+ #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
18
+ #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
19
+
20
+
21
+ // just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
22
+ static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
23
+ // requires CUDA >= 10 and ARCH >= 70
24
+ // this is very slow compared to float or __half2, and never used.
25
+ //return atomicAdd(reinterpret_cast<__half*>(address), val);
26
+ }
27
+
28
+
29
+ template <typename T>
30
+ static inline __host__ __device__ T div_round_up(T val, T divisor) {
31
+ return (val + divisor - 1) / divisor;
32
+ }
33
+
34
+
35
+ template <uint32_t D>
36
+ __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
37
+ static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
38
+
39
+ // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
40
+ // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
41
+ // coordinates.
42
+ constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };
43
+
44
+ uint32_t result = 0;
45
+ #pragma unroll
46
+ for (uint32_t i = 0; i < D; ++i) {
47
+ result ^= pos_grid[i] * primes[i];
48
+ }
49
+
50
+ return result;
51
+ }
52
+
53
+
54
+ template <uint32_t D, uint32_t C>
55
+ __device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
56
+ uint32_t stride = 1;
57
+ uint32_t index = 0;
58
+
59
+ #pragma unroll
60
+ for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
61
+ index += pos_grid[d] * stride;
62
+ stride *= align_corners ? resolution: (resolution + 1);
63
+ }
64
+
65
+ // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
66
+ // gridtype: 0 == hash, 1 == tiled
67
+ if (gridtype == 0 && stride > hashmap_size) {
68
+ index = fast_hash<D>(pos_grid);
69
+ }
70
+
71
+ return (index % hashmap_size) * C + ch;
72
+ }
73
+
74
+
75
+ template <typename scalar_t, uint32_t D, uint32_t C>
76
+ __global__ void kernel_grid(
77
+ const float * __restrict__ inputs,
78
+ const scalar_t * __restrict__ grid,
79
+ const int * __restrict__ offsets,
80
+ scalar_t * __restrict__ outputs,
81
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
82
+ scalar_t * __restrict__ dy_dx,
83
+ const uint32_t gridtype,
84
+ const bool align_corners
85
+ ) {
86
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
87
+
88
+ if (b >= B) return;
89
+
90
+ const uint32_t level = blockIdx.y;
91
+
92
+ // locate
93
+ grid += (uint32_t)offsets[level] * C;
94
+ inputs += b * D;
95
+ outputs += level * B * C + b * C;
96
+
97
+ // check input range (should be in [0, 1])
98
+ bool flag_oob = false;
99
+ #pragma unroll
100
+ for (uint32_t d = 0; d < D; d++) {
101
+ if (inputs[d] < 0 || inputs[d] > 1) {
102
+ flag_oob = true;
103
+ }
104
+ }
105
+ // if input out of bound, just set output to 0
106
+ if (flag_oob) {
107
+ #pragma unroll
108
+ for (uint32_t ch = 0; ch < C; ch++) {
109
+ outputs[ch] = 0;
110
+ }
111
+ if (dy_dx) {
112
+ dy_dx += b * D * L * C + level * D * C; // B L D C
113
+ #pragma unroll
114
+ for (uint32_t d = 0; d < D; d++) {
115
+ #pragma unroll
116
+ for (uint32_t ch = 0; ch < C; ch++) {
117
+ dy_dx[d * C + ch] = 0;
118
+ }
119
+ }
120
+ }
121
+ return;
122
+ }
123
+
124
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
125
+ const float scale = exp2f(level * S) * H - 1.0f;
126
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
127
+
128
+ // calculate coordinate
129
+ float pos[D];
130
+ uint32_t pos_grid[D];
131
+
132
+ #pragma unroll
133
+ for (uint32_t d = 0; d < D; d++) {
134
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
135
+ pos_grid[d] = floorf(pos[d]);
136
+ pos[d] -= (float)pos_grid[d];
137
+ }
138
+
139
+ //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
140
+
141
+ // interpolate
142
+ scalar_t results[C] = {0}; // temp results in register
143
+
144
+ #pragma unroll
145
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
146
+ float w = 1;
147
+ uint32_t pos_grid_local[D];
148
+
149
+ #pragma unroll
150
+ for (uint32_t d = 0; d < D; d++) {
151
+ if ((idx & (1 << d)) == 0) {
152
+ w *= 1 - pos[d];
153
+ pos_grid_local[d] = pos_grid[d];
154
+ } else {
155
+ w *= pos[d];
156
+ pos_grid_local[d] = pos_grid[d] + 1;
157
+ }
158
+ }
159
+
160
+ uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
161
+
162
+ // writing to register (fast)
163
+ #pragma unroll
164
+ for (uint32_t ch = 0; ch < C; ch++) {
165
+ results[ch] += w * grid[index + ch];
166
+ }
167
+
168
+ //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
169
+ }
170
+
171
+ // writing to global memory (slow)
172
+ #pragma unroll
173
+ for (uint32_t ch = 0; ch < C; ch++) {
174
+ outputs[ch] = results[ch];
175
+ }
176
+
177
+ // prepare dy_dx
178
+ // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
179
+ if (dy_dx) {
180
+
181
+ dy_dx += b * D * L * C + level * D * C; // B L D C
182
+
183
+ #pragma unroll
184
+ for (uint32_t gd = 0; gd < D; gd++) {
185
+
186
+ scalar_t results_grad[C] = {0};
187
+
188
+ #pragma unroll
189
+ for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
190
+ float w = scale;
191
+ uint32_t pos_grid_local[D];
192
+
193
+ #pragma unroll
194
+ for (uint32_t nd = 0; nd < D - 1; nd++) {
195
+ const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
196
+
197
+ if ((idx & (1 << nd)) == 0) {
198
+ w *= 1 - pos[d];
199
+ pos_grid_local[d] = pos_grid[d];
200
+ } else {
201
+ w *= pos[d];
202
+ pos_grid_local[d] = pos_grid[d] + 1;
203
+ }
204
+ }
205
+
206
+ pos_grid_local[gd] = pos_grid[gd];
207
+ uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
208
+ pos_grid_local[gd] = pos_grid[gd] + 1;
209
+ uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
210
+
211
+ #pragma unroll
212
+ for (uint32_t ch = 0; ch < C; ch++) {
213
+ results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
214
+ }
215
+ }
216
+
217
+ #pragma unroll
218
+ for (uint32_t ch = 0; ch < C; ch++) {
219
+ dy_dx[gd * C + ch] = results_grad[ch];
220
+ }
221
+ }
222
+ }
223
+ }
224
+
225
+
226
+ template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
227
+ __global__ void kernel_grid_backward(
228
+ const scalar_t * __restrict__ grad,
229
+ const float * __restrict__ inputs,
230
+ const scalar_t * __restrict__ grid,
231
+ const int * __restrict__ offsets,
232
+ scalar_t * __restrict__ grad_grid,
233
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
234
+ const uint32_t gridtype,
235
+ const bool align_corners
236
+ ) {
237
+ const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
238
+ if (b >= B) return;
239
+
240
+ const uint32_t level = blockIdx.y;
241
+ const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
242
+
243
+ // locate
244
+ grad_grid += offsets[level] * C;
245
+ inputs += b * D;
246
+ grad += level * B * C + b * C + ch; // L, B, C
247
+
248
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
249
+ const float scale = exp2f(level * S) * H - 1.0f;
250
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
251
+
252
+ // check input range (should be in [0, 1])
253
+ #pragma unroll
254
+ for (uint32_t d = 0; d < D; d++) {
255
+ if (inputs[d] < 0 || inputs[d] > 1) {
256
+ return; // grad is init as 0, so we simply return.
257
+ }
258
+ }
259
+
260
+ // calculate coordinate
261
+ float pos[D];
262
+ uint32_t pos_grid[D];
263
+
264
+ #pragma unroll
265
+ for (uint32_t d = 0; d < D; d++) {
266
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
267
+ pos_grid[d] = floorf(pos[d]);
268
+ pos[d] -= (float)pos_grid[d];
269
+ }
270
+
271
+ scalar_t grad_cur[N_C] = {0}; // fetch to register
272
+ #pragma unroll
273
+ for (uint32_t c = 0; c < N_C; c++) {
274
+ grad_cur[c] = grad[c];
275
+ }
276
+
277
+ // interpolate
278
+ #pragma unroll
279
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
280
+ float w = 1;
281
+ uint32_t pos_grid_local[D];
282
+
283
+ #pragma unroll
284
+ for (uint32_t d = 0; d < D; d++) {
285
+ if ((idx & (1 << d)) == 0) {
286
+ w *= 1 - pos[d];
287
+ pos_grid_local[d] = pos_grid[d];
288
+ } else {
289
+ w *= pos[d];
290
+ pos_grid_local[d] = pos_grid[d] + 1;
291
+ }
292
+ }
293
+
294
+ uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
295
+
296
+ // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
297
+ // TODO: use float which is better than __half, if N_C % 2 != 0
298
+ if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
299
+ #pragma unroll
300
+ for (uint32_t c = 0; c < N_C; c += 2) {
301
+ // process two __half at once (by interpreting as a __half2)
302
+ __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
303
+ atomicAdd((__half2*)&grad_grid[index + c], v);
304
+ }
305
+ // float, or __half when N_C % 2 != 0 (which means C == 1)
306
+ } else {
307
+ #pragma unroll
308
+ for (uint32_t c = 0; c < N_C; c++) {
309
+ atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
310
+ }
311
+ }
312
+ }
313
+ }
314
+
315
+
316
+ template <typename scalar_t, uint32_t D, uint32_t C>
317
+ __global__ void kernel_input_backward(
318
+ const scalar_t * __restrict__ grad,
319
+ const scalar_t * __restrict__ dy_dx,
320
+ scalar_t * __restrict__ grad_inputs,
321
+ uint32_t B, uint32_t L
322
+ ) {
323
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
324
+ if (t >= B * D) return;
325
+
326
+ const uint32_t b = t / D;
327
+ const uint32_t d = t - b * D;
328
+
329
+ dy_dx += b * L * D * C;
330
+
331
+ scalar_t result = 0;
332
+
333
+ # pragma unroll
334
+ for (int l = 0; l < L; l++) {
335
+ # pragma unroll
336
+ for (int ch = 0; ch < C; ch++) {
337
+ result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
338
+ }
339
+ }
340
+
341
+ grad_inputs[t] = result;
342
+ }
343
+
344
+
345
+ template <typename scalar_t, uint32_t D>
346
+ void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
347
+ static constexpr uint32_t N_THREAD = 512;
348
+ const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
349
+ switch (C) {
350
+ case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
351
+ case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
352
+ case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
353
+ case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
354
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
355
+ }
356
+ }
357
+
358
+ // inputs: [B, D], float, in [0, 1]
359
+ // embeddings: [sO, C], float
360
+ // offsets: [L + 1], uint32_t
361
+ // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
362
+ // H: base resolution
363
+ // dy_dx: [B, L * D * C]
364
+ template <typename scalar_t>
365
+ void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
366
+ switch (D) {
367
+ case 1: kernel_grid_wrapper<scalar_t, 1>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
368
+ case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
369
+ case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
370
+ case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
371
+ case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
372
+ default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
373
+ }
374
+
375
+ }
376
+
377
+ template <typename scalar_t, uint32_t D>
378
+ void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
379
+ static constexpr uint32_t N_THREAD = 256;
380
+ const uint32_t N_C = std::min(2u, C); // n_features_per_thread
381
+ const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
382
+ switch (C) {
383
+ case 1:
384
+ kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
385
+ if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
386
+ break;
387
+ case 2:
388
+ kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
389
+ if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
390
+ break;
391
+ case 4:
392
+ kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
393
+ if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
394
+ break;
395
+ case 8:
396
+ kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
397
+ if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
398
+ break;
399
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
400
+ }
401
+ }
402
+
403
+
404
+ // grad: [L, B, C], float
405
+ // inputs: [B, D], float, in [0, 1]
406
+ // embeddings: [sO, C], float
407
+ // offsets: [L + 1], uint32_t
408
+ // grad_embeddings: [sO, C]
409
+ // H: base resolution
410
+ template <typename scalar_t>
411
+ void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
412
+ switch (D) {
413
+ case 1: kernel_grid_backward_wrapper<scalar_t, 1>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
414
+ case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
415
+ case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
416
+ case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
417
+ case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
418
+ default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
419
+ }
420
+ }
421
+
422
+
423
+
424
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners) {
425
+ CHECK_CUDA(inputs);
426
+ CHECK_CUDA(embeddings);
427
+ CHECK_CUDA(offsets);
428
+ CHECK_CUDA(outputs);
429
+ // CHECK_CUDA(dy_dx);
430
+
431
+ CHECK_CONTIGUOUS(inputs);
432
+ CHECK_CONTIGUOUS(embeddings);
433
+ CHECK_CONTIGUOUS(offsets);
434
+ CHECK_CONTIGUOUS(outputs);
435
+ // CHECK_CONTIGUOUS(dy_dx);
436
+
437
+ CHECK_IS_FLOATING(inputs);
438
+ CHECK_IS_FLOATING(embeddings);
439
+ CHECK_IS_INT(offsets);
440
+ CHECK_IS_FLOATING(outputs);
441
+ // CHECK_IS_FLOATING(dy_dx);
442
+
443
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
444
+ embeddings.scalar_type(), "grid_encode_forward", ([&] {
445
+ grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
446
+ }));
447
+ }
448
+
449
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners) {
450
+ CHECK_CUDA(grad);
451
+ CHECK_CUDA(inputs);
452
+ CHECK_CUDA(embeddings);
453
+ CHECK_CUDA(offsets);
454
+ CHECK_CUDA(grad_embeddings);
455
+ // CHECK_CUDA(dy_dx);
456
+ // CHECK_CUDA(grad_inputs);
457
+
458
+ CHECK_CONTIGUOUS(grad);
459
+ CHECK_CONTIGUOUS(inputs);
460
+ CHECK_CONTIGUOUS(embeddings);
461
+ CHECK_CONTIGUOUS(offsets);
462
+ CHECK_CONTIGUOUS(grad_embeddings);
463
+ // CHECK_CONTIGUOUS(dy_dx);
464
+ // CHECK_CONTIGUOUS(grad_inputs);
465
+
466
+ CHECK_IS_FLOATING(grad);
467
+ CHECK_IS_FLOATING(inputs);
468
+ CHECK_IS_FLOATING(embeddings);
469
+ CHECK_IS_INT(offsets);
470
+ CHECK_IS_FLOATING(grad_embeddings);
471
+ // CHECK_IS_FLOATING(dy_dx);
472
+ // CHECK_IS_FLOATING(grad_inputs);
473
+
474
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
475
+ grad.scalar_type(), "grid_encode_backward", ([&] {
476
+ grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
477
+ }));
478
+
479
+ }
gridencoder/src/gridencoder.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _HASH_ENCODE_H
2
+ #define _HASH_ENCODE_H
3
+
4
+ #include <stdint.h>
5
+ #include <torch/torch.h>
6
+
7
+ // inputs: [B, D], float, in [0, 1]
8
+ // embeddings: [sO, C], float
9
+ // offsets: [L + 1], uint32_t
10
+ // outputs: [B, L * C], float
11
+ // H: base resolution
12
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners);
13
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners);
14
+
15
+ #endif
loss.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def mape_loss(pred, target):
6
+ # pred, target: [B, 1], torch tenspr
7
+ difference = (pred - target).abs()
8
+ scale = 1 / (target.abs() + 1e-2)
9
+ loss = difference * scale
10
+
11
+ return loss.mean()
main_nerf.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+
4
+ from nerf.provider import NeRFDataset
5
+ from nerf.utils import *
6
+ from optimizer import Shampoo
7
+
8
+ from nerf.sd import StableDiffusion
9
+ from nerf.clip import CLIP
10
+ from nerf.gui import NeRFGUI
11
+
12
+ # torch.autograd.set_detect_anomaly(True)
13
+
14
+ if __name__ == '__main__':
15
+
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--text', help="text prompt")
18
+ parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --preload")
19
+ parser.add_argument('--test', action='store_true', help="test mode")
20
+ parser.add_argument('--workspace', type=str, default='workspace')
21
+ parser.add_argument('--guidance', type=str, default='stable-diffusion', help='choose from [stable-diffusion, clip]')
22
+ parser.add_argument('--seed', type=int, default=0)
23
+
24
+ ### training options
25
+ parser.add_argument('--iters', type=int, default=15000, help="training iters")
26
+ parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
27
+ parser.add_argument('--ckpt', type=str, default='latest')
28
+ parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
29
+ parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
30
+ parser.add_argument('--num_steps', type=int, default=256, help="num steps sampled per ray (only valid when not using --cuda_ray)")
31
+ parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
32
+ parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
33
+ parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
34
+ parser.add_argument('--albedo_iters', type=int, default=15000, help="training iters")
35
+ # model options
36
+ parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
37
+ parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
38
+ # network backbone
39
+ parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
40
+ parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
41
+ # rendering resolution in training
42
+ parser.add_argument('--w', type=int, default=64, help="render width for CLIP training (<=224)")
43
+ parser.add_argument('--h', type=int, default=64, help="render height for CLIP training (<=224)")
44
+
45
+ ### dataset options
46
+ parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
47
+ parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
48
+ parser.add_argument('--min_near', type=float, default=0.1, help="minimum near distance for camera")
49
+ parser.add_argument('--radius_range', type=float, nargs='*', default=[1.0, 1.5], help="training camera radius range")
50
+ parser.add_argument('--fovy_range', type=float, nargs='*', default=[40, 70], help="training camera fovy range")
51
+ parser.add_argument('--dir_text', action='store_true', help="direction encoded text prompt")
52
+
53
+ ### GUI options
54
+ parser.add_argument('--gui', action='store_true', help="start a GUI")
55
+ parser.add_argument('--W', type=int, default=800, help="GUI width")
56
+ parser.add_argument('--H', type=int, default=800, help="GUI height")
57
+ parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
58
+ parser.add_argument('--fovy', type=float, default=60, help="default GUI camera fovy")
59
+ parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction")
60
+ parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction")
61
+ parser.add_argument('--max_spp', type=int, default=64, help="GUI rendering max sample per pixel")
62
+
63
+ opt = parser.parse_args()
64
+
65
+ if opt.O:
66
+ opt.fp16 = True
67
+ opt.cuda_ray = True
68
+ opt.dir_text = True
69
+
70
+ if opt.backbone == 'vanilla':
71
+ from nerf.network import NeRFNetwork
72
+ elif opt.backbone == 'tcnn':
73
+ from nerf.network_tcnn import NeRFNetwork
74
+ elif opt.backbone == 'grid':
75
+ from nerf.network_grid import NeRFNetwork
76
+ else:
77
+ raise NotImplementedError(f'--backbone {opt.backbone} is not implemented!')
78
+
79
+ print(opt)
80
+
81
+ seed_everything(opt.seed)
82
+
83
+ model = NeRFNetwork(opt)
84
+
85
+ print(model)
86
+
87
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
88
+
89
+ if opt.test:
90
+ guidance = None # do not load guidance at test
91
+
92
+ trainer = Trainer('ngp', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
93
+
94
+ if opt.gui:
95
+ gui = NeRFGUI(opt, trainer)
96
+ gui.render()
97
+
98
+ else:
99
+ test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
100
+ trainer.test(test_loader)
101
+ trainer.save_mesh(resolution=256)
102
+
103
+ else:
104
+
105
+ if opt.guidance == 'stable-diffusion':
106
+ guidance = StableDiffusion(device)
107
+ elif opt.guidance == 'clip':
108
+ guidance = CLIP(device)
109
+ else:
110
+ raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.')
111
+
112
+ optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
113
+ # optimizer = lambda model: Shampoo(model.get_params(opt.lr))
114
+
115
+ train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader()
116
+
117
+ # decay to 0.01 * init_lr at last iter step
118
+ scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.01 ** min(iter / opt.iters, 1))
119
+
120
+ trainer = Trainer('ngp', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=1)
121
+
122
+ if opt.gui:
123
+ trainer.train_loader = train_loader # attach dataloader to trainer
124
+
125
+ gui = NeRFGUI(opt, trainer)
126
+ gui.render()
127
+
128
+ else:
129
+ valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=5).dataloader()
130
+
131
+ max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
132
+ trainer.train(train_loader, valid_loader, max_epoch)
133
+
134
+ # also test
135
+ test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
136
+ trainer.test(test_loader)
137
+ trainer.save_mesh(resolution=256)
nerf/clip.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import torchvision.transforms as T
5
+ import torchvision.transforms.functional as TF
6
+
7
+ import clip
8
+
9
+ class CLIP(nn.Module):
10
+ def __init__(self, device):
11
+ super().__init__()
12
+
13
+ self.device = device
14
+
15
+ self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
16
+
17
+ # image augmentation
18
+ self.aug = T.Compose([
19
+ T.Resize((224, 224)),
20
+ T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
21
+ ])
22
+
23
+ # self.gaussian_blur = T.GaussianBlur(15, sigma=(0.1, 10))
24
+
25
+
26
+ def get_text_embeds(self, prompt):
27
+
28
+ text = clip.tokenize(prompt).to(self.device)
29
+ text_z = self.clip_model.encode_text(text)
30
+ text_z = text_z / text_z.norm(dim=-1, keepdim=True)
31
+
32
+ return text_z
33
+
34
+
35
+ def train_step(self, text_z, pred_rgb):
36
+
37
+ pred_rgb = self.aug(pred_rgb)
38
+
39
+ image_z = self.clip_model.encode_image(pred_rgb)
40
+ image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
41
+
42
+ loss = - (image_z * text_z).sum(-1).mean()
43
+
44
+ return loss
45
+
nerf/gui.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ import dearpygui.dearpygui as dpg
5
+ from scipy.spatial.transform import Rotation as R
6
+
7
+ from nerf.utils import *
8
+
9
+
10
+ class OrbitCamera:
11
+ def __init__(self, W, H, r=2, fovy=60):
12
+ self.W = W
13
+ self.H = H
14
+ self.radius = r # camera distance from center
15
+ self.fovy = fovy # in degree
16
+ self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
17
+ self.rot = R.from_quat([1, 0, 0, 0]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention)
18
+ self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
19
+
20
+ # pose
21
+ @property
22
+ def pose(self):
23
+ # first move camera to radius
24
+ res = np.eye(4, dtype=np.float32)
25
+ res[2, 3] -= self.radius
26
+ # rotate
27
+ rot = np.eye(4, dtype=np.float32)
28
+ rot[:3, :3] = self.rot.as_matrix()
29
+ res = rot @ res
30
+ # translate
31
+ res[:3, 3] -= self.center
32
+ return res
33
+
34
+ # intrinsics
35
+ @property
36
+ def intrinsics(self):
37
+ focal = self.H / (2 * np.tan(np.radians(self.fovy) / 2))
38
+ return np.array([focal, focal, self.W // 2, self.H // 2])
39
+
40
+ def orbit(self, dx, dy):
41
+ # rotate along camera up/side axis!
42
+ side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
43
+ rotvec_x = self.up * np.radians(-0.1 * dx)
44
+ rotvec_y = side * np.radians(-0.1 * dy)
45
+ self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
46
+
47
+ def scale(self, delta):
48
+ self.radius *= 1.1 ** (-delta)
49
+
50
+ def pan(self, dx, dy, dz=0):
51
+ # pan in camera coordinate system (careful on the sensitivity!)
52
+ self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])
53
+
54
+
55
+ class NeRFGUI:
56
+ def __init__(self, opt, trainer, debug=True):
57
+ self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
58
+ self.W = opt.W
59
+ self.H = opt.H
60
+ self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
61
+ self.debug = debug
62
+ self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg
63
+ self.training = False
64
+ self.step = 0 # training step
65
+
66
+ self.trainer = trainer
67
+ self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
68
+ self.need_update = True # camera moved, should reset accumulation
69
+ self.spp = 1 # sample per pixel
70
+ self.light_dir = np.array([opt.light_theta, opt.light_phi])
71
+ self.ambient_ratio = 1.0
72
+ self.mode = 'image' # choose from ['image', 'depth']
73
+ self.shading = 'albedo'
74
+
75
+ self.dynamic_resolution = True
76
+ self.downscale = 1
77
+ self.train_steps = 16
78
+
79
+ dpg.create_context()
80
+ self.register_dpg()
81
+ self.test_step()
82
+
83
+
84
+ def __del__(self):
85
+ dpg.destroy_context()
86
+
87
+
88
+ def train_step(self):
89
+
90
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
91
+ starter.record()
92
+
93
+ outputs = self.trainer.train_gui(self.trainer.train_loader, step=self.train_steps)
94
+
95
+ ender.record()
96
+ torch.cuda.synchronize()
97
+ t = starter.elapsed_time(ender)
98
+
99
+ self.step += self.train_steps
100
+ self.need_update = True
101
+
102
+ dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
103
+ dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
104
+
105
+ # dynamic train steps
106
+ # max allowed train time per-frame is 500 ms
107
+ full_t = t / self.train_steps * 16
108
+ train_steps = min(16, max(4, int(16 * 500 / full_t)))
109
+ if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
110
+ self.train_steps = train_steps
111
+
112
+
113
+ def prepare_buffer(self, outputs):
114
+ if self.mode == 'image':
115
+ return outputs['image']
116
+ else:
117
+ return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
118
+
119
+
120
+ def test_step(self):
121
+
122
+ if self.need_update or self.spp < self.opt.max_spp:
123
+
124
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
125
+ starter.record()
126
+
127
+ outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading)
128
+
129
+ ender.record()
130
+ torch.cuda.synchronize()
131
+ t = starter.elapsed_time(ender)
132
+
133
+ # update dynamic resolution
134
+ if self.dynamic_resolution:
135
+ # max allowed infer time per-frame is 200 ms
136
+ full_t = t / (self.downscale ** 2)
137
+ downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
138
+ if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
139
+ self.downscale = downscale
140
+
141
+ if self.need_update:
142
+ self.render_buffer = self.prepare_buffer(outputs)
143
+ self.spp = 1
144
+ self.need_update = False
145
+ else:
146
+ self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
147
+ self.spp += 1
148
+
149
+ dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
150
+ dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
151
+ dpg.set_value("_log_spp", self.spp)
152
+ dpg.set_value("_texture", self.render_buffer)
153
+
154
+
155
+ def register_dpg(self):
156
+
157
+ ### register texture
158
+
159
+ with dpg.texture_registry(show=False):
160
+ dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
161
+
162
+ ### register window
163
+
164
+ # the rendered image, as the primary window
165
+ with dpg.window(tag="_primary_window", width=self.W, height=self.H):
166
+
167
+ # add the texture
168
+ dpg.add_image("_texture")
169
+
170
+ dpg.set_primary_window("_primary_window", True)
171
+
172
+ # control window
173
+ with dpg.window(label="Control", tag="_control_window", width=400, height=300):
174
+
175
+ # text prompt
176
+ if self.opt.text is not None:
177
+ dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text")
178
+
179
+ # button theme
180
+ with dpg.theme() as theme_button:
181
+ with dpg.theme_component(dpg.mvButton):
182
+ dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
183
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
184
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
185
+ dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
186
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
187
+
188
+ # time
189
+ if not self.opt.test:
190
+ with dpg.group(horizontal=True):
191
+ dpg.add_text("Train time: ")
192
+ dpg.add_text("no data", tag="_log_train_time")
193
+
194
+ with dpg.group(horizontal=True):
195
+ dpg.add_text("Infer time: ")
196
+ dpg.add_text("no data", tag="_log_infer_time")
197
+
198
+ with dpg.group(horizontal=True):
199
+ dpg.add_text("SPP: ")
200
+ dpg.add_text("1", tag="_log_spp")
201
+
202
+ # train button
203
+ if not self.opt.test:
204
+ with dpg.collapsing_header(label="Train", default_open=True):
205
+ with dpg.group(horizontal=True):
206
+ dpg.add_text("Train: ")
207
+
208
+ def callback_train(sender, app_data):
209
+ if self.training:
210
+ self.training = False
211
+ dpg.configure_item("_button_train", label="start")
212
+ else:
213
+ self.training = True
214
+ dpg.configure_item("_button_train", label="stop")
215
+
216
+ dpg.add_button(label="start", tag="_button_train", callback=callback_train)
217
+ dpg.bind_item_theme("_button_train", theme_button)
218
+
219
+ def callback_reset(sender, app_data):
220
+ @torch.no_grad()
221
+ def weight_reset(m: nn.Module):
222
+ reset_parameters = getattr(m, "reset_parameters", None)
223
+ if callable(reset_parameters):
224
+ m.reset_parameters()
225
+ self.trainer.model.apply(fn=weight_reset)
226
+ self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
227
+ self.need_update = True
228
+
229
+ dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
230
+ dpg.bind_item_theme("_button_reset", theme_button)
231
+
232
+
233
+ with dpg.group(horizontal=True):
234
+ dpg.add_text("Checkpoint: ")
235
+
236
+ def callback_save(sender, app_data):
237
+ self.trainer.save_checkpoint(full=True, best=False)
238
+ dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
239
+ self.trainer.epoch += 1 # use epoch to indicate different calls.
240
+
241
+ dpg.add_button(label="save", tag="_button_save", callback=callback_save)
242
+ dpg.bind_item_theme("_button_save", theme_button)
243
+
244
+ dpg.add_text("", tag="_log_ckpt")
245
+
246
+ # save mesh
247
+ with dpg.group(horizontal=True):
248
+ dpg.add_text("Marching Cubes: ")
249
+
250
+ def callback_mesh(sender, app_data):
251
+ self.trainer.save_mesh(resolution=256, threshold=10)
252
+ dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
253
+ self.trainer.epoch += 1 # use epoch to indicate different calls.
254
+
255
+ dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
256
+ dpg.bind_item_theme("_button_mesh", theme_button)
257
+
258
+ dpg.add_text("", tag="_log_mesh")
259
+
260
+ with dpg.group(horizontal=True):
261
+ dpg.add_text("", tag="_log_train_log")
262
+
263
+
264
+ # rendering options
265
+ with dpg.collapsing_header(label="Options", default_open=True):
266
+
267
+ # dynamic rendering resolution
268
+ with dpg.group(horizontal=True):
269
+
270
+ def callback_set_dynamic_resolution(sender, app_data):
271
+ if self.dynamic_resolution:
272
+ self.dynamic_resolution = False
273
+ self.downscale = 1
274
+ else:
275
+ self.dynamic_resolution = True
276
+ self.need_update = True
277
+
278
+ dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
279
+ dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
280
+
281
+ # mode combo
282
+ def callback_change_mode(sender, app_data):
283
+ self.mode = app_data
284
+ self.need_update = True
285
+
286
+ dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
287
+
288
+ # bg_color picker
289
+ def callback_change_bg(sender, app_data):
290
+ self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
291
+ self.need_update = True
292
+
293
+ dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
294
+
295
+ # fov slider
296
+ def callback_set_fovy(sender, app_data):
297
+ self.cam.fovy = app_data
298
+ self.need_update = True
299
+
300
+ dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
301
+
302
+ # dt_gamma slider
303
+ def callback_set_dt_gamma(sender, app_data):
304
+ self.opt.dt_gamma = app_data
305
+ self.need_update = True
306
+
307
+ dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
308
+
309
+ # max_steps slider
310
+ def callback_set_max_steps(sender, app_data):
311
+ self.opt.max_steps = app_data
312
+ self.need_update = True
313
+
314
+ dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
315
+
316
+ # aabb slider
317
+ def callback_set_aabb(sender, app_data, user_data):
318
+ # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
319
+ self.trainer.model.aabb_infer[user_data] = app_data
320
+
321
+ # also change train aabb ? [better not...]
322
+ #self.trainer.model.aabb_train[user_data] = app_data
323
+
324
+ self.need_update = True
325
+
326
+ dpg.add_separator()
327
+ dpg.add_text("Axis-aligned bounding box:")
328
+
329
+ with dpg.group(horizontal=True):
330
+ dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
331
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
332
+
333
+ with dpg.group(horizontal=True):
334
+ dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
335
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
336
+
337
+ with dpg.group(horizontal=True):
338
+ dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
339
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
340
+
341
+ # light dir
342
+ def callback_set_light_dir(sender, app_data, user_data):
343
+ self.light_dir[user_data] = app_data
344
+ self.need_update = True
345
+
346
+ dpg.add_separator()
347
+ dpg.add_text("Plane Light Direction:")
348
+
349
+ with dpg.group(horizontal=True):
350
+ dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0)
351
+
352
+ with dpg.group(horizontal=True):
353
+ dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1)
354
+
355
+ # ambient ratio
356
+ def callback_set_abm_ratio(sender, app_data):
357
+ self.ambient_ratio = app_data
358
+ self.need_update = True
359
+
360
+ dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio)
361
+
362
+ # shading mode
363
+ def callback_change_shading(sender, app_data):
364
+ self.shading = app_data
365
+ self.need_update = True
366
+
367
+ dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading)
368
+
369
+
370
+ # debug info
371
+ if self.debug:
372
+ with dpg.collapsing_header(label="Debug"):
373
+ # pose
374
+ dpg.add_separator()
375
+ dpg.add_text("Camera Pose:")
376
+ dpg.add_text(str(self.cam.pose), tag="_log_pose")
377
+
378
+
379
+ ### register camera handler
380
+
381
+ def callback_camera_drag_rotate(sender, app_data):
382
+
383
+ if not dpg.is_item_focused("_primary_window"):
384
+ return
385
+
386
+ dx = app_data[1]
387
+ dy = app_data[2]
388
+
389
+ self.cam.orbit(dx, dy)
390
+ self.need_update = True
391
+
392
+ if self.debug:
393
+ dpg.set_value("_log_pose", str(self.cam.pose))
394
+
395
+
396
+ def callback_camera_wheel_scale(sender, app_data):
397
+
398
+ if not dpg.is_item_focused("_primary_window"):
399
+ return
400
+
401
+ delta = app_data
402
+
403
+ self.cam.scale(delta)
404
+ self.need_update = True
405
+
406
+ if self.debug:
407
+ dpg.set_value("_log_pose", str(self.cam.pose))
408
+
409
+
410
+ def callback_camera_drag_pan(sender, app_data):
411
+
412
+ if not dpg.is_item_focused("_primary_window"):
413
+ return
414
+
415
+ dx = app_data[1]
416
+ dy = app_data[2]
417
+
418
+ self.cam.pan(dx, dy)
419
+ self.need_update = True
420
+
421
+ if self.debug:
422
+ dpg.set_value("_log_pose", str(self.cam.pose))
423
+
424
+
425
+ with dpg.handler_registry():
426
+ dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
427
+ dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
428
+ dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan)
429
+
430
+
431
+ dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False)
432
+
433
+ # TODO: seems dearpygui doesn't support resizing texture...
434
+ # def callback_resize(sender, app_data):
435
+ # self.W = app_data[0]
436
+ # self.H = app_data[1]
437
+ # # how to reload texture ???
438
+
439
+ # dpg.set_viewport_resize_callback(callback_resize)
440
+
441
+ ### global theme
442
+ with dpg.theme() as theme_no_padding:
443
+ with dpg.theme_component(dpg.mvAll):
444
+ # set all padding to 0 to avoid scroll bar
445
+ dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
446
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
447
+ dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
448
+
449
+ dpg.bind_item_theme("_primary_window", theme_no_padding)
450
+
451
+ dpg.setup_dearpygui()
452
+
453
+ #dpg.show_metrics()
454
+
455
+ dpg.show_viewport()
456
+
457
+
458
+ def render(self):
459
+
460
+ while dpg.is_dearpygui_running():
461
+ # update texture every frame
462
+ if self.training:
463
+ self.train_step()
464
+ self.test_step()
465
+ dpg.render_dearpygui_frame()
nerf/network.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from activation import trunc_exp
6
+ from .renderer import NeRFRenderer
7
+
8
+ import numpy as np
9
+ from encoding import get_encoder
10
+
11
+ from .utils import safe_normalize
12
+
13
+ class MLP(nn.Module):
14
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
15
+ super().__init__()
16
+ self.dim_in = dim_in
17
+ self.dim_out = dim_out
18
+ self.dim_hidden = dim_hidden
19
+ self.num_layers = num_layers
20
+
21
+ net = []
22
+ for l in range(num_layers):
23
+ net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
24
+
25
+ self.net = nn.ModuleList(net)
26
+
27
+ def forward(self, x):
28
+ for l in range(self.num_layers):
29
+ x = self.net[l](x)
30
+ if l != self.num_layers - 1:
31
+ x = F.relu(x, inplace=True)
32
+ return x
33
+
34
+
35
+ class NeRFNetwork(NeRFRenderer):
36
+ def __init__(self,
37
+ opt,
38
+ num_layers=5,
39
+ hidden_dim=128,
40
+ num_layers_bg=3,
41
+ hidden_dim_bg=128,
42
+ ):
43
+
44
+ super().__init__(opt)
45
+
46
+ self.num_layers = num_layers
47
+ self.hidden_dim = hidden_dim
48
+
49
+ self.encoder, self.in_dim = get_encoder('frequency', input_dim=3)
50
+
51
+ self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
52
+
53
+ # background network
54
+ if self.bg_radius > 0:
55
+ self.num_layers_bg = num_layers_bg
56
+ self.hidden_dim_bg = hidden_dim_bg
57
+
58
+ self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2)
59
+
60
+ self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
61
+
62
+ else:
63
+ self.bg_net = None
64
+
65
+ def gaussian(self, x):
66
+ # x: [B, N, 3]
67
+
68
+ d = (x ** 2).sum(-1)
69
+ g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
70
+
71
+ return g
72
+
73
+ def common_forward(self, x):
74
+ # x: [N, 3], in [-bound, bound]
75
+
76
+ # sigma
77
+ h = self.encoder(x, bound=self.bound)
78
+
79
+ h = self.sigma_net(h)
80
+
81
+ sigma = trunc_exp(h[..., 0] + self.gaussian(x))
82
+ albedo = torch.sigmoid(h[..., 1:])
83
+
84
+ return sigma, albedo
85
+
86
+ # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
87
+ def finite_differnce_normal(self, x, epsilon=5e-4):
88
+ # x: [N, 3]
89
+ dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
90
+ dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
91
+ dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
92
+ dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
93
+ dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
94
+ dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
95
+
96
+ normal = torch.stack([
97
+ 0.5 * (dx_pos - dx_neg) / epsilon,
98
+ 0.5 * (dy_pos - dy_neg) / epsilon,
99
+ 0.5 * (dz_pos - dz_neg) / epsilon
100
+ ], dim=-1)
101
+
102
+ return normal
103
+
104
+ def forward(self, x, d, l=None, ratio=1, shading='albedo'):
105
+ # x: [N, 3], in [-bound, bound]
106
+ # d: [N, 3], view direction, nomalized in [-1, 1]
107
+ # l: [3], plane light direction, nomalized in [-1, 1]
108
+ # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
109
+
110
+ if shading == 'albedo':
111
+ # no need to query normal
112
+ sigma, color = self.common_forward(x)
113
+ normal = None
114
+
115
+ else:
116
+ # query normal
117
+
118
+ # sigma, albedo = self.common_forward(x)
119
+ # normal = self.finite_differnce_normal(x)
120
+
121
+ with torch.enable_grad():
122
+ x.requires_grad_(True)
123
+ sigma, albedo = self.common_forward(x)
124
+ # query gradient
125
+ normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
126
+
127
+ # normalize...
128
+ normal = safe_normalize(normal)
129
+ normal[torch.isnan(normal)] = 0
130
+
131
+ # light direction (random if not provided)
132
+ if l is None:
133
+ l = torch.randn(3, device=x.device, dtype=torch.float)
134
+ l = safe_normalize(l)
135
+
136
+ # lambertian shading
137
+ lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,]
138
+
139
+ if shading == 'textureless':
140
+ color = lambertian.unsqueeze(-1).repeat(1, 3)
141
+ elif shading == 'normal':
142
+ color = (normal + 1) / 2
143
+ else: # 'lambertian'
144
+ color = albedo * lambertian.unsqueeze(-1)
145
+
146
+ return sigma, color, normal
147
+
148
+
149
+ def density(self, x):
150
+ # x: [N, 3], in [-bound, bound]
151
+
152
+ sigma, albedo = self.common_forward(x)
153
+
154
+ return {
155
+ 'sigma': sigma,
156
+ 'albedo': albedo,
157
+ }
158
+
159
+
160
+ def background(self, x, d):
161
+ # x: [N, 2], in [-1, 1]
162
+
163
+ h = self.encoder_bg(x) # [N, C]
164
+
165
+ h = self.bg_net(h)
166
+
167
+ # sigmoid activation for rgb
168
+ rgbs = torch.sigmoid(h)
169
+
170
+ return rgbs
171
+
172
+ # optimizer utils
173
+ def get_params(self, lr):
174
+
175
+ params = [
176
+ # {'params': self.encoder.parameters(), 'lr': lr * 10},
177
+ {'params': self.sigma_net.parameters(), 'lr': lr},
178
+ ]
179
+
180
+ if self.bg_radius > 0:
181
+ # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
182
+ params.append({'params': self.bg_net.parameters(), 'lr': lr})
183
+
184
+ return params
nerf/network_grid.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from activation import trunc_exp
6
+ from .renderer import NeRFRenderer
7
+
8
+ import numpy as np
9
+ from encoding import get_encoder
10
+
11
+ from .utils import safe_normalize
12
+
13
+ class MLP(nn.Module):
14
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
15
+ super().__init__()
16
+ self.dim_in = dim_in
17
+ self.dim_out = dim_out
18
+ self.dim_hidden = dim_hidden
19
+ self.num_layers = num_layers
20
+
21
+ net = []
22
+ for l in range(num_layers):
23
+ net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
24
+
25
+ self.net = nn.ModuleList(net)
26
+
27
+ def forward(self, x):
28
+ for l in range(self.num_layers):
29
+ x = self.net[l](x)
30
+ if l != self.num_layers - 1:
31
+ x = F.relu(x, inplace=True)
32
+ return x
33
+
34
+
35
+ class NeRFNetwork(NeRFRenderer):
36
+ def __init__(self,
37
+ opt,
38
+ num_layers=3,
39
+ hidden_dim=64,
40
+ num_layers_bg=2,
41
+ hidden_dim_bg=64,
42
+ ):
43
+
44
+ super().__init__(opt)
45
+
46
+ self.num_layers = num_layers
47
+ self.hidden_dim = hidden_dim
48
+
49
+ self.encoder, self.in_dim = get_encoder('tiledgrid', input_dim=3, desired_resolution=2048 * self.bound)
50
+
51
+ self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
52
+
53
+ # background network
54
+ if self.bg_radius > 0:
55
+ self.num_layers_bg = num_layers_bg
56
+ self.hidden_dim_bg = hidden_dim_bg
57
+
58
+ # use a very simple network to avoid it learning the prompt...
59
+ # self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2, num_levels=4, desired_resolution=2048)
60
+ self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=2)
61
+
62
+ self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
63
+
64
+ else:
65
+ self.bg_net = None
66
+
67
+ def gaussian(self, x):
68
+ # x: [B, N, 3]
69
+
70
+ d = (x ** 2).sum(-1)
71
+ g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
72
+
73
+ return g
74
+
75
+ def common_forward(self, x):
76
+ # x: [N, 3], in [-bound, bound]
77
+
78
+ # sigma
79
+ h = self.encoder(x, bound=self.bound)
80
+
81
+ h = self.sigma_net(h)
82
+
83
+ sigma = trunc_exp(h[..., 0] + self.gaussian(x))
84
+ albedo = torch.sigmoid(h[..., 1:])
85
+
86
+ return sigma, albedo
87
+
88
+ # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
89
+ def finite_differnce_normal(self, x, epsilon=5e-4):
90
+ # x: [N, 3]
91
+ dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
92
+ dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
93
+ dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
94
+ dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
95
+ dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
96
+ dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
97
+
98
+ normal = torch.stack([
99
+ 0.5 * (dx_pos - dx_neg) / epsilon,
100
+ 0.5 * (dy_pos - dy_neg) / epsilon,
101
+ 0.5 * (dz_pos - dz_neg) / epsilon
102
+ ], dim=-1)
103
+
104
+ return normal
105
+
106
+ def forward(self, x, d, l=None, ratio=1, shading='albedo'):
107
+ # x: [N, 3], in [-bound, bound]
108
+ # d: [N, 3], view direction, nomalized in [-1, 1]
109
+ # l: [3], plane light direction, nomalized in [-1, 1]
110
+ # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
111
+
112
+ if shading == 'albedo':
113
+ # no need to query normal
114
+ sigma, color = self.common_forward(x)
115
+ normal = None
116
+
117
+ else:
118
+ # query normal
119
+
120
+ sigma, albedo = self.common_forward(x)
121
+ normal = self.finite_differnce_normal(x)
122
+
123
+ # with torch.enable_grad():
124
+ # x.requires_grad_(True)
125
+ # sigma, albedo = self.common_forward(x)
126
+ # # query gradient
127
+ # normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
128
+
129
+ # normalize...
130
+ normal = safe_normalize(normal)
131
+ normal[torch.isnan(normal)] = 0
132
+
133
+ # light direction (random if not provided)
134
+ if l is None:
135
+ l = torch.randn(3, device=x.device, dtype=torch.float)
136
+ l = safe_normalize(l)
137
+
138
+ # lambertian shading
139
+ lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,]
140
+
141
+ if shading == 'textureless':
142
+ color = lambertian.unsqueeze(-1).repeat(1, 3)
143
+ elif shading == 'normal':
144
+ color = (normal + 1) / 2
145
+ else: # 'lambertian'
146
+ color = albedo * lambertian.unsqueeze(-1)
147
+
148
+ return sigma, color, normal
149
+
150
+
151
+ def density(self, x):
152
+ # x: [N, 3], in [-bound, bound]
153
+
154
+ sigma, albedo = self.common_forward(x)
155
+
156
+ return {
157
+ 'sigma': sigma,
158
+ 'albedo': albedo,
159
+ }
160
+
161
+
162
+ def background(self, x, d):
163
+ # x: [N, 2], in [-1, 1]
164
+
165
+ h = self.encoder_bg(x) # [N, C]
166
+
167
+ h = self.bg_net(h)
168
+
169
+ # sigmoid activation for rgb
170
+ rgbs = torch.sigmoid(h)
171
+
172
+ return rgbs
173
+
174
+ # optimizer utils
175
+ def get_params(self, lr):
176
+
177
+ params = [
178
+ {'params': self.encoder.parameters(), 'lr': lr * 10},
179
+ {'params': self.sigma_net.parameters(), 'lr': lr},
180
+ ]
181
+
182
+ if self.bg_radius > 0:
183
+ params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
184
+ params.append({'params': self.bg_net.parameters(), 'lr': lr})
185
+
186
+ return params
nerf/network_tcnn.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from activation import trunc_exp
6
+ from .renderer import NeRFRenderer
7
+
8
+ import numpy as np
9
+ import tinycudann as tcnn
10
+
11
+ class MLP(nn.Module):
12
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
13
+ super().__init__()
14
+ self.dim_in = dim_in
15
+ self.dim_out = dim_out
16
+ self.dim_hidden = dim_hidden
17
+ self.num_layers = num_layers
18
+
19
+ net = []
20
+ for l in range(num_layers):
21
+ net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
22
+
23
+ self.net = nn.ModuleList(net)
24
+
25
+ def forward(self, x):
26
+ for l in range(self.num_layers):
27
+ x = self.net[l](x)
28
+ if l != self.num_layers - 1:
29
+ x = F.relu(x, inplace=True)
30
+ return x
31
+
32
+
33
+ class NeRFNetwork(NeRFRenderer):
34
+ def __init__(self,
35
+ opt,
36
+ num_layers=3,
37
+ hidden_dim=64,
38
+ num_layers_bg=2,
39
+ hidden_dim_bg=64,
40
+ ):
41
+
42
+ super().__init__(opt)
43
+
44
+ self.num_layers = num_layers
45
+ self.hidden_dim = hidden_dim
46
+
47
+ per_level_scale = np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1))
48
+
49
+ self.encoder = tcnn.Encoding(
50
+ n_input_dims=3,
51
+ encoding_config={
52
+ "otype": "HashGrid",
53
+ "n_levels": 16,
54
+ "n_features_per_level": 2,
55
+ "log2_hashmap_size": 19,
56
+ "base_resolution": 16,
57
+ "per_level_scale": per_level_scale,
58
+ },
59
+ )
60
+
61
+ self.sigma_net = MLP(32, 4, hidden_dim, num_layers, bias=True)
62
+
63
+ # background network
64
+ if self.bg_radius > 0:
65
+ self.num_layers_bg = num_layers_bg
66
+ self.hidden_dim_bg = hidden_dim_bg
67
+
68
+ self.encoder_bg = tcnn.Encoding(
69
+ n_input_dims=2,
70
+ encoding_config={
71
+ "otype": "HashGrid",
72
+ "n_levels": 4,
73
+ "n_features_per_level": 2,
74
+ "log2_hashmap_size": 16,
75
+ "base_resolution": 16,
76
+ "per_level_scale": 1.5,
77
+ },
78
+ )
79
+
80
+ self.bg_net = MLP(8, 3, hidden_dim_bg, num_layers_bg, bias=True)
81
+
82
+ else:
83
+ self.bg_net = None
84
+
85
+ def gaussian(self, x):
86
+ # x: [B, N, 3]
87
+
88
+ d = (x ** 2).sum(-1)
89
+ g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
90
+
91
+ return g
92
+
93
+ def common_forward(self, x):
94
+ # x: [N, 3], in [-bound, bound]
95
+
96
+ # sigma
97
+ h = (x + self.bound) / (2 * self.bound) # to [0, 1]
98
+ h = self.encoder(h)
99
+
100
+ h = self.sigma_net(h)
101
+
102
+ sigma = trunc_exp(h[..., 0] + self.gaussian(x))
103
+ albedo = torch.sigmoid(h[..., 1:])
104
+
105
+ return sigma, albedo
106
+
107
+
108
+ def forward(self, x, d, l=None, ratio=1, shading='albedo'):
109
+ # x: [N, 3], in [-bound, bound]
110
+ # d: [N, 3], view direction, nomalized in [-1, 1]
111
+ # l: [3], plane light direction, nomalized in [-1, 1]
112
+ # ratio: scalar, ambient ratio, 1 == no shading (albedo only)
113
+
114
+ if shading == 'albedo':
115
+ # no need to query normal
116
+ sigma, color = self.common_forward(x)
117
+ normal = None
118
+
119
+ else:
120
+ # query normal
121
+ has_grad = torch.is_grad_enabled()
122
+
123
+ with torch.enable_grad():
124
+ x.requires_grad_(True)
125
+ sigma, albedo = self.common_forward(x)
126
+ # query gradient
127
+ normal = torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
128
+
129
+ # normalize...
130
+ normal = normal / (torch.norm(normal, dim=-1, keepdim=True) + 1e-9)
131
+ normal[torch.isnan(normal)] = 0
132
+
133
+ if not has_grad:
134
+ normal = normal.detach()
135
+
136
+ # light direction (random if not provided)
137
+ if l is None:
138
+ l = torch.randn(3, device=x.device, dtype=torch.float)
139
+ l = l / (torch.norm(l, dim=-1, keepdim=True) + 1e-9)
140
+
141
+ # lambertian shading
142
+ lambertian = ratio + (1 - ratio) * (normal @ l).clamp(min=0) # [N,]
143
+
144
+ if shading == 'textureless':
145
+ color = lambertian.unsqueeze(-1).repeat(1, 3)
146
+ elif shading == 'normal':
147
+ color = (normal + 1) / 2
148
+ else: # 'lambertian'
149
+ color = albedo * lambertian.unsqueeze(-1)
150
+
151
+ return sigma, color, normal
152
+
153
+
154
+ def density(self, x):
155
+ # x: [N, 3], in [-bound, bound]
156
+
157
+ sigma, _ = self.common_forward(x)
158
+
159
+ return {
160
+ 'sigma': sigma
161
+ }
162
+
163
+
164
+ def background(self, x, d):
165
+ # x: [N, 2], in [-1, 1]
166
+
167
+ h = (x + 1) / (2 * 1) # to [0, 1]
168
+ h = self.encoder_bg(h) # [N, C]
169
+
170
+ h = self.bg_net(h)
171
+
172
+ # sigmoid activation for rgb
173
+ rgbs = torch.sigmoid(h)
174
+
175
+ return rgbs
176
+
177
+ # optimizer utils
178
+ def get_params(self, lr):
179
+
180
+ params = [
181
+ {'params': self.encoder.parameters(), 'lr': lr * 10},
182
+ {'params': self.sigma_net.parameters(), 'lr': lr},
183
+ ]
184
+
185
+ if self.bg_radius > 0:
186
+ params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
187
+ params.append({'params': self.bg_net.parameters(), 'lr': lr})
188
+
189
+ return params
nerf/provider.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import json
5
+ import tqdm
6
+ import random
7
+ import numpy as np
8
+ from scipy.spatial.transform import Slerp, Rotation
9
+
10
+ import trimesh
11
+
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+
15
+ from .utils import get_rays, safe_normalize
16
+
17
+ def visualize_poses(poses, size=0.1):
18
+ # poses: [B, 4, 4]
19
+
20
+ axes = trimesh.creation.axis(axis_length=4)
21
+ sphere = trimesh.creation.icosphere(radius=1)
22
+ objects = [axes, sphere]
23
+
24
+ for pose in poses:
25
+ # a camera is visualized with 8 line segments.
26
+ pos = pose[:3, 3]
27
+ a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
28
+ b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
29
+ c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
30
+ d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
31
+
32
+ segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]])
33
+ segs = trimesh.load_path(segs)
34
+ objects.append(segs)
35
+
36
+ trimesh.Scene(objects).show()
37
+
38
+ def get_view_direction(thetas, phis):
39
+ # phis [B,]; thetas: [B,]
40
+ # front = 0 0-90
41
+ # side (left) = 1 90-180
42
+ # back = 2 180-270
43
+ # side (right) = 3 270-360
44
+ # top = 4 0-30
45
+ # bottom = 5 150-180
46
+ res = torch.zeros(phis.shape[0], dtype=torch.long)
47
+ # first determine by phis
48
+ res[(phis < (np.pi / 2))] = 0
49
+ res[(phis >= (np.pi / 2)) & (phis < np.pi)] = 1
50
+ res[(phis >= np.pi) & (phis < (3 * np.pi / 2))] = 2
51
+ res[(phis >= (3 * np.pi / 2)) & (phis < (2 * np.pi))] = 3
52
+ # override by thetas
53
+ res[thetas < (np.pi / 6)] = 4
54
+ res[thetas >= (5 * np.pi / 6)] = 5
55
+ return res
56
+
57
+
58
+ def rand_poses(size, device, return_dirs=False, radius_range=[1, 1.5], theta_range=[0, 4 * np.pi / 6], phi_range=[0, 2*np.pi]):
59
+ ''' generate random poses from an orbit camera
60
+ Args:
61
+ size: batch size of generated poses.
62
+ device: where to allocate the output.
63
+ radius: camera radius
64
+ theta_range: [min, max], should be in [0, \pi]
65
+ phi_range: [min, max], should be in [0, 2\pi]
66
+ Return:
67
+ poses: [size, 4, 4]
68
+ '''
69
+
70
+ radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
71
+ thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
72
+ phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
73
+
74
+ centers = torch.stack([
75
+ radius * torch.sin(thetas) * torch.sin(phis),
76
+ radius * torch.cos(thetas),
77
+ radius * torch.sin(thetas) * torch.cos(phis),
78
+ ], dim=-1) # [B, 3]
79
+
80
+ # jitters
81
+ centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
82
+ targets = torch.randn_like(centers) * 0.2
83
+
84
+ # lookat
85
+ forward_vector = safe_normalize(targets - centers)
86
+ up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
87
+ right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
88
+
89
+ up_noise = torch.randn_like(up_vector) * 0.02
90
+ up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
91
+
92
+ poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
93
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
94
+ poses[:, :3, 3] = centers
95
+
96
+ if return_dirs:
97
+ dirs = get_view_direction(thetas, phis)
98
+ else:
99
+ dirs = None
100
+
101
+ return poses, dirs
102
+
103
+
104
+ def circle_poses(device, return_dirs=False, radius=1.25, theta=np.pi/2, phi=0):
105
+
106
+ thetas = torch.FloatTensor([theta]).to(device)
107
+ phis = torch.FloatTensor([phi]).to(device)
108
+
109
+ centers = torch.stack([
110
+ radius * torch.sin(thetas) * torch.sin(phis),
111
+ radius * torch.cos(thetas),
112
+ radius * torch.sin(thetas) * torch.cos(phis),
113
+ ], dim=-1) # [B, 3]
114
+
115
+ # lookat
116
+ forward_vector = - safe_normalize(centers)
117
+ up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0)
118
+ right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
119
+ up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
120
+
121
+ poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0)
122
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
123
+ poses[:, :3, 3] = centers
124
+
125
+ if return_dirs:
126
+ dirs = get_view_direction(thetas, phis)
127
+ else:
128
+ dirs = None
129
+
130
+ return poses, dirs
131
+
132
+
133
+ class NeRFDataset:
134
+ def __init__(self, opt, device, type='train', H=256, W=256, size=100):
135
+ super().__init__()
136
+
137
+ self.opt = opt
138
+ self.device = device
139
+ self.type = type # train, val, test
140
+
141
+ self.H = H
142
+ self.W = W
143
+ self.radius_range = opt.radius_range
144
+ self.fovy_range = opt.fovy_range
145
+ self.size = size
146
+
147
+ self.training = self.type in ['train', 'all']
148
+
149
+ self.cx = self.H / 2
150
+ self.cy = self.W / 2
151
+
152
+ # [debug] visualize poses
153
+ # poses, dirs = rand_poses(100, self.device, return_dirs=self.opt.dir_text, radius_range=self.radius_range)
154
+ # visualize_poses(poses.detach().cpu().numpy())
155
+
156
+
157
+ def collate(self, index):
158
+
159
+ B = len(index) # always 1
160
+
161
+ if self.training:
162
+ # random pose on the fly
163
+ poses, dirs = rand_poses(B, self.device, return_dirs=self.opt.dir_text, radius_range=self.radius_range)
164
+
165
+ # random focal
166
+ fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0]
167
+ focal = self.H / (2 * np.tan(np.radians(fov) / 2))
168
+ intrinsics = np.array([focal, focal, self.cx, self.cy])
169
+ else:
170
+ # circle pose
171
+ phi = (index[0] / self.size) * 2 * np.pi
172
+ poses, dirs = circle_poses(self.device, return_dirs=self.opt.dir_text, radius=self.radius_range[1], theta=np.pi/2, phi=phi)
173
+
174
+ # fixed focal
175
+ fov = (self.fovy_range[1] + self.fovy_range[0]) / 2
176
+ focal = self.H / (2 * np.tan(np.radians(fov) / 2))
177
+ intrinsics = np.array([focal, focal, self.cx, self.cy])
178
+
179
+
180
+ # sample a low-resolution but full image for CLIP
181
+ rays = get_rays(poses, intrinsics, self.H, self.W, -1)
182
+
183
+ data = {
184
+ 'H': self.H,
185
+ 'W': self.W,
186
+ 'rays_o': rays['rays_o'],
187
+ 'rays_d': rays['rays_d'],
188
+ 'dir': dirs,
189
+ }
190
+
191
+ return data
192
+
193
+
194
+ def dataloader(self):
195
+ loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
196
+ loader._data = self # an ugly fix... we need to access dataset in trainer.
197
+ return loader
nerf/renderer.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import cv2
4
+ import trimesh
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ import mcubes
12
+ import raymarching
13
+ from .utils import custom_meshgrid, safe_normalize
14
+
15
+ def sample_pdf(bins, weights, n_samples, det=False):
16
+ # This implementation is from NeRF
17
+ # bins: [B, T], old_z_vals
18
+ # weights: [B, T - 1], bin weights.
19
+ # return: [B, n_samples], new_z_vals
20
+
21
+ # Get pdf
22
+ weights = weights + 1e-5 # prevent nans
23
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
24
+ cdf = torch.cumsum(pdf, -1)
25
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
26
+ # Take uniform samples
27
+ if det:
28
+ u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
29
+ u = u.expand(list(cdf.shape[:-1]) + [n_samples])
30
+ else:
31
+ u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
32
+
33
+ # Invert CDF
34
+ u = u.contiguous()
35
+ inds = torch.searchsorted(cdf, u, right=True)
36
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
37
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
38
+ inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
39
+
40
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
41
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
42
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
43
+
44
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
45
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
46
+ t = (u - cdf_g[..., 0]) / denom
47
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
48
+
49
+ return samples
50
+
51
+
52
+ def plot_pointcloud(pc, color=None):
53
+ # pc: [N, 3]
54
+ # color: [N, 3/4]
55
+ print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
56
+ pc = trimesh.PointCloud(pc, color)
57
+ # axis
58
+ axes = trimesh.creation.axis(axis_length=4)
59
+ # sphere
60
+ sphere = trimesh.creation.icosphere(radius=1)
61
+ trimesh.Scene([pc, axes, sphere]).show()
62
+
63
+
64
+ class NeRFRenderer(nn.Module):
65
+ def __init__(self, opt):
66
+ super().__init__()
67
+
68
+ self.opt = opt
69
+ self.bound = opt.bound
70
+ self.cascade = 1 + math.ceil(math.log2(opt.bound))
71
+ self.grid_size = 128
72
+ self.cuda_ray = opt.cuda_ray
73
+ self.min_near = opt.min_near
74
+ self.density_thresh = opt.density_thresh
75
+ self.bg_radius = opt.bg_radius
76
+
77
+ # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
78
+ # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
79
+ aabb_train = torch.FloatTensor([-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound])
80
+ aabb_infer = aabb_train.clone()
81
+ self.register_buffer('aabb_train', aabb_train)
82
+ self.register_buffer('aabb_infer', aabb_infer)
83
+
84
+ # extra state for cuda raymarching
85
+ if self.cuda_ray:
86
+ # density grid
87
+ density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
88
+ density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
89
+ self.register_buffer('density_grid', density_grid)
90
+ self.register_buffer('density_bitfield', density_bitfield)
91
+ self.mean_density = 0
92
+ self.iter_density = 0
93
+ # step counter
94
+ step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
95
+ self.register_buffer('step_counter', step_counter)
96
+ self.mean_count = 0
97
+ self.local_step = 0
98
+
99
+
100
+ def forward(self, x, d):
101
+ raise NotImplementedError()
102
+
103
+ def density(self, x):
104
+ raise NotImplementedError()
105
+
106
+ def color(self, x, d, mask=None, **kwargs):
107
+ raise NotImplementedError()
108
+
109
+ def reset_extra_state(self):
110
+ if not self.cuda_ray:
111
+ return
112
+ # density grid
113
+ self.density_grid.zero_()
114
+ self.mean_density = 0
115
+ self.iter_density = 0
116
+ # step counter
117
+ self.step_counter.zero_()
118
+ self.mean_count = 0
119
+ self.local_step = 0
120
+
121
+ @torch.no_grad()
122
+ def export_mesh(self, path, resolution=None, S=128):
123
+
124
+ if resolution is None:
125
+ resolution = self.grid_size
126
+
127
+ density_thresh = min(self.mean_density, self.density_thresh)
128
+
129
+ sigmas = np.zeros([resolution, resolution, resolution], dtype=np.float32)
130
+
131
+ # query
132
+ X = torch.linspace(-1, 1, resolution).split(S)
133
+ Y = torch.linspace(-1, 1, resolution).split(S)
134
+ Z = torch.linspace(-1, 1, resolution).split(S)
135
+
136
+ for xi, xs in enumerate(X):
137
+ for yi, ys in enumerate(Y):
138
+ for zi, zs in enumerate(Z):
139
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
140
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
141
+ val = self.density(pts.to(self.density_bitfield.device))
142
+ sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
143
+
144
+ vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
145
+
146
+ vertices = vertices / (resolution - 1.0) * 2 - 1
147
+ vertices = vertices.astype(np.float32)
148
+ triangles = triangles.astype(np.int32)
149
+
150
+ v = torch.from_numpy(vertices).to(self.density_bitfield.device)
151
+ f = torch.from_numpy(triangles).int().to(self.density_bitfield.device)
152
+
153
+ # mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
154
+ # mesh.export(os.path.join(path, f'mesh.ply'))
155
+
156
+ # texture?
157
+ def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):
158
+ # v, f: torch Tensor
159
+ device = v.device
160
+ v_np = v.cpu().numpy() # [N, 3]
161
+ f_np = f.cpu().numpy() # [M, 3]
162
+
163
+ print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
164
+
165
+ # unwrap uvs
166
+ import xatlas
167
+ import nvdiffrast.torch as dr
168
+ from sklearn.neighbors import NearestNeighbors
169
+ from scipy.ndimage import binary_dilation, binary_erosion
170
+
171
+ glctx = dr.RasterizeGLContext()
172
+
173
+ atlas = xatlas.Atlas()
174
+ atlas.add_mesh(v_np, f_np)
175
+ chart_options = xatlas.ChartOptions()
176
+ chart_options.max_iterations = 0 # disable merge_chart for faster unwrap...
177
+ atlas.generate(chart_options=chart_options)
178
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
179
+
180
+ # vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]
181
+
182
+ vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
183
+ ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
184
+
185
+ # render uv maps
186
+ uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
187
+ uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
188
+
189
+ if ssaa > 1:
190
+ h = int(h0 * ssaa)
191
+ w = int(w0 * ssaa)
192
+ else:
193
+ h, w = h0, w0
194
+
195
+ rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
196
+ xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
197
+ mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
198
+
199
+ # masked query
200
+ xyzs = xyzs.view(-1, 3)
201
+ mask = (mask > 0).view(-1)
202
+
203
+ sigmas = torch.zeros(h * w, device=device, dtype=torch.float32)
204
+ feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)
205
+
206
+ if mask.any():
207
+ xyzs = xyzs[mask] # [M, 3]
208
+
209
+ # batched inference to avoid OOM
210
+ all_sigmas = []
211
+ all_feats = []
212
+ head = 0
213
+ while head < xyzs.shape[0]:
214
+ tail = min(head + 640000, xyzs.shape[0])
215
+ results_ = self.density(xyzs[head:tail])
216
+ all_sigmas.append(results_['sigma'].float())
217
+ all_feats.append(results_['albedo'].float())
218
+ head += 640000
219
+
220
+ sigmas[mask] = torch.cat(all_sigmas, dim=0)
221
+ feats[mask] = torch.cat(all_feats, dim=0)
222
+
223
+ sigmas = sigmas.view(h, w, 1)
224
+ feats = feats.view(h, w, -1)
225
+ mask = mask.view(h, w)
226
+
227
+ ### alpha mask
228
+ # deltas = 2 * np.sqrt(3) / 1024
229
+ # alphas = 1 - torch.exp(-sigmas * deltas)
230
+ # alphas_mask = alphas > 0.5
231
+ # feats = feats * alphas_mask
232
+
233
+ # quantize [0.0, 1.0] to [0, 255]
234
+ feats = feats.cpu().numpy()
235
+ feats = (feats * 255).astype(np.uint8)
236
+
237
+ # alphas = alphas.cpu().numpy()
238
+ # alphas = (alphas * 255).astype(np.uint8)
239
+
240
+ ### NN search as an antialiasing ...
241
+ mask = mask.cpu().numpy()
242
+
243
+ inpaint_region = binary_dilation(mask, iterations=3)
244
+ inpaint_region[mask] = 0
245
+
246
+ search_region = mask.copy()
247
+ not_search_region = binary_erosion(search_region, iterations=2)
248
+ search_region[not_search_region] = 0
249
+
250
+ search_coords = np.stack(np.nonzero(search_region), axis=-1)
251
+ inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
252
+
253
+ knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
254
+ _, indices = knn.kneighbors(inpaint_coords)
255
+
256
+ feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
257
+
258
+ # do ssaa after the NN search, in numpy
259
+ feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)
260
+
261
+ if ssaa > 1:
262
+ # alphas = cv2.resize(alphas, (w0, h0), interpolation=cv2.INTER_NEAREST)
263
+ feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)
264
+
265
+ # cv2.imwrite(os.path.join(path, f'alpha.png'), alphas)
266
+ cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats)
267
+
268
+ # save obj (v, vt, f /)
269
+ obj_file = os.path.join(path, f'{name}mesh.obj')
270
+ mtl_file = os.path.join(path, f'{name}mesh.mtl')
271
+
272
+ print(f'[INFO] writing obj mesh to {obj_file}')
273
+ with open(obj_file, "w") as fp:
274
+ fp.write(f'mtllib {name}.mtl \n')
275
+
276
+ print(f'[INFO] writing vertices {v_np.shape}')
277
+ for v in v_np:
278
+ fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
279
+
280
+ print(f'[INFO] writing vertices texture coords {vt_np.shape}')
281
+ for v in vt_np:
282
+ fp.write(f'vt {v[0]} {1 - v[1]} \n')
283
+
284
+ print(f'[INFO] writing faces {f_np.shape}')
285
+ fp.write(f'usemtl mat0 \n')
286
+ for i in range(len(f_np)):
287
+ fp.write(f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n")
288
+
289
+ with open(mtl_file, "w") as fp:
290
+ fp.write(f'newmtl mat0 \n')
291
+ fp.write(f'Ka 1.000000 1.000000 1.000000 \n')
292
+ fp.write(f'Kd 1.000000 1.000000 1.000000 \n')
293
+ fp.write(f'Ks 0.000000 0.000000 0.000000 \n')
294
+ fp.write(f'Tr 1.000000 \n')
295
+ fp.write(f'illum 1 \n')
296
+ fp.write(f'Ns 0.000000 \n')
297
+ fp.write(f'map_Kd {name}albedo.png \n')
298
+
299
+ _export(v, f)
300
+
301
+ def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs):
302
+ # rays_o, rays_d: [B, N, 3], assumes B == 1
303
+ # bg_color: [BN, 3] in range [0, 1]
304
+ # return: image: [B, N, 3], depth: [B, N]
305
+
306
+ prefix = rays_o.shape[:-1]
307
+ rays_o = rays_o.contiguous().view(-1, 3)
308
+ rays_d = rays_d.contiguous().view(-1, 3)
309
+
310
+ N = rays_o.shape[0] # N = B * N, in fact
311
+ device = rays_o.device
312
+
313
+ results = {}
314
+
315
+ # choose aabb
316
+ aabb = self.aabb_train if self.training else self.aabb_infer
317
+
318
+ # sample steps
319
+ nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
320
+ nears.unsqueeze_(-1)
321
+ fars.unsqueeze_(-1)
322
+
323
+ #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
324
+
325
+ z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T]
326
+ z_vals = z_vals.expand((N, num_steps)) # [N, T]
327
+ z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
328
+
329
+ # perturb z_vals
330
+ sample_dist = (fars - nears) / num_steps
331
+ if perturb:
332
+ z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
333
+ #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
334
+
335
+ # generate xyzs
336
+ xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
337
+ xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
338
+
339
+ #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
340
+
341
+ # query SDF and RGB
342
+ density_outputs = self.density(xyzs.reshape(-1, 3))
343
+
344
+ #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
345
+ for k, v in density_outputs.items():
346
+ density_outputs[k] = v.view(N, num_steps, -1)
347
+
348
+ # upsample z_vals (nerf-like)
349
+ if upsample_steps > 0:
350
+ with torch.no_grad():
351
+
352
+ deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
353
+ deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
354
+
355
+ alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T]
356
+ alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
357
+ weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
358
+
359
+ # sample new z_vals
360
+ z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
361
+ new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t]
362
+
363
+ new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
364
+ new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
365
+
366
+ # only forward new points to save computation
367
+ new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
368
+ #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
369
+ for k, v in new_density_outputs.items():
370
+ new_density_outputs[k] = v.view(N, upsample_steps, -1)
371
+
372
+ # re-order
373
+ z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
374
+ z_vals, z_index = torch.sort(z_vals, dim=1)
375
+
376
+ xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
377
+ xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
378
+
379
+ for k in density_outputs:
380
+ tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
381
+ density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
382
+
383
+ deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
384
+ deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
385
+ alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
386
+ alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
387
+ weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
388
+
389
+ dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
390
+ for k, v in density_outputs.items():
391
+ density_outputs[k] = v.view(-1, v.shape[-1])
392
+
393
+ sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading)
394
+ rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
395
+
396
+ #print(xyzs.shape, 'valid_rgb:', mask.sum().item())
397
+ # orientation loss
398
+ if normals is not None:
399
+ normals = normals.view(N, -1, 3)
400
+ # print(weights.shape, normals.shape, dirs.shape)
401
+ loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
402
+ results['loss_orient'] = loss_orient.mean()
403
+
404
+ # calculate weight_sum (mask)
405
+ weights_sum = weights.sum(dim=-1) # [N]
406
+
407
+ # calculate depth
408
+ ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
409
+ depth = torch.sum(weights * ori_z_vals, dim=-1)
410
+
411
+ # calculate color
412
+ image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
413
+
414
+ # mix background color
415
+ if self.bg_radius > 0:
416
+ # use the bg model to calculate bg_color
417
+ sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
418
+ bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3]
419
+ elif bg_color is None:
420
+ bg_color = 1
421
+
422
+ image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
423
+
424
+ image = image.view(*prefix, 3)
425
+ depth = depth.view(*prefix)
426
+
427
+ mask = (nears < fars).reshape(*prefix)
428
+
429
+ results['image'] = image
430
+ results['depth'] = depth
431
+ results['weights_sum'] = weights_sum
432
+ results['mask'] = mask
433
+
434
+ return results
435
+
436
+
437
+ def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
438
+ # rays_o, rays_d: [B, N, 3], assumes B == 1
439
+ # return: image: [B, N, 3], depth: [B, N]
440
+
441
+ prefix = rays_o.shape[:-1]
442
+ rays_o = rays_o.contiguous().view(-1, 3)
443
+ rays_d = rays_d.contiguous().view(-1, 3)
444
+
445
+ N = rays_o.shape[0] # N = B * N, in fact
446
+ device = rays_o.device
447
+
448
+ # pre-calculate near far
449
+ nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer)
450
+
451
+ results = {}
452
+
453
+ if self.training:
454
+ # setup counter
455
+ counter = self.step_counter[self.local_step % 16]
456
+ counter.zero_() # set to 0
457
+ self.local_step += 1
458
+
459
+ xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
460
+
461
+ #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
462
+
463
+ sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
464
+
465
+ #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')
466
+
467
+ weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)
468
+
469
+ # orientation loss
470
+ if normals is not None:
471
+ weights = 1 - torch.exp(-sigmas)
472
+ loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
473
+ results['loss_orient'] = loss_orient.mean()
474
+
475
+ else:
476
+
477
+ # allocate outputs
478
+ dtype = torch.float32
479
+
480
+ # fix light for all samples if not provided
481
+ if light_d is None:
482
+ light_d = torch.randn(3, device=device, dtype=torch.float)
483
+ light_d = safe_normalize(light_d)
484
+
485
+ weights_sum = torch.zeros(N, dtype=dtype, device=device)
486
+ depth = torch.zeros(N, dtype=dtype, device=device)
487
+ image = torch.zeros(N, 3, dtype=dtype, device=device)
488
+
489
+ n_alive = N
490
+ rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
491
+ rays_t = nears.clone() # [N]
492
+
493
+ step = 0
494
+
495
+ while step < max_steps: # hard coded max step
496
+
497
+ # count alive rays
498
+ n_alive = rays_alive.shape[0]
499
+
500
+ # exit loop
501
+ if n_alive <= 0:
502
+ break
503
+
504
+ # decide compact_steps
505
+ n_step = max(min(N // n_alive, 8), 1)
506
+
507
+ xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
508
+
509
+ sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
510
+
511
+ raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)
512
+
513
+ rays_alive = rays_alive[rays_alive >= 0]
514
+ #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
515
+
516
+ step += n_step
517
+
518
+ # mix background color
519
+ if self.bg_radius > 0:
520
+
521
+ # use the bg model to calculate bg_color
522
+ sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
523
+ bg_color = self.background(sph, rays_d) # [N, 3]
524
+
525
+ elif bg_color is None:
526
+ bg_color = 1
527
+
528
+ image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
529
+ image = image.view(*prefix, 3)
530
+
531
+ depth = torch.clamp(depth - nears, min=0) / (fars - nears)
532
+ depth = depth.view(*prefix)
533
+
534
+ weights_sum = weights_sum.reshape(*prefix)
535
+
536
+ mask = (nears < fars).reshape(*prefix)
537
+
538
+ results['image'] = image
539
+ results['depth'] = depth
540
+ results['weights_sum'] = weights_sum
541
+ results['mask'] = mask
542
+
543
+ return results
544
+
545
+
546
+ @torch.no_grad()
547
+ def update_extra_state(self, decay=0.95, S=128):
548
+ # call before each epoch to update extra states.
549
+
550
+ if not self.cuda_ray:
551
+ return
552
+
553
+ ### update density grid
554
+ tmp_grid = - torch.ones_like(self.density_grid)
555
+
556
+ X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
557
+ Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
558
+ Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
559
+
560
+ for xs in X:
561
+ for ys in Y:
562
+ for zs in Z:
563
+
564
+ # construct points
565
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
566
+ coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
567
+ indices = raymarching.morton3D(coords).long() # [N]
568
+ xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
569
+
570
+ # cascading
571
+ for cas in range(self.cascade):
572
+ bound = min(2 ** cas, self.bound)
573
+ half_grid_size = bound / self.grid_size
574
+ # scale to current cascade's resolution
575
+ cas_xyzs = xyzs * (bound - half_grid_size)
576
+ # add noise in [-hgs, hgs]
577
+ cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
578
+ # query density
579
+ sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
580
+ # assign
581
+ tmp_grid[cas, indices] = sigmas
582
+
583
+ # ema update
584
+ valid_mask = self.density_grid >= 0
585
+ self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
586
+ self.mean_density = torch.mean(self.density_grid[valid_mask]).item()
587
+ self.iter_density += 1
588
+
589
+ # convert to bitfield
590
+ density_thresh = min(self.mean_density, self.density_thresh)
591
+ self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
592
+
593
+ ### update step counter
594
+ total_step = min(16, self.local_step)
595
+ if total_step > 0:
596
+ self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
597
+ self.local_step = 0
598
+
599
+ # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
600
+
601
+
602
+ def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
603
+ # rays_o, rays_d: [B, N, 3], assumes B == 1
604
+ # return: pred_rgb: [B, N, 3]
605
+
606
+ if self.cuda_ray:
607
+ _run = self.run_cuda
608
+ else:
609
+ _run = self.run
610
+
611
+ B, N = rays_o.shape[:2]
612
+ device = rays_o.device
613
+
614
+ # never stage when cuda_ray
615
+ if staged and not self.cuda_ray:
616
+ depth = torch.empty((B, N), device=device)
617
+ image = torch.empty((B, N, 3), device=device)
618
+ weights_sum = torch.empty((B, N), device=device)
619
+
620
+ for b in range(B):
621
+ head = 0
622
+ while head < N:
623
+ tail = min(head + max_ray_batch, N)
624
+ results_ = _run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs)
625
+ depth[b:b+1, head:tail] = results_['depth']
626
+ weights_sum[b:b+1, head:tail] = results_['weights_sum']
627
+ image[b:b+1, head:tail] = results_['image']
628
+ head += max_ray_batch
629
+
630
+ results = {}
631
+ results['depth'] = depth
632
+ results['image'] = image
633
+ results['weights_sum'] = weights_sum
634
+
635
+ else:
636
+ results = _run(rays_o, rays_d, **kwargs)
637
+
638
+ return results
nerf/sd.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
2
+ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
3
+
4
+ # suppress partial model loading warning
5
+ logging.set_verbosity_error()
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ import time
12
+
13
+ class StableDiffusion(nn.Module):
14
+ def __init__(self, device):
15
+ super().__init__()
16
+
17
+ try:
18
+ with open('./TOKEN', 'r') as f:
19
+ self.token = f.read()
20
+ print(f'[INFO] successfully loaded hugging face user token!')
21
+ except FileNotFoundError as e:
22
+ print(e)
23
+ print(f'[INFO] Please first create a file called TOKEN and copy your hugging face access token into it to download stable diffusion checkpoints.')
24
+
25
+ self.device = device
26
+ self.num_train_timesteps = 1000
27
+ self.min_step = int(self.num_train_timesteps * 0.02)
28
+ self.max_step = int(self.num_train_timesteps * 0.98)
29
+
30
+ print(f'[INFO] loading stable diffusion...')
31
+
32
+ # 1. Load the autoencoder model which will be used to decode the latents into image space.
33
+ self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=self.token).to(self.device)
34
+
35
+ # 2. Load the tokenizer and text encoder to tokenize and encode the text.
36
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
37
+ self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
38
+
39
+ # 3. The UNet model for generating the latents.
40
+ self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=self.token).to(self.device)
41
+
42
+ # 4. Create a scheduler for inference
43
+ self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=self.num_train_timesteps)
44
+
45
+ print(f'[INFO] loaded stable diffusion!')
46
+
47
+ def get_text_embeds(self, prompt):
48
+ # Tokenize text and get embeddings
49
+ text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
50
+
51
+ with torch.no_grad():
52
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
53
+
54
+ # Do the same for unconditional embeddings
55
+ uncond_input = self.tokenizer([''] * len(prompt), padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
56
+
57
+ with torch.no_grad():
58
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
59
+
60
+ # Cat for final embeddings
61
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
62
+ return text_embeddings
63
+
64
+
65
+ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100):
66
+
67
+ # interp to 512x512 to be fed into vae.
68
+
69
+ # _t = time.time()
70
+ pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
71
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
72
+
73
+ # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
74
+ t = torch.randint(self.min_step, self.max_step + 1, [1], dtype=torch.long, device=self.device)
75
+
76
+ # encode image into latents with vae, requires grad!
77
+ # _t = time.time()
78
+ latents = self.encode_imgs(pred_rgb_512)
79
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
80
+
81
+ # predict the noise residual with unet, NO grad!
82
+ # _t = time.time()
83
+ with torch.no_grad():
84
+ # add noise
85
+ noise = torch.randn_like(latents)
86
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
87
+ # pred noise
88
+ latent_model_input = torch.cat([latents_noisy] * 2)
89
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
90
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: unet {time.time() - _t:.4f}s')
91
+
92
+ # perform guidance (high scale from paper!)
93
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
94
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
95
+
96
+ # w(t), one_minus_alpha_prod, i.e., sigma^2
97
+ w = (1 - self.scheduler.alphas_cumprod[t]).to(self.device)
98
+ grad = w * (noise_pred - noise)
99
+
100
+ # clip grad for stable training?
101
+ # grad = grad.clamp(-1, 1)
102
+
103
+ # manually backward, since we omitted an item in grad and cannot simply autodiff.
104
+ # _t = time.time()
105
+ latents.backward(gradient=grad, retain_graph=True)
106
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: backward {time.time() - _t:.4f}s')
107
+
108
+ return 0 # fake loss value
109
+
110
+ def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
111
+
112
+ if latents is None:
113
+ latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
114
+
115
+ self.scheduler.set_timesteps(num_inference_steps)
116
+
117
+ with torch.autocast('cuda'):
118
+ for i, t in enumerate(self.scheduler.timesteps):
119
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
120
+ latent_model_input = torch.cat([latents] * 2)
121
+
122
+ # predict the noise residual
123
+ with torch.no_grad():
124
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
125
+
126
+ # perform guidance
127
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
128
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
129
+
130
+ # compute the previous noisy sample x_t -> x_t-1
131
+ latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
132
+
133
+ return latents
134
+
135
+ def decode_latents(self, latents):
136
+
137
+ latents = 1 / 0.18215 * latents
138
+
139
+ with torch.no_grad():
140
+ imgs = self.vae.decode(latents).sample
141
+
142
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
143
+
144
+ return imgs
145
+
146
+ def encode_imgs(self, imgs):
147
+ # imgs: [B, 3, H, W]
148
+
149
+ imgs = 2 * imgs - 1
150
+
151
+ posterior = self.vae.encode(imgs).latent_dist
152
+ latents = posterior.sample() * 0.18215
153
+
154
+ return latents
155
+
156
+ def prompt_to_img(self, prompts, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
157
+
158
+ if isinstance(prompts, str):
159
+ prompts = [prompts]
160
+
161
+ # Prompts -> text embeds
162
+ text_embeds = self.get_text_embeds(prompts) # [2, 77, 768]
163
+
164
+ # Text embeds -> img latents
165
+ latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
166
+
167
+ # Img latents -> imgs
168
+ imgs = self.decode_latents(latents) # [1, 3, 512, 512]
169
+
170
+ # Img to Numpy
171
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
172
+ imgs = (imgs * 255).round().astype('uint8')
173
+
174
+ return imgs
175
+
176
+
177
+ if __name__ == '__main__':
178
+
179
+ import argparse
180
+ import matplotlib.pyplot as plt
181
+
182
+ parser = argparse.ArgumentParser()
183
+ parser.add_argument('prompt', type=str)
184
+ parser.add_argument('-H', type=int, default=512)
185
+ parser.add_argument('-W', type=int, default=512)
186
+ parser.add_argument('--steps', type=int, default=50)
187
+ opt = parser.parse_args()
188
+
189
+ device = torch.device('cuda')
190
+
191
+ sd = StableDiffusion(device)
192
+
193
+ imgs = sd.prompt_to_img(opt.prompt, opt.H, opt.W, opt.steps)
194
+
195
+ # visualize image
196
+ plt.imshow(imgs[0])
197
+ plt.show()
198
+
199
+
200
+
201
+
nerf/utils.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import tqdm
4
+ import math
5
+ import imageio
6
+ import random
7
+ import warnings
8
+ import tensorboardX
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ import time
14
+ from datetime import datetime
15
+
16
+ import cv2
17
+ import matplotlib.pyplot as plt
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.optim as optim
22
+ import torch.nn.functional as F
23
+ import torch.distributed as dist
24
+ from torch.utils.data import Dataset, DataLoader
25
+
26
+ import trimesh
27
+ from rich.console import Console
28
+ from torch_ema import ExponentialMovingAverage
29
+
30
+ from packaging import version as pver
31
+
32
+ def custom_meshgrid(*args):
33
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
34
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
35
+ return torch.meshgrid(*args)
36
+ else:
37
+ return torch.meshgrid(*args, indexing='ij')
38
+
39
+ def safe_normalize(x, eps=1e-20):
40
+ return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
41
+
42
+ @torch.cuda.amp.autocast(enabled=False)
43
+ def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
44
+ ''' get rays
45
+ Args:
46
+ poses: [B, 4, 4], cam2world
47
+ intrinsics: [4]
48
+ H, W, N: int
49
+ error_map: [B, 128 * 128], sample probability based on training error
50
+ Returns:
51
+ rays_o, rays_d: [B, N, 3]
52
+ inds: [B, N]
53
+ '''
54
+
55
+ device = poses.device
56
+ B = poses.shape[0]
57
+ fx, fy, cx, cy = intrinsics
58
+
59
+ i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device))
60
+ i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
61
+ j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
62
+
63
+ results = {}
64
+
65
+ if N > 0:
66
+ N = min(N, H*W)
67
+
68
+ if error_map is None:
69
+ inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
70
+ inds = inds.expand([B, N])
71
+ else:
72
+
73
+ # weighted sample on a low-reso grid
74
+ inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128)
75
+
76
+ # map to the original resolution with random perturb.
77
+ inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway.
78
+ sx, sy = H / 128, W / 128
79
+ inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1)
80
+ inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1)
81
+ inds = inds_x * W + inds_y
82
+
83
+ results['inds_coarse'] = inds_coarse # need this when updating error_map
84
+
85
+ i = torch.gather(i, -1, inds)
86
+ j = torch.gather(j, -1, inds)
87
+
88
+ results['inds'] = inds
89
+
90
+ else:
91
+ inds = torch.arange(H*W, device=device).expand([B, H*W])
92
+
93
+ zs = torch.ones_like(i)
94
+ xs = (i - cx) / fx * zs
95
+ ys = (j - cy) / fy * zs
96
+ directions = torch.stack((xs, ys, zs), dim=-1)
97
+ directions = safe_normalize(directions)
98
+ rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
99
+
100
+ rays_o = poses[..., :3, 3] # [B, 3]
101
+ rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
102
+
103
+ results['rays_o'] = rays_o
104
+ results['rays_d'] = rays_d
105
+
106
+ return results
107
+
108
+
109
+ def seed_everything(seed):
110
+ random.seed(seed)
111
+ os.environ['PYTHONHASHSEED'] = str(seed)
112
+ np.random.seed(seed)
113
+ torch.manual_seed(seed)
114
+ torch.cuda.manual_seed(seed)
115
+ #torch.backends.cudnn.deterministic = True
116
+ #torch.backends.cudnn.benchmark = True
117
+
118
+
119
+ def torch_vis_2d(x, renormalize=False):
120
+ # x: [3, H, W] or [1, H, W] or [H, W]
121
+ import matplotlib.pyplot as plt
122
+ import numpy as np
123
+ import torch
124
+
125
+ if isinstance(x, torch.Tensor):
126
+ if len(x.shape) == 3:
127
+ x = x.permute(1,2,0).squeeze()
128
+ x = x.detach().cpu().numpy()
129
+
130
+ print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}')
131
+
132
+ x = x.astype(np.float32)
133
+
134
+ # renormalize
135
+ if renormalize:
136
+ x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8)
137
+
138
+ plt.imshow(x)
139
+ plt.show()
140
+
141
+ @torch.jit.script
142
+ def linear_to_srgb(x):
143
+ return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
144
+
145
+
146
+ @torch.jit.script
147
+ def srgb_to_linear(x):
148
+ return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
149
+
150
+
151
+ class Trainer(object):
152
+ def __init__(self,
153
+ name, # name of this experiment
154
+ opt, # extra conf
155
+ model, # network
156
+ guidance, # guidance network
157
+ criterion=None, # loss function, if None, assume inline implementation in train_step
158
+ optimizer=None, # optimizer
159
+ ema_decay=None, # if use EMA, set the decay
160
+ lr_scheduler=None, # scheduler
161
+ metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
162
+ local_rank=0, # which GPU am I
163
+ world_size=1, # total num of GPUs
164
+ device=None, # device to use, usually setting to None is OK. (auto choose device)
165
+ mute=False, # whether to mute all print
166
+ fp16=False, # amp optimize level
167
+ eval_interval=1, # eval once every $ epoch
168
+ max_keep_ckpt=2, # max num of saved ckpts in disk
169
+ workspace='workspace', # workspace to save logs & ckpts
170
+ best_mode='min', # the smaller/larger result, the better
171
+ use_loss_as_metric=True, # use loss as the first metric
172
+ report_metric_at_train=False, # also report metrics at training
173
+ use_checkpoint="latest", # which ckpt to use at init time
174
+ use_tensorboardX=True, # whether to use tensorboard for logging
175
+ scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
176
+ ):
177
+
178
+ self.name = name
179
+ self.opt = opt
180
+ self.mute = mute
181
+ self.metrics = metrics
182
+ self.local_rank = local_rank
183
+ self.world_size = world_size
184
+ self.workspace = workspace
185
+ self.ema_decay = ema_decay
186
+ self.fp16 = fp16
187
+ self.best_mode = best_mode
188
+ self.use_loss_as_metric = use_loss_as_metric
189
+ self.report_metric_at_train = report_metric_at_train
190
+ self.max_keep_ckpt = max_keep_ckpt
191
+ self.eval_interval = eval_interval
192
+ self.use_checkpoint = use_checkpoint
193
+ self.use_tensorboardX = use_tensorboardX
194
+ self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
195
+ self.scheduler_update_every_step = scheduler_update_every_step
196
+ self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
197
+ self.console = Console()
198
+
199
+ # text prompt
200
+ ref_text = self.opt.text
201
+
202
+ model.to(self.device)
203
+ if self.world_size > 1:
204
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
205
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
206
+ self.model = model
207
+
208
+ # guide model
209
+ self.guidance = guidance
210
+
211
+ if self.guidance is not None:
212
+
213
+ for p in self.guidance.parameters():
214
+ p.requires_grad = False
215
+
216
+ if not self.opt.dir_text:
217
+ self.text_z = self.guidance.get_text_embeds([ref_text])
218
+ else:
219
+ self.text_z = []
220
+ for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']:
221
+ text = f"{ref_text}, {d} view"
222
+ text_z = self.guidance.get_text_embeds([text])
223
+ self.text_z.append(text_z)
224
+
225
+ else:
226
+ self.text_z = None
227
+
228
+ if isinstance(criterion, nn.Module):
229
+ criterion.to(self.device)
230
+ self.criterion = criterion
231
+
232
+ if optimizer is None:
233
+ self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam
234
+ else:
235
+ self.optimizer = optimizer(self.model)
236
+
237
+ if lr_scheduler is None:
238
+ self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler
239
+ else:
240
+ self.lr_scheduler = lr_scheduler(self.optimizer)
241
+
242
+ if ema_decay is not None:
243
+ self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay)
244
+ else:
245
+ self.ema = None
246
+
247
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
248
+
249
+ # variable init
250
+ self.epoch = 0
251
+ self.global_step = 0
252
+ self.local_step = 0
253
+ self.stats = {
254
+ "loss": [],
255
+ "valid_loss": [],
256
+ "results": [], # metrics[0], or valid_loss
257
+ "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt
258
+ "best_result": None,
259
+ }
260
+
261
+ # auto fix
262
+ if len(metrics) == 0 or self.use_loss_as_metric:
263
+ self.best_mode = 'min'
264
+
265
+ # workspace prepare
266
+ self.log_ptr = None
267
+ if self.workspace is not None:
268
+ os.makedirs(self.workspace, exist_ok=True)
269
+ self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
270
+ self.log_ptr = open(self.log_path, "a+")
271
+
272
+ self.ckpt_path = os.path.join(self.workspace, 'checkpoints')
273
+ self.best_path = f"{self.ckpt_path}/{self.name}.pth"
274
+ os.makedirs(self.ckpt_path, exist_ok=True)
275
+
276
+ self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}')
277
+ self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
278
+
279
+ if self.workspace is not None:
280
+ if self.use_checkpoint == "scratch":
281
+ self.log("[INFO] Training from scratch ...")
282
+ elif self.use_checkpoint == "latest":
283
+ self.log("[INFO] Loading latest checkpoint ...")
284
+ self.load_checkpoint()
285
+ elif self.use_checkpoint == "latest_model":
286
+ self.log("[INFO] Loading latest checkpoint (model only)...")
287
+ self.load_checkpoint(model_only=True)
288
+ elif self.use_checkpoint == "best":
289
+ if os.path.exists(self.best_path):
290
+ self.log("[INFO] Loading best checkpoint ...")
291
+ self.load_checkpoint(self.best_path)
292
+ else:
293
+ self.log(f"[INFO] {self.best_path} not found, loading latest ...")
294
+ self.load_checkpoint()
295
+ else: # path to ckpt
296
+ self.log(f"[INFO] Loading {self.use_checkpoint} ...")
297
+ self.load_checkpoint(self.use_checkpoint)
298
+
299
+ def __del__(self):
300
+ if self.log_ptr:
301
+ self.log_ptr.close()
302
+
303
+
304
+ def log(self, *args, **kwargs):
305
+ if self.local_rank == 0:
306
+ if not self.mute:
307
+ #print(*args)
308
+ self.console.print(*args, **kwargs)
309
+ if self.log_ptr:
310
+ print(*args, file=self.log_ptr)
311
+ self.log_ptr.flush() # write immediately to file
312
+
313
+ ### ------------------------------
314
+
315
+ def train_step(self, data):
316
+
317
+ rays_o = data['rays_o'] # [B, N, 3]
318
+ rays_d = data['rays_d'] # [B, N, 3]
319
+
320
+ B, N = rays_o.shape[:2]
321
+ H, W = data['H'], data['W']
322
+
323
+ # TODO: shading is not working right now...
324
+ if self.global_step < self.opt.albedo_iters:
325
+ shading = 'albedo'
326
+ ambient_ratio = 1.0
327
+ else:
328
+ rand = random.random()
329
+ if rand > 0.8:
330
+ shading = 'albedo'
331
+ ambient_ratio = 1.0
332
+ elif rand > 0.4:
333
+ shading = 'lambertian'
334
+ ambient_ratio = 0.1
335
+ else:
336
+ shading = 'textureless'
337
+ ambient_ratio = 0.1
338
+
339
+ # _t = time.time()
340
+ bg_color = torch.rand((B * N, 3), device=rays_o.device) # pixel-wise random
341
+ outputs = self.model.render(rays_o, rays_d, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt))
342
+ pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
343
+ # torch.cuda.synchronize(); print(f'[TIME] nerf render {time.time() - _t:.4f}s')
344
+
345
+ # text embeddings
346
+ if self.opt.dir_text:
347
+ dirs = data['dir'] # [B,]
348
+ text_z = self.text_z[dirs]
349
+ else:
350
+ text_z = self.text_z
351
+
352
+ # encode pred_rgb to latents
353
+ # _t = time.time()
354
+ loss_guidance = self.guidance.train_step(text_z, pred_rgb)
355
+ # torch.cuda.synchronize(); print(f'[TIME] total guiding {time.time() - _t:.4f}s')
356
+
357
+ # occupancy loss
358
+ pred_ws = outputs['weights_sum'].reshape(B, 1, H, W)
359
+ # mask_ws = outputs['mask'].reshape(B, 1, H, W) # near < far
360
+
361
+ # loss_ws = (pred_ws ** 2 + 0.01).sqrt().mean()
362
+
363
+ alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
364
+ # alphas = alphas ** 2 # skewed entropy, favors 0 over 1
365
+ loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
366
+
367
+ loss = loss_guidance + 1e-3 * loss_entropy
368
+
369
+ if 'loss_orient' in outputs:
370
+ loss_orient = outputs['loss_orient']
371
+ loss = loss + 1e-2 * loss_orient
372
+
373
+ return pred_rgb, pred_ws, loss
374
+
375
+ def eval_step(self, data):
376
+
377
+ rays_o = data['rays_o'] # [B, N, 3]
378
+ rays_d = data['rays_d'] # [B, N, 3]
379
+
380
+ B, N = rays_o.shape[:2]
381
+ H, W = data['H'], data['W']
382
+
383
+ shading = data['shading'] if 'shading' in data else 'albedo'
384
+ ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
385
+ light_d = data['light_d'] if 'light_d' in data else None
386
+
387
+ outputs = self.model.render(rays_o, rays_d, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt))
388
+ pred_rgb = outputs['image'].reshape(B, H, W, 3)
389
+ pred_depth = outputs['depth'].reshape(B, H, W)
390
+ pred_ws = outputs['weights_sum'].reshape(B, H, W)
391
+ # mask_ws = outputs['mask'].reshape(B, H, W) # near < far
392
+
393
+ # loss_ws = pred_ws.sum() / mask_ws.sum()
394
+ # loss_ws = pred_ws.mean()
395
+
396
+ alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
397
+ # alphas = alphas ** 2 # skewed entropy, favors 0 over 1
398
+ loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
399
+
400
+ loss = 1e-3 * loss_entropy
401
+
402
+ return pred_rgb, pred_depth, loss
403
+
404
+ # moved out bg_color and perturb for more flexible control...
405
+ def test_step(self, data, bg_color=None, perturb=False):
406
+ rays_o = data['rays_o'] # [B, N, 3]
407
+ rays_d = data['rays_d'] # [B, N, 3]
408
+
409
+ B, N = rays_o.shape[:2]
410
+ H, W = data['H'], data['W']
411
+
412
+ if bg_color is not None:
413
+ bg_color = bg_color.to(rays_o.device)
414
+ else:
415
+ bg_color = torch.ones(3, device=rays_o.device) # [3]
416
+
417
+ shading = data['shading'] if 'shading' in data else 'albedo'
418
+ ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
419
+ light_d = data['light_d'] if 'light_d' in data else None
420
+
421
+ outputs = self.model.render(rays_o, rays_d, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, bg_color=bg_color, **vars(self.opt))
422
+
423
+ pred_rgb = outputs['image'].reshape(B, H, W, 3)
424
+ pred_depth = outputs['depth'].reshape(B, H, W)
425
+
426
+ return pred_rgb, pred_depth
427
+
428
+
429
+ def save_mesh(self, save_path=None, resolution=128):
430
+
431
+ if save_path is None:
432
+ save_path = os.path.join(self.workspace, 'mesh')
433
+
434
+ self.log(f"==> Saving mesh to {save_path}")
435
+
436
+ os.makedirs(save_path, exist_ok=True)
437
+
438
+ self.model.export_mesh(save_path, resolution=resolution)
439
+
440
+ self.log(f"==> Finished saving mesh.")
441
+
442
+ ### ------------------------------
443
+
444
+ def train(self, train_loader, valid_loader, max_epochs):
445
+ if self.use_tensorboardX and self.local_rank == 0:
446
+ self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name))
447
+
448
+ start_t = time.time()
449
+
450
+ for epoch in range(self.epoch + 1, max_epochs + 1):
451
+ self.epoch = epoch
452
+
453
+ self.train_one_epoch(train_loader)
454
+
455
+ if self.workspace is not None and self.local_rank == 0:
456
+ self.save_checkpoint(full=True, best=False)
457
+
458
+ if self.epoch % self.eval_interval == 0:
459
+ self.evaluate_one_epoch(valid_loader)
460
+ self.save_checkpoint(full=False, best=True)
461
+
462
+ end_t = time.time()
463
+
464
+ self.log(f"[INFO] training takes {(end_t - start_t)/ 60:.4f} minutes.")
465
+
466
+ if self.use_tensorboardX and self.local_rank == 0:
467
+ self.writer.close()
468
+
469
+ def evaluate(self, loader, name=None):
470
+ self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
471
+ self.evaluate_one_epoch(loader, name)
472
+ self.use_tensorboardX = use_tensorboardX
473
+
474
+ def test(self, loader, save_path=None, name=None, write_video=True):
475
+
476
+ if save_path is None:
477
+ save_path = os.path.join(self.workspace, 'results')
478
+
479
+ if name is None:
480
+ name = f'{self.name}_ep{self.epoch:04d}'
481
+
482
+ os.makedirs(save_path, exist_ok=True)
483
+
484
+ self.log(f"==> Start Test, save results to {save_path}")
485
+
486
+ pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
487
+ self.model.eval()
488
+
489
+ if write_video:
490
+ all_preds = []
491
+ all_preds_depth = []
492
+
493
+ with torch.no_grad():
494
+
495
+ for i, data in enumerate(loader):
496
+
497
+ with torch.cuda.amp.autocast(enabled=self.fp16):
498
+ preds, preds_depth = self.test_step(data)
499
+
500
+ pred = preds[0].detach().cpu().numpy()
501
+ pred = (pred * 255).astype(np.uint8)
502
+
503
+ pred_depth = preds_depth[0].detach().cpu().numpy()
504
+ pred_depth = (pred_depth * 255).astype(np.uint8)
505
+
506
+ if write_video:
507
+ all_preds.append(pred)
508
+ all_preds_depth.append(pred_depth)
509
+ else:
510
+ cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
511
+ cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth)
512
+
513
+ pbar.update(loader.batch_size)
514
+
515
+ if write_video:
516
+ all_preds = np.stack(all_preds, axis=0)
517
+ all_preds_depth = np.stack(all_preds_depth, axis=0)
518
+
519
+ imageio.mimwrite(os.path.join(save_path, f'{name}_rgb.mp4'), all_preds, fps=25, quality=8, macro_block_size=1)
520
+ imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1)
521
+
522
+ self.log(f"==> Finished Test.")
523
+
524
+ # [GUI] train text step.
525
+ def train_gui(self, train_loader, step=16):
526
+
527
+ self.model.train()
528
+
529
+ total_loss = torch.tensor([0], dtype=torch.float32, device=self.device)
530
+
531
+ loader = iter(train_loader)
532
+
533
+ for _ in range(step):
534
+
535
+ # mimic an infinite loop dataloader (in case the total dataset is smaller than step)
536
+ try:
537
+ data = next(loader)
538
+ except StopIteration:
539
+ loader = iter(train_loader)
540
+ data = next(loader)
541
+
542
+ # update grid every 16 steps
543
+ if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
544
+ with torch.cuda.amp.autocast(enabled=self.fp16):
545
+ self.model.update_extra_state()
546
+
547
+ self.global_step += 1
548
+
549
+ self.optimizer.zero_grad()
550
+
551
+ with torch.cuda.amp.autocast(enabled=self.fp16):
552
+ pred_rgbs, pred_ws, loss = self.train_step(data)
553
+
554
+ self.scaler.scale(loss).backward()
555
+ self.scaler.step(self.optimizer)
556
+ self.scaler.update()
557
+
558
+ if self.scheduler_update_every_step:
559
+ self.lr_scheduler.step()
560
+
561
+ total_loss += loss.detach()
562
+
563
+ if self.ema is not None:
564
+ self.ema.update()
565
+
566
+ average_loss = total_loss.item() / step
567
+
568
+ if not self.scheduler_update_every_step:
569
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
570
+ self.lr_scheduler.step(average_loss)
571
+ else:
572
+ self.lr_scheduler.step()
573
+
574
+ outputs = {
575
+ 'loss': average_loss,
576
+ 'lr': self.optimizer.param_groups[0]['lr'],
577
+ }
578
+
579
+ return outputs
580
+
581
+
582
+ # [GUI] test on a single image
583
+ def test_gui(self, pose, intrinsics, W, H, bg_color=None, spp=1, downscale=1, light_d=None, ambient_ratio=1.0, shading='albedo'):
584
+
585
+ # render resolution (may need downscale to for better frame rate)
586
+ rH = int(H * downscale)
587
+ rW = int(W * downscale)
588
+ intrinsics = intrinsics * downscale
589
+
590
+ pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)
591
+
592
+ rays = get_rays(pose, intrinsics, rH, rW, -1)
593
+
594
+ # from degree theta/phi to 3D normalized vec
595
+ light_d = np.deg2rad(light_d)
596
+ light_d = np.array([
597
+ np.sin(light_d[0]) * np.sin(light_d[1]),
598
+ np.cos(light_d[0]),
599
+ np.sin(light_d[0]) * np.cos(light_d[1]),
600
+ ], dtype=np.float32)
601
+ light_d = torch.from_numpy(light_d).to(self.device)
602
+
603
+ data = {
604
+ 'rays_o': rays['rays_o'],
605
+ 'rays_d': rays['rays_d'],
606
+ 'H': rH,
607
+ 'W': rW,
608
+ 'light_d': light_d,
609
+ 'ambient_ratio': ambient_ratio,
610
+ 'shading': shading,
611
+ }
612
+
613
+ self.model.eval()
614
+
615
+ if self.ema is not None:
616
+ self.ema.store()
617
+ self.ema.copy_to()
618
+
619
+ with torch.no_grad():
620
+ with torch.cuda.amp.autocast(enabled=self.fp16):
621
+ # here spp is used as perturb random seed!
622
+ preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=spp)
623
+
624
+ if self.ema is not None:
625
+ self.ema.restore()
626
+
627
+ # interpolation to the original resolution
628
+ if downscale != 1:
629
+ # have to permute twice with torch...
630
+ preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous()
631
+ preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
632
+
633
+ outputs = {
634
+ 'image': preds[0].detach().cpu().numpy(),
635
+ 'depth': preds_depth[0].detach().cpu().numpy(),
636
+ }
637
+
638
+ return outputs
639
+
640
+ def train_one_epoch(self, loader):
641
+ self.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...")
642
+
643
+ total_loss = 0
644
+ if self.local_rank == 0 and self.report_metric_at_train:
645
+ for metric in self.metrics:
646
+ metric.clear()
647
+
648
+ self.model.train()
649
+
650
+ # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs
651
+ # ref: https://pytorch.org/docs/stable/data.html
652
+ if self.world_size > 1:
653
+ loader.sampler.set_epoch(self.epoch)
654
+
655
+ if self.local_rank == 0:
656
+ pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
657
+
658
+ self.local_step = 0
659
+
660
+ for data in loader:
661
+
662
+ # update grid every 16 steps
663
+ if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
664
+ with torch.cuda.amp.autocast(enabled=self.fp16):
665
+ self.model.update_extra_state()
666
+
667
+ self.local_step += 1
668
+ self.global_step += 1
669
+
670
+ self.optimizer.zero_grad()
671
+
672
+ with torch.cuda.amp.autocast(enabled=self.fp16):
673
+ pred_rgbs, pred_ws, loss = self.train_step(data)
674
+
675
+ self.scaler.scale(loss).backward()
676
+ self.scaler.step(self.optimizer)
677
+ self.scaler.update()
678
+
679
+ if self.scheduler_update_every_step:
680
+ self.lr_scheduler.step()
681
+
682
+ loss_val = loss.item()
683
+ total_loss += loss_val
684
+
685
+ if self.local_rank == 0:
686
+ # if self.report_metric_at_train:
687
+ # for metric in self.metrics:
688
+ # metric.update(preds, truths)
689
+
690
+ if self.use_tensorboardX:
691
+ self.writer.add_scalar("train/loss", loss_val, self.global_step)
692
+ self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step)
693
+
694
+ if self.scheduler_update_every_step:
695
+ pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}")
696
+ else:
697
+ pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
698
+ pbar.update(loader.batch_size)
699
+
700
+ if self.ema is not None:
701
+ self.ema.update()
702
+
703
+ average_loss = total_loss / self.local_step
704
+ self.stats["loss"].append(average_loss)
705
+
706
+ if self.local_rank == 0:
707
+ pbar.close()
708
+ if self.report_metric_at_train:
709
+ for metric in self.metrics:
710
+ self.log(metric.report(), style="red")
711
+ if self.use_tensorboardX:
712
+ metric.write(self.writer, self.epoch, prefix="train")
713
+ metric.clear()
714
+
715
+ if not self.scheduler_update_every_step:
716
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
717
+ self.lr_scheduler.step(average_loss)
718
+ else:
719
+ self.lr_scheduler.step()
720
+
721
+ self.log(f"==> Finished Epoch {self.epoch}.")
722
+
723
+
724
+ def evaluate_one_epoch(self, loader, name=None):
725
+ self.log(f"++> Evaluate at epoch {self.epoch} ...")
726
+
727
+ if name is None:
728
+ name = f'{self.name}_ep{self.epoch:04d}'
729
+
730
+ total_loss = 0
731
+ if self.local_rank == 0:
732
+ for metric in self.metrics:
733
+ metric.clear()
734
+
735
+ self.model.eval()
736
+
737
+ if self.ema is not None:
738
+ self.ema.store()
739
+ self.ema.copy_to()
740
+
741
+ if self.local_rank == 0:
742
+ pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
743
+
744
+ with torch.no_grad():
745
+ self.local_step = 0
746
+
747
+ for data in loader:
748
+ self.local_step += 1
749
+
750
+ with torch.cuda.amp.autocast(enabled=self.fp16):
751
+ preds, preds_depth, loss = self.eval_step(data)
752
+
753
+ # all_gather/reduce the statistics (NCCL only support all_*)
754
+ if self.world_size > 1:
755
+ dist.all_reduce(loss, op=dist.ReduceOp.SUM)
756
+ loss = loss / self.world_size
757
+
758
+ preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]
759
+ dist.all_gather(preds_list, preds)
760
+ preds = torch.cat(preds_list, dim=0)
761
+
762
+ preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]
763
+ dist.all_gather(preds_depth_list, preds_depth)
764
+ preds_depth = torch.cat(preds_depth_list, dim=0)
765
+
766
+ loss_val = loss.item()
767
+ total_loss += loss_val
768
+
769
+ # only rank = 0 will perform evaluation.
770
+ if self.local_rank == 0:
771
+
772
+ # save image
773
+ save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png')
774
+ save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png')
775
+
776
+ #self.log(f"==> Saving validation image to {save_path}")
777
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
778
+
779
+ pred = preds[0].detach().cpu().numpy()
780
+ pred = (pred * 255).astype(np.uint8)
781
+
782
+ pred_depth = preds_depth[0].detach().cpu().numpy()
783
+ pred_depth = (pred_depth * 255).astype(np.uint8)
784
+
785
+ cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
786
+ cv2.imwrite(save_path_depth, pred_depth)
787
+
788
+ pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
789
+ pbar.update(loader.batch_size)
790
+
791
+
792
+ average_loss = total_loss / self.local_step
793
+ self.stats["valid_loss"].append(average_loss)
794
+
795
+ if self.local_rank == 0:
796
+ pbar.close()
797
+ if not self.use_loss_as_metric and len(self.metrics) > 0:
798
+ result = self.metrics[0].measure()
799
+ self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result
800
+ else:
801
+ self.stats["results"].append(average_loss) # if no metric, choose best by min loss
802
+
803
+ for metric in self.metrics:
804
+ self.log(metric.report(), style="blue")
805
+ if self.use_tensorboardX:
806
+ metric.write(self.writer, self.epoch, prefix="evaluate")
807
+ metric.clear()
808
+
809
+ if self.ema is not None:
810
+ self.ema.restore()
811
+
812
+ self.log(f"++> Evaluate epoch {self.epoch} Finished.")
813
+
814
+ def save_checkpoint(self, name=None, full=False, best=False):
815
+
816
+ if name is None:
817
+ name = f'{self.name}_ep{self.epoch:04d}'
818
+
819
+ state = {
820
+ 'epoch': self.epoch,
821
+ 'global_step': self.global_step,
822
+ 'stats': self.stats,
823
+ }
824
+
825
+ if self.model.cuda_ray:
826
+ state['mean_count'] = self.model.mean_count
827
+ state['mean_density'] = self.model.mean_density
828
+
829
+ if full:
830
+ state['optimizer'] = self.optimizer.state_dict()
831
+ state['lr_scheduler'] = self.lr_scheduler.state_dict()
832
+ state['scaler'] = self.scaler.state_dict()
833
+ if self.ema is not None:
834
+ state['ema'] = self.ema.state_dict()
835
+
836
+ if not best:
837
+
838
+ state['model'] = self.model.state_dict()
839
+
840
+ file_path = f"{name}.pth"
841
+
842
+ self.stats["checkpoints"].append(file_path)
843
+
844
+ if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
845
+ old_ckpt = os.path.join(self.ckpt_path, self.stats["checkpoints"].pop(0))
846
+ if os.path.exists(old_ckpt):
847
+ os.remove(old_ckpt)
848
+
849
+ torch.save(state, os.path.join(self.ckpt_path, file_path))
850
+
851
+ else:
852
+ if len(self.stats["results"]) > 0:
853
+ if self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"]:
854
+ self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}")
855
+ self.stats["best_result"] = self.stats["results"][-1]
856
+
857
+ # save ema results
858
+ if self.ema is not None:
859
+ self.ema.store()
860
+ self.ema.copy_to()
861
+
862
+ state['model'] = self.model.state_dict()
863
+
864
+ if self.ema is not None:
865
+ self.ema.restore()
866
+
867
+ torch.save(state, self.best_path)
868
+ else:
869
+ self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.")
870
+
871
+ def load_checkpoint(self, checkpoint=None, model_only=False):
872
+ if checkpoint is None:
873
+ checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth'))
874
+ if checkpoint_list:
875
+ checkpoint = checkpoint_list[-1]
876
+ self.log(f"[INFO] Latest checkpoint is {checkpoint}")
877
+ else:
878
+ self.log("[WARN] No checkpoint found, model randomly initialized.")
879
+ return
880
+
881
+ checkpoint_dict = torch.load(checkpoint, map_location=self.device)
882
+
883
+ if 'model' not in checkpoint_dict:
884
+ self.model.load_state_dict(checkpoint_dict)
885
+ self.log("[INFO] loaded model.")
886
+ return
887
+
888
+ missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
889
+ self.log("[INFO] loaded model.")
890
+ if len(missing_keys) > 0:
891
+ self.log(f"[WARN] missing keys: {missing_keys}")
892
+ if len(unexpected_keys) > 0:
893
+ self.log(f"[WARN] unexpected keys: {unexpected_keys}")
894
+
895
+ if self.ema is not None and 'ema' in checkpoint_dict:
896
+ try:
897
+ self.ema.load_state_dict(checkpoint_dict['ema'])
898
+ self.log("[INFO] loaded EMA.")
899
+ except:
900
+ self.log("[WARN] failed to loaded EMA.")
901
+
902
+ if self.model.cuda_ray:
903
+ if 'mean_count' in checkpoint_dict:
904
+ self.model.mean_count = checkpoint_dict['mean_count']
905
+ if 'mean_density' in checkpoint_dict:
906
+ self.model.mean_density = checkpoint_dict['mean_density']
907
+
908
+ if model_only:
909
+ return
910
+
911
+ self.stats = checkpoint_dict['stats']
912
+ self.epoch = checkpoint_dict['epoch']
913
+ self.global_step = checkpoint_dict['global_step']
914
+ self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
915
+
916
+ if self.optimizer and 'optimizer' in checkpoint_dict:
917
+ try:
918
+ self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
919
+ self.log("[INFO] loaded optimizer.")
920
+ except:
921
+ self.log("[WARN] Failed to load optimizer.")
922
+
923
+ if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
924
+ try:
925
+ self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
926
+ self.log("[INFO] loaded scheduler.")
927
+ except:
928
+ self.log("[WARN] Failed to load scheduler.")
929
+
930
+ if self.scaler and 'scaler' in checkpoint_dict:
931
+ try:
932
+ self.scaler.load_state_dict(checkpoint_dict['scaler'])
933
+ self.log("[INFO] loaded scaler.")
934
+ except:
935
+ self.log("[WARN] Failed to load scaler.")
optimizer.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import enum
4
+ import itertools
5
+ from dataclasses import dataclass
6
+ import torch.optim as optim
7
+
8
+ @torch.no_grad()
9
+ def PowerIter(mat_g, error_tolerance=1e-6, num_iters=100):
10
+ """Power iteration.
11
+ Compute the maximum eigenvalue of mat, for scaling.
12
+ v is a random vector with values in (-1, 1)
13
+ Args:
14
+ mat_g: the symmetric PSD matrix.
15
+ error_tolerance: Iterative exit condition.
16
+ num_iters: Number of iterations.
17
+ Returns:
18
+ eigen vector, eigen value, num_iters
19
+ """
20
+ v = torch.rand(list(mat_g.shape)[0], device=mat_g.get_device()) * 2 - 1
21
+ error = 1
22
+ iters = 0
23
+ singular_val = 0
24
+ while error > error_tolerance and iters < num_iters:
25
+ v = v / torch.norm(v)
26
+ mat_v = torch.mv(mat_g, v)
27
+ s_v = torch.dot(v, mat_v)
28
+ error = torch.abs(s_v - singular_val)
29
+ v = mat_v
30
+ singular_val = s_v
31
+ iters += 1
32
+ return singular_val, v / torch.norm(v), iters
33
+
34
+
35
+ @torch.no_grad()
36
+ def MatPower(mat_m, p):
37
+ """Computes mat_m^p, for p a positive integer.
38
+ Args:
39
+ mat_m: a square matrix
40
+ p: a positive integer
41
+ Returns:
42
+ mat_m^p
43
+ """
44
+ if p in [1, 2, 4, 8, 16, 32]:
45
+ p_done = 1
46
+ res = mat_m
47
+ while p_done < p:
48
+ res = torch.matmul(res, res)
49
+ p_done *= 2
50
+ return res
51
+
52
+ power = None
53
+ while p > 0:
54
+ if p % 2 == 1:
55
+ power = torch.matmul(mat_m, power) if power is not None else mat_m
56
+ p //= 2
57
+ mat_m = torch.matmul(mat_m, mat_m)
58
+ return power
59
+
60
+
61
+ @torch.no_grad()
62
+ def ComputePower(mat_g, p,
63
+ iter_count=100,
64
+ error_tolerance=1e-6,
65
+ ridge_epsilon=1e-6):
66
+ """A method to compute G^{-1/p} using a coupled Newton iteration.
67
+ See for example equation 3.2 on page 9 of:
68
+ A Schur-Newton Method for the Matrix p-th Root and its Inverse
69
+ by Chun-Hua Guo and Nicholas J. Higham
70
+ SIAM Journal on Matrix Analysis and Applications,
71
+ 2006, Vol. 28, No. 3 : pp. 788-804
72
+ https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
73
+ Args:
74
+ mat_g: A square positive semidefinite matrix
75
+ p: a positive integer
76
+ iter_count: Stop iterating after this many rounds.
77
+ error_tolerance: Threshold for stopping iteration
78
+ ridge_epsilon: We add this times I to G, to make is positive definite.
79
+ For scaling, we multiply it by the largest eigenvalue of G.
80
+ Returns:
81
+ (mat_g + rI)^{-1/p} (r = ridge_epsilon * max_eigenvalue of mat_g).
82
+ """
83
+ shape = list(mat_g.shape)
84
+ if len(shape) == 1:
85
+ return torch.pow(mat_g + ridge_epsilon, -1/p)
86
+ identity = torch.eye(shape[0], device=mat_g.get_device())
87
+ if shape[0] == 1:
88
+ return identity
89
+ alpha = -1.0/p
90
+ max_ev, _, _ = PowerIter(mat_g)
91
+ ridge_epsilon *= max_ev
92
+ mat_g += ridge_epsilon * identity
93
+ z = (1 + p) / (2 * torch.norm(mat_g))
94
+ # The best value for z is
95
+ # (1 + p) * (c_max^{1/p} - c_min^{1/p}) /
96
+ # (c_max^{1+1/p} - c_min^{1+1/p})
97
+ # where c_max and c_min are the largest and smallest singular values of
98
+ # mat_g.
99
+ # The above estimate assumes that c_max > c_min * 2^p
100
+ # Can replace above line by the one below, but it is less accurate,
101
+ # hence needs more iterations to converge.
102
+ # z = (1 + p) / tf.trace(mat_g)
103
+ # If we want the method to always converge, use z = 1 / norm(mat_g)
104
+ # or z = 1 / tf.trace(mat_g), but these can result in many
105
+ # extra iterations.
106
+
107
+ mat_root = identity * torch.pow(z, 1.0/p)
108
+ mat_m = mat_g * z
109
+ error = torch.max(torch.abs(mat_m - identity))
110
+ count = 0
111
+ while error > error_tolerance and count < iter_count:
112
+ tmp_mat_m = (1 - alpha) * identity + alpha * mat_m
113
+ new_mat_root = torch.matmul(mat_root, tmp_mat_m)
114
+ mat_m = torch.matmul(MatPower(tmp_mat_m, p), mat_m)
115
+ new_error = torch.max(torch.abs(mat_m - identity))
116
+ if new_error > error * 1.2:
117
+ break
118
+ mat_root = new_mat_root
119
+ error = new_error
120
+ count += 1
121
+ return mat_root
122
+
123
+
124
+
125
+ # Grafting is a technique to fix the layerwise scale of Shampoo optimizer.
126
+ # https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This
127
+ # allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad
128
+ # is already well tuned. Grafting onto Shampoo means take the Shampoo direction,
129
+ # but use the step magnitude from the grafted optimizer such as Adagrad or SGD.
130
+ class LayerwiseGrafting(enum.IntEnum):
131
+ NONE = 0
132
+ SGD = 1
133
+ ADAGRAD = 2
134
+
135
+
136
+ @dataclass
137
+ class ShampooHyperParams:
138
+ """Shampoo hyper parameters."""
139
+ beta2: float = 0.9
140
+ diagonal_eps: float = 1e-6
141
+ matrix_eps: float = 1e-12
142
+ weight_decay: float = 0.0
143
+ inverse_exponent_override: int = 2 # fixed exponent for preconditioner, if >0
144
+ start_preconditioning_step: int = 1
145
+ # Performance tuning params for controlling memory and compute requirements.
146
+ # How often to compute preconditioner.
147
+ preconditioning_compute_steps: int = 1
148
+ # How often to compute statistics.
149
+ statistics_compute_steps: int = 1
150
+ # Block size for large layers (if > 0).
151
+ # Block size = 1 ==> Adagrad (Don't do this, extremely inefficient!)
152
+ # Block size should be as large as feasible under memory/time constraints.
153
+ block_size: int = 128
154
+ # Automatic shape interpretation (for eg: [4, 3, 1024, 512] would result in
155
+ # 12 x [1024, 512] L and R statistics. Disabled by default which results in
156
+ # Shampoo constructing statistics [4, 4], [3, 3], [1024, 1024], [512, 512].
157
+ best_effort_shape_interpretation: bool = True
158
+ # Type of grafting (SGD or AdaGrad).
159
+ # https://arxiv.org/pdf/2002.11803.pdf
160
+ graft_type: int = LayerwiseGrafting.ADAGRAD
161
+ # Nesterov momentum
162
+ nesterov: bool = True
163
+
164
+
165
+ class Graft:
166
+ """Base class to perform grafting onto Shampoo. This class does no grafting.
167
+ """
168
+
169
+ def __init__(self, hps, unused_var):
170
+ self.hps = hps
171
+
172
+ def add_statistics(self, grad):
173
+ pass
174
+
175
+ def precondition_gradient(self, grad):
176
+ return grad
177
+
178
+ def update_momentum(self, update, unused_beta1):
179
+ return update
180
+
181
+
182
+ class SGDGraft(Graft):
183
+ """Graft using SGD+momentum.
184
+ momentum maintains an exponentially weighted moving average of gradients.
185
+ """
186
+
187
+ def __init__(self, hps, var):
188
+ super(SGDGraft, self).__init__(hps, var)
189
+ self.momentum = torch.zeros_like(var.data, device=var.get_device())
190
+
191
+ def update_momentum(self, update, beta1):
192
+ self.momentum.mul_(beta1).add_(update)
193
+ return self.momentum
194
+
195
+
196
+ class AdagradGraft(SGDGraft):
197
+ """Graft using Adagrad.
198
+ Essentially an implementation of Adagrad with momentum.
199
+ """
200
+
201
+ def __init__(self, hps, var):
202
+ super(AdagradGraft, self).__init__(hps, var)
203
+ self.statistics = torch.zeros_like(var.data, device=var.get_device())
204
+
205
+ def add_statistics(self, grad):
206
+ self.statistics.add_(grad * grad)
207
+
208
+ def precondition_gradient(self, grad):
209
+ return grad / (torch.sqrt(self.statistics) + self.hps.diagonal_eps)
210
+
211
+
212
+ class BlockPartitioner:
213
+ """Partitions a tensor into smaller tensors for preconditioning.
214
+ For example, if a variable has shape (4096, 512), we might split the
215
+ 4096 into 4 blocks, so we effectively have 4 variables of size
216
+ (1024, 512) each.
217
+ """
218
+
219
+ def __init__(self, var, hps):
220
+ self._shape = var.shape
221
+ self._splits = []
222
+ self._split_sizes = []
223
+ split_sizes = []
224
+ # We split var into smaller blocks. Here we store the metadata to make
225
+ # that split.
226
+ for i, d in enumerate(var.shape):
227
+ if hps.block_size > 0 and d > hps.block_size:
228
+ # d-1, otherwise split appends a 0-size array.
229
+ nsplit = (d-1) // hps.block_size
230
+ indices = (np.arange(nsplit, dtype=np.int32) + 1) * hps.block_size
231
+ sizes = np.ones(nsplit + 1, dtype=np.int32) * hps.block_size
232
+ sizes[-1] = d - indices[-1]
233
+ self._splits.append((i, indices))
234
+ self._split_sizes.append((i, sizes))
235
+ split_sizes.append(sizes)
236
+ else:
237
+ split_sizes.append(np.array([d], dtype=np.int32))
238
+ self._num_splits = len(split_sizes)
239
+ self._preconditioner_shapes = []
240
+ for t in itertools.product(*split_sizes):
241
+ self._preconditioner_shapes.extend([[d, d] for d in t])
242
+
243
+ def shapes_for_preconditioners(self):
244
+ return self._preconditioner_shapes
245
+
246
+ def num_splits(self):
247
+ return self._num_splits
248
+
249
+ def partition(self, tensor):
250
+ """Partition tensor into blocks."""
251
+
252
+ assert tensor.shape == self._shape
253
+ tensors = [tensor]
254
+ for (i, sizes) in self._split_sizes:
255
+ tensors_local = []
256
+ for t in tensors:
257
+ tensors_local.extend(
258
+ torch.split(t, tuple(sizes), dim=i))
259
+ tensors = tensors_local
260
+ return tensors
261
+
262
+ def merge_partitions(self, partitions):
263
+ """Merge partitions back to original shape."""
264
+
265
+ for (i, indices) in reversed(self._splits):
266
+ n = len(indices) + 1
267
+ partial_merged_tensors = []
268
+ ind = 0
269
+ while ind < len(partitions):
270
+ partial_merged_tensors.append(
271
+ torch.cat(partitions[ind:ind + n], axis=i))
272
+ ind += n
273
+ partitions = partial_merged_tensors
274
+ assert len(partitions) == 1
275
+ return partitions[0]
276
+
277
+
278
+ def _merge_small_dims(shape_to_merge, max_dim):
279
+ """Merge small dimensions.
280
+ If there are some small dimensions, we collapse them:
281
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
282
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
283
+ Args:
284
+ shape_to_merge: Shape to merge small dimensions.
285
+ max_dim: Maximal dimension of output shape used in merging.
286
+ Returns:
287
+ Merged shape.
288
+ """
289
+ resulting_shape = []
290
+ product = 1
291
+ for d in shape_to_merge:
292
+ if product * d <= max_dim:
293
+ product *= d
294
+ else:
295
+ if product > 1:
296
+ resulting_shape.append(product)
297
+ product = d
298
+ if product > 1:
299
+ resulting_shape.append(product)
300
+ return resulting_shape
301
+
302
+
303
+ class Preconditioner:
304
+ """Compute statistics/shape from gradients for preconditioning."""
305
+
306
+ def __init__(self, var, hps):
307
+ self._hps = hps
308
+ self._original_shape = var.shape
309
+ self._transformed_shape = var.shape
310
+ if hps.best_effort_shape_interpretation:
311
+ self._transformed_shape = _merge_small_dims(
312
+ self._original_shape, hps.block_size)
313
+
314
+ reshaped_var = torch.reshape(var, self._transformed_shape)
315
+ self._partitioner = BlockPartitioner(reshaped_var, hps)
316
+ shapes = self._partitioner.shapes_for_preconditioners()
317
+ rank = len(self._transformed_shape)
318
+ device = var.get_device()
319
+ if rank <= 1:
320
+ self.statistics = []
321
+ self.preconditioners = []
322
+ else:
323
+ eps = self._hps.matrix_eps
324
+ self.statistics = [eps * torch.eye(s[0], device=device) for s in shapes]
325
+ self.preconditioners = [torch.eye(s[0], device=device) for s in shapes]
326
+
327
+ def add_statistics(self, grad):
328
+ """Compute statistics from gradients and add to the correct state entries.
329
+ Args:
330
+ grad: Gradient to compute statistics from.
331
+ """
332
+ if not self.statistics: return
333
+ reshaped_grad = torch.reshape(grad, self._transformed_shape)
334
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
335
+ w1 = self._hps.beta2
336
+ w2 = 1.0 if w1 == 1.0 else (1.0 - w1)
337
+ rank = len(self._transformed_shape)
338
+ for j, grad in enumerate(partitioned_grads):
339
+ for i in range(rank):
340
+ axes = list(range(i)) + list(range(i + 1, rank))
341
+ stat = torch.tensordot(grad, grad, [axes, axes])
342
+ self.statistics[j*rank + i].mul_(w1).add_(stat, alpha=w2)
343
+
344
+ def exponent_for_preconditioner(self):
345
+ """Returns exponent to use for inverse-pth root M^{-1/p}."""
346
+ if self._hps.inverse_exponent_override > 0:
347
+ return self._hps.inverse_exponent_override
348
+ return 2 * len(self._transformed_shape)
349
+
350
+ def compute_preconditioners(self):
351
+ """Compute L^{-1/exp} for each stats matrix L."""
352
+ exp = self.exponent_for_preconditioner()
353
+ eps = self._hps.matrix_eps
354
+ for i, stat in enumerate(self.statistics):
355
+ self.preconditioners[i] = ComputePower(
356
+ stat, exp, ridge_epsilon=eps)
357
+
358
+ def preconditioned_grad(self, grad):
359
+ """Precondition the gradient.
360
+ Args:
361
+ grad: A gradient tensor to precondition.
362
+ Returns:
363
+ A preconditioned gradient.
364
+ """
365
+ if not self.preconditioners: return grad
366
+ reshaped_grad = torch.reshape(grad, self._transformed_shape)
367
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
368
+ preconditioned_partitioned_grads = []
369
+ num_splits = self._partitioner.num_splits()
370
+ for i, grad in enumerate(partitioned_grads):
371
+ preconditioners_for_grad = self.preconditioners[i * num_splits:(i + 1) *
372
+ num_splits]
373
+ rank = len(grad.shape)
374
+ precond_grad = grad
375
+ for j in range(rank):
376
+ preconditioner = preconditioners_for_grad[j]
377
+ precond_grad = torch.tensordot(
378
+ precond_grad, preconditioner, [[0], [0]])
379
+ preconditioned_partitioned_grads.append(precond_grad)
380
+ merged_grad = self._partitioner.merge_partitions(
381
+ preconditioned_partitioned_grads)
382
+ return torch.reshape(merged_grad, self._original_shape)
383
+
384
+
385
+ STEP = 'step'
386
+ MOMENTUM = 'momentum'
387
+ PRECONDITIONER = 'preconditioner'
388
+ GRAFT = 'graft'
389
+
390
+
391
+ class Shampoo(optim.Optimizer):
392
+ """The Shampoo optimizer."""
393
+
394
+ def __init__(self,
395
+ params,
396
+ lr=1.0,
397
+ momentum=0.9,
398
+ hyperparams=ShampooHyperParams()):
399
+ defaults = dict(lr=lr, momentum=momentum)
400
+ self.hps = hyperparams
401
+ super(Shampoo, self).__init__(params, defaults)
402
+
403
+ def init_var_state(self, var, state):
404
+ """Initialize the PyTorch state of for a single variable."""
405
+ state[STEP] = 0
406
+ state[MOMENTUM] = torch.zeros_like(var.data, device=var.get_device())
407
+ state[PRECONDITIONER] = Preconditioner(var, self.hps)
408
+ if self.hps.graft_type == LayerwiseGrafting.ADAGRAD:
409
+ state[GRAFT] = AdagradGraft(self.hps, var)
410
+ elif self.hps.graft_type == LayerwiseGrafting.SGD:
411
+ state[GRAFT] = SGDGraft(self.hps, var)
412
+ else:
413
+ state[GRAFT] = Graft(self.hps, var)
414
+
415
+ def step(self, closure=None):
416
+ hps = self.hps
417
+ for group in self.param_groups:
418
+ lr = group['lr']
419
+ for p in group['params']:
420
+ if p.grad is None: continue
421
+ grad = p.grad.data
422
+ if grad.is_sparse:
423
+ raise RuntimeError('Shampoo does not support sparse yet')
424
+ state = self.state[p]
425
+ if not state:
426
+ self.init_var_state(p, state)
427
+ state[STEP] += 1
428
+
429
+ preconditioner = state[PRECONDITIONER]
430
+ graft = state[GRAFT]
431
+
432
+ # Gather statistics, compute preconditioners
433
+ graft.add_statistics(grad)
434
+ if state[STEP] % hps.statistics_compute_steps == 0:
435
+ preconditioner.add_statistics(grad)
436
+ if state[STEP] % hps.preconditioning_compute_steps == 0:
437
+ preconditioner.compute_preconditioners()
438
+
439
+ # Precondition gradients
440
+ graft_grad = graft.precondition_gradient(grad)
441
+ shampoo_grad = grad
442
+ if state[STEP] >= self.hps.start_preconditioning_step:
443
+ shampoo_grad = preconditioner.preconditioned_grad(grad)
444
+
445
+ # Grafting
446
+ graft_norm = torch.norm(graft_grad)
447
+ shampoo_norm = torch.norm(shampoo_grad)
448
+ shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))
449
+
450
+ # Weight decay
451
+ if self.hps.weight_decay != 0.0:
452
+ shampoo_grad.add_(p.data, alpha=self.hps.weight_decay)
453
+ graft_grad.add_(p.data, alpha=self.hps.weight_decay)
454
+
455
+ # Momentum and Nesterov momentum, if needed
456
+ state[MOMENTUM].mul_(group['momentum']).add_(shampoo_grad)
457
+ graft_momentum = graft.update_momentum(grad, group['momentum'])
458
+
459
+ if state[STEP] >= self.hps.start_preconditioning_step:
460
+ momentum_update = state[MOMENTUM]
461
+ wd_update = shampoo_grad
462
+ else:
463
+ momentum_update = graft_momentum
464
+ wd_update = graft_grad
465
+
466
+ if hps.nesterov:
467
+ momentum_update.mul_(group['momentum']).add_(wd_update)
468
+
469
+ # Final update
470
+ p.data.add_(momentum_update, alpha=-lr)
raymarching/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .raymarching import *
raymarching/backend.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.cpp_extension import load
3
+
4
+ _src_path = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ nvcc_flags = [
7
+ '-O3', '-std=c++14',
8
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
+ ]
10
+
11
+ if os.name == "posix":
12
+ c_flags = ['-O3', '-std=c++14']
13
+ elif os.name == "nt":
14
+ c_flags = ['/O2', '/std:c++17']
15
+
16
+ # find cl.exe
17
+ def find_cl_path():
18
+ import glob
19
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21
+ if paths:
22
+ return paths[0]
23
+
24
+ # If cl.exe is not on path, try to find it.
25
+ if os.system("where cl.exe >nul 2>nul") != 0:
26
+ cl_path = find_cl_path()
27
+ if cl_path is None:
28
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29
+ os.environ["PATH"] += ";" + cl_path
30
+
31
+ _backend = load(name='_raymarching',
32
+ extra_cflags=c_flags,
33
+ extra_cuda_cflags=nvcc_flags,
34
+ sources=[os.path.join(_src_path, 'src', f) for f in [
35
+ 'raymarching.cu',
36
+ 'bindings.cpp',
37
+ ]],
38
+ )
39
+
40
+ __all__ = ['_backend']
raymarching/raymarching.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Function
7
+ from torch.cuda.amp import custom_bwd, custom_fwd
8
+
9
+ try:
10
+ import _raymarching as _backend
11
+ except ImportError:
12
+ from .backend import _backend
13
+
14
+
15
+ # ----------------------------------------
16
+ # utils
17
+ # ----------------------------------------
18
+
19
+ class _near_far_from_aabb(Function):
20
+ @staticmethod
21
+ @custom_fwd(cast_inputs=torch.float32)
22
+ def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
23
+ ''' near_far_from_aabb, CUDA implementation
24
+ Calculate rays' intersection time (near and far) with aabb
25
+ Args:
26
+ rays_o: float, [N, 3]
27
+ rays_d: float, [N, 3]
28
+ aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
29
+ min_near: float, scalar
30
+ Returns:
31
+ nears: float, [N]
32
+ fars: float, [N]
33
+ '''
34
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
35
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
36
+
37
+ rays_o = rays_o.contiguous().view(-1, 3)
38
+ rays_d = rays_d.contiguous().view(-1, 3)
39
+
40
+ N = rays_o.shape[0] # num rays
41
+
42
+ nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
43
+ fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
44
+
45
+ _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
46
+
47
+ return nears, fars
48
+
49
+ near_far_from_aabb = _near_far_from_aabb.apply
50
+
51
+
52
+ class _sph_from_ray(Function):
53
+ @staticmethod
54
+ @custom_fwd(cast_inputs=torch.float32)
55
+ def forward(ctx, rays_o, rays_d, radius):
56
+ ''' sph_from_ray, CUDA implementation
57
+ get spherical coordinate on the background sphere from rays.
58
+ Assume rays_o are inside the Sphere(radius).
59
+ Args:
60
+ rays_o: [N, 3]
61
+ rays_d: [N, 3]
62
+ radius: scalar, float
63
+ Return:
64
+ coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
65
+ '''
66
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
67
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
68
+
69
+ rays_o = rays_o.contiguous().view(-1, 3)
70
+ rays_d = rays_d.contiguous().view(-1, 3)
71
+
72
+ N = rays_o.shape[0] # num rays
73
+
74
+ coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
75
+
76
+ _backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
77
+
78
+ return coords
79
+
80
+ sph_from_ray = _sph_from_ray.apply
81
+
82
+
83
+ class _morton3D(Function):
84
+ @staticmethod
85
+ def forward(ctx, coords):
86
+ ''' morton3D, CUDA implementation
87
+ Args:
88
+ coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
89
+ TODO: check if the coord range is valid! (current 128 is safe)
90
+ Returns:
91
+ indices: [N], int32, in [0, 128^3)
92
+
93
+ '''
94
+ if not coords.is_cuda: coords = coords.cuda()
95
+
96
+ N = coords.shape[0]
97
+
98
+ indices = torch.empty(N, dtype=torch.int32, device=coords.device)
99
+
100
+ _backend.morton3D(coords.int(), N, indices)
101
+
102
+ return indices
103
+
104
+ morton3D = _morton3D.apply
105
+
106
+ class _morton3D_invert(Function):
107
+ @staticmethod
108
+ def forward(ctx, indices):
109
+ ''' morton3D_invert, CUDA implementation
110
+ Args:
111
+ indices: [N], int32, in [0, 128^3)
112
+ Returns:
113
+ coords: [N, 3], int32, in [0, 128)
114
+
115
+ '''
116
+ if not indices.is_cuda: indices = indices.cuda()
117
+
118
+ N = indices.shape[0]
119
+
120
+ coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
121
+
122
+ _backend.morton3D_invert(indices.int(), N, coords)
123
+
124
+ return coords
125
+
126
+ morton3D_invert = _morton3D_invert.apply
127
+
128
+
129
+ class _packbits(Function):
130
+ @staticmethod
131
+ @custom_fwd(cast_inputs=torch.float32)
132
+ def forward(ctx, grid, thresh, bitfield=None):
133
+ ''' packbits, CUDA implementation
134
+ Pack up the density grid into a bit field to accelerate ray marching.
135
+ Args:
136
+ grid: float, [C, H * H * H], assume H % 2 == 0
137
+ thresh: float, threshold
138
+ Returns:
139
+ bitfield: uint8, [C, H * H * H / 8]
140
+ '''
141
+ if not grid.is_cuda: grid = grid.cuda()
142
+ grid = grid.contiguous()
143
+
144
+ C = grid.shape[0]
145
+ H3 = grid.shape[1]
146
+ N = C * H3 // 8
147
+
148
+ if bitfield is None:
149
+ bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
150
+
151
+ _backend.packbits(grid, N, thresh, bitfield)
152
+
153
+ return bitfield
154
+
155
+ packbits = _packbits.apply
156
+
157
+ # ----------------------------------------
158
+ # train functions
159
+ # ----------------------------------------
160
+
161
+ class _march_rays_train(Function):
162
+ @staticmethod
163
+ @custom_fwd(cast_inputs=torch.float32)
164
+ def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
165
+ ''' march rays to generate points (forward only)
166
+ Args:
167
+ rays_o/d: float, [N, 3]
168
+ bound: float, scalar
169
+ density_bitfield: uint8: [CHHH // 8]
170
+ C: int
171
+ H: int
172
+ nears/fars: float, [N]
173
+ step_counter: int32, (2), used to count the actual number of generated points.
174
+ mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
175
+ perturb: bool
176
+ align: int, pad output so its size is dividable by align, set to -1 to disable.
177
+ force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
178
+ dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
179
+ max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
180
+ Returns:
181
+ xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
182
+ dirs: float, [M, 3], all generated points' view dirs.
183
+ deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)
184
+ rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0]
185
+ '''
186
+
187
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
188
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
189
+ if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
190
+
191
+ rays_o = rays_o.contiguous().view(-1, 3)
192
+ rays_d = rays_d.contiguous().view(-1, 3)
193
+ density_bitfield = density_bitfield.contiguous()
194
+
195
+ N = rays_o.shape[0] # num rays
196
+ M = N * max_steps # init max points number in total
197
+
198
+ # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
199
+ # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
200
+ if not force_all_rays and mean_count > 0:
201
+ if align > 0:
202
+ mean_count += align - mean_count % align
203
+ M = mean_count
204
+
205
+ xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
206
+ dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
207
+ deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
208
+ rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
209
+
210
+ if step_counter is None:
211
+ step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
212
+
213
+ if perturb:
214
+ noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
215
+ else:
216
+ noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
217
+
218
+ _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
219
+
220
+ #print(step_counter, M)
221
+
222
+ # only used at the first (few) epochs.
223
+ if force_all_rays or mean_count <= 0:
224
+ m = step_counter[0].item() # D2H copy
225
+ if align > 0:
226
+ m += align - m % align
227
+ xyzs = xyzs[:m]
228
+ dirs = dirs[:m]
229
+ deltas = deltas[:m]
230
+
231
+ torch.cuda.empty_cache()
232
+
233
+ return xyzs, dirs, deltas, rays
234
+
235
+ march_rays_train = _march_rays_train.apply
236
+
237
+
238
+ class _composite_rays_train(Function):
239
+ @staticmethod
240
+ @custom_fwd(cast_inputs=torch.float32)
241
+ def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4):
242
+ ''' composite rays' rgbs, according to the ray marching formula.
243
+ Args:
244
+ rgbs: float, [M, 3]
245
+ sigmas: float, [M,]
246
+ deltas: float, [M, 2]
247
+ rays: int32, [N, 3]
248
+ Returns:
249
+ weights_sum: float, [N,], the alpha channel
250
+ depth: float, [N, ], the Depth
251
+ image: float, [N, 3], the RGB channel (after multiplying alpha!)
252
+ '''
253
+
254
+ sigmas = sigmas.contiguous()
255
+ rgbs = rgbs.contiguous()
256
+
257
+ M = sigmas.shape[0]
258
+ N = rays.shape[0]
259
+
260
+ weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
261
+ depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
262
+ image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
263
+
264
+ _backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image)
265
+
266
+ ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)
267
+ ctx.dims = [M, N, T_thresh]
268
+
269
+ return weights_sum, depth, image
270
+
271
+ @staticmethod
272
+ @custom_bwd
273
+ def backward(ctx, grad_weights_sum, grad_depth, grad_image):
274
+
275
+ # NOTE: grad_depth is not used now! It won't be propagated to sigmas.
276
+
277
+ grad_weights_sum = grad_weights_sum.contiguous()
278
+ grad_image = grad_image.contiguous()
279
+
280
+ sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors
281
+ M, N, T_thresh = ctx.dims
282
+
283
+ grad_sigmas = torch.zeros_like(sigmas)
284
+ grad_rgbs = torch.zeros_like(rgbs)
285
+
286
+ _backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs)
287
+
288
+ return grad_sigmas, grad_rgbs, None, None, None
289
+
290
+
291
+ composite_rays_train = _composite_rays_train.apply
292
+
293
+ # ----------------------------------------
294
+ # infer functions
295
+ # ----------------------------------------
296
+
297
+ class _march_rays(Function):
298
+ @staticmethod
299
+ @custom_fwd(cast_inputs=torch.float32)
300
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
301
+ ''' march rays to generate points (forward only, for inference)
302
+ Args:
303
+ n_alive: int, number of alive rays
304
+ n_step: int, how many steps we march
305
+ rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
306
+ rays_t: float, [N], the alive rays' time, we only use the first n_alive.
307
+ rays_o/d: float, [N, 3]
308
+ bound: float, scalar
309
+ density_bitfield: uint8: [CHHH // 8]
310
+ C: int
311
+ H: int
312
+ nears/fars: float, [N]
313
+ align: int, pad output so its size is dividable by align, set to -1 to disable.
314
+ perturb: bool/int, int > 0 is used as the random seed.
315
+ dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
316
+ max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
317
+ Returns:
318
+ xyzs: float, [n_alive * n_step, 3], all generated points' coords
319
+ dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
320
+ deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
321
+ '''
322
+
323
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
324
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
325
+
326
+ rays_o = rays_o.contiguous().view(-1, 3)
327
+ rays_d = rays_d.contiguous().view(-1, 3)
328
+
329
+ M = n_alive * n_step
330
+
331
+ if align > 0:
332
+ M += align - (M % align)
333
+
334
+ xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
335
+ dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
336
+ deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
337
+
338
+ if perturb:
339
+ # torch.manual_seed(perturb) # test_gui uses spp index as seed
340
+ noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
341
+ else:
342
+ noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
343
+
344
+ _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
345
+
346
+ return xyzs, dirs, deltas
347
+
348
+ march_rays = _march_rays.apply
349
+
350
+
351
+ class _composite_rays(Function):
352
+ @staticmethod
353
+ @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
354
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
355
+ ''' composite rays' rgbs, according to the ray marching formula. (for inference)
356
+ Args:
357
+ n_alive: int, number of alive rays
358
+ n_step: int, how many steps we march
359
+ rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
360
+ rays_t: float, [N], the alive rays' time
361
+ sigmas: float, [n_alive * n_step,]
362
+ rgbs: float, [n_alive * n_step, 3]
363
+ deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
364
+ In-place Outputs:
365
+ weights_sum: float, [N,], the alpha channel
366
+ depth: float, [N,], the depth value
367
+ image: float, [N, 3], the RGB channel (after multiplying alpha!)
368
+ '''
369
+ _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
370
+ return tuple()
371
+
372
+
373
+ composite_rays = _composite_rays.apply
raymarching/setup.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
+
5
+ _src_path = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ nvcc_flags = [
8
+ '-O3', '-std=c++14',
9
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
+ ]
11
+
12
+ if os.name == "posix":
13
+ c_flags = ['-O3', '-std=c++14']
14
+ elif os.name == "nt":
15
+ c_flags = ['/O2', '/std:c++17']
16
+
17
+ # find cl.exe
18
+ def find_cl_path():
19
+ import glob
20
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
+ if paths:
23
+ return paths[0]
24
+
25
+ # If cl.exe is not on path, try to find it.
26
+ if os.system("where cl.exe >nul 2>nul") != 0:
27
+ cl_path = find_cl_path()
28
+ if cl_path is None:
29
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
+ os.environ["PATH"] += ";" + cl_path
31
+
32
+ '''
33
+ Usage:
34
+
35
+ python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
36
+
37
+ python setup.py install # build extensions and install (copy) to PATH.
38
+ pip install . # ditto but better (e.g., dependency & metadata handling)
39
+
40
+ python setup.py develop # build extensions and install (symbolic) to PATH.
41
+ pip install -e . # ditto but better (e.g., dependency & metadata handling)
42
+
43
+ '''
44
+ setup(
45
+ name='raymarching', # package name, import this to use python API
46
+ ext_modules=[
47
+ CUDAExtension(
48
+ name='_raymarching', # extension name, import this to use CUDA API
49
+ sources=[os.path.join(_src_path, 'src', f) for f in [
50
+ 'raymarching.cu',
51
+ 'bindings.cpp',
52
+ ]],
53
+ extra_compile_args={
54
+ 'cxx': c_flags,
55
+ 'nvcc': nvcc_flags,
56
+ }
57
+ ),
58
+ ],
59
+ cmdclass={
60
+ 'build_ext': BuildExtension,
61
+ }
62
+ )
raymarching/src/bindings.cpp ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include "raymarching.h"
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ // utils
7
+ m.def("packbits", &packbits, "packbits (CUDA)");
8
+ m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
9
+ m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
10
+ m.def("morton3D", &morton3D, "morton3D (CUDA)");
11
+ m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
12
+ // train
13
+ m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
14
+ m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
15
+ m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
16
+ // infer
17
+ m.def("march_rays", &march_rays, "march rays (CUDA)");
18
+ m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
19
+ }
raymarching/src/raymarching.cu ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda.h>
2
+ #include <cuda_fp16.h>
3
+ #include <cuda_runtime.h>
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <torch/torch.h>
7
+
8
+ #include <cstdio>
9
+ #include <stdint.h>
10
+ #include <stdexcept>
11
+ #include <limits>
12
+
13
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
14
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
15
+ #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
16
+ #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
17
+
18
+
19
+ inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
20
+ inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
21
+ inline constexpr __device__ float PI() { return 3.141592653589793f; }
22
+ inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
23
+
24
+
25
+ template <typename T>
26
+ inline __host__ __device__ T div_round_up(T val, T divisor) {
27
+ return (val + divisor - 1) / divisor;
28
+ }
29
+
30
+ inline __host__ __device__ float signf(const float x) {
31
+ return copysignf(1.0, x);
32
+ }
33
+
34
+ inline __host__ __device__ float clamp(const float x, const float min, const float max) {
35
+ return fminf(max, fmaxf(min, x));
36
+ }
37
+
38
+ inline __host__ __device__ void swapf(float& a, float& b) {
39
+ float c = a; a = b; b = c;
40
+ }
41
+
42
+ inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
43
+ const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));
44
+ int exponent;
45
+ frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
46
+ return fminf(max_cascade - 1, fmaxf(0, exponent));
47
+ }
48
+
49
+ inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
50
+ const float mx = dt * H * 0.5;
51
+ int exponent;
52
+ frexpf(mx, &exponent);
53
+ return fminf(max_cascade - 1, fmaxf(0, exponent));
54
+ }
55
+
56
+ inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
57
+ {
58
+ v = (v * 0x00010001u) & 0xFF0000FFu;
59
+ v = (v * 0x00000101u) & 0x0F00F00Fu;
60
+ v = (v * 0x00000011u) & 0xC30C30C3u;
61
+ v = (v * 0x00000005u) & 0x49249249u;
62
+ return v;
63
+ }
64
+
65
+ inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
66
+ {
67
+ uint32_t xx = __expand_bits(x);
68
+ uint32_t yy = __expand_bits(y);
69
+ uint32_t zz = __expand_bits(z);
70
+ return xx | (yy << 1) | (zz << 2);
71
+ }
72
+
73
+ inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
74
+ {
75
+ x = x & 0x49249249;
76
+ x = (x | (x >> 2)) & 0xc30c30c3;
77
+ x = (x | (x >> 4)) & 0x0f00f00f;
78
+ x = (x | (x >> 8)) & 0xff0000ff;
79
+ x = (x | (x >> 16)) & 0x0000ffff;
80
+ return x;
81
+ }
82
+
83
+
84
+ ////////////////////////////////////////////////////
85
+ ///////////// utils /////////////
86
+ ////////////////////////////////////////////////////
87
+
88
+ // rays_o/d: [N, 3]
89
+ // nears/fars: [N]
90
+ // scalar_t should always be float in use.
91
+ template <typename scalar_t>
92
+ __global__ void kernel_near_far_from_aabb(
93
+ const scalar_t * __restrict__ rays_o,
94
+ const scalar_t * __restrict__ rays_d,
95
+ const scalar_t * __restrict__ aabb,
96
+ const uint32_t N,
97
+ const float min_near,
98
+ scalar_t * nears, scalar_t * fars
99
+ ) {
100
+ // parallel per ray
101
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
102
+ if (n >= N) return;
103
+
104
+ // locate
105
+ rays_o += n * 3;
106
+ rays_d += n * 3;
107
+
108
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
109
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
110
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
111
+
112
+ // get near far (assume cube scene)
113
+ float near = (aabb[0] - ox) * rdx;
114
+ float far = (aabb[3] - ox) * rdx;
115
+ if (near > far) swapf(near, far);
116
+
117
+ float near_y = (aabb[1] - oy) * rdy;
118
+ float far_y = (aabb[4] - oy) * rdy;
119
+ if (near_y > far_y) swapf(near_y, far_y);
120
+
121
+ if (near > far_y || near_y > far) {
122
+ nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
123
+ return;
124
+ }
125
+
126
+ if (near_y > near) near = near_y;
127
+ if (far_y < far) far = far_y;
128
+
129
+ float near_z = (aabb[2] - oz) * rdz;
130
+ float far_z = (aabb[5] - oz) * rdz;
131
+ if (near_z > far_z) swapf(near_z, far_z);
132
+
133
+ if (near > far_z || near_z > far) {
134
+ nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
135
+ return;
136
+ }
137
+
138
+ if (near_z > near) near = near_z;
139
+ if (far_z < far) far = far_z;
140
+
141
+ if (near < min_near) near = min_near;
142
+
143
+ nears[n] = near;
144
+ fars[n] = far;
145
+ }
146
+
147
+
148
+ void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
149
+
150
+ static constexpr uint32_t N_THREAD = 128;
151
+
152
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
153
+ rays_o.scalar_type(), "near_far_from_aabb", ([&] {
154
+ kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>());
155
+ }));
156
+ }
157
+
158
+
159
+ // rays_o/d: [N, 3]
160
+ // radius: float
161
+ // coords: [N, 2]
162
+ template <typename scalar_t>
163
+ __global__ void kernel_sph_from_ray(
164
+ const scalar_t * __restrict__ rays_o,
165
+ const scalar_t * __restrict__ rays_d,
166
+ const float radius,
167
+ const uint32_t N,
168
+ scalar_t * coords
169
+ ) {
170
+ // parallel per ray
171
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
172
+ if (n >= N) return;
173
+
174
+ // locate
175
+ rays_o += n * 3;
176
+ rays_d += n * 3;
177
+ coords += n * 2;
178
+
179
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
180
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
181
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
182
+
183
+ // solve t from || o + td || = radius
184
+ const float A = dx * dx + dy * dy + dz * dz;
185
+ const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
186
+ const float C = ox * ox + oy * oy + oz * oz - radius * radius;
187
+
188
+ const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
189
+
190
+ // solve theta, phi (assume y is the up axis)
191
+ const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
192
+ const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
193
+ const float phi = atan2(z, x); // [-PI, PI)
194
+
195
+ // normalize to [-1, 1]
196
+ coords[0] = 2 * theta * RPI() - 1;
197
+ coords[1] = phi * RPI();
198
+ }
199
+
200
+
201
+ void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
202
+
203
+ static constexpr uint32_t N_THREAD = 128;
204
+
205
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
206
+ rays_o.scalar_type(), "sph_from_ray", ([&] {
207
+ kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>());
208
+ }));
209
+ }
210
+
211
+
212
+ // coords: int32, [N, 3]
213
+ // indices: int32, [N]
214
+ __global__ void kernel_morton3D(
215
+ const int * __restrict__ coords,
216
+ const uint32_t N,
217
+ int * indices
218
+ ) {
219
+ // parallel
220
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
221
+ if (n >= N) return;
222
+
223
+ // locate
224
+ coords += n * 3;
225
+ indices[n] = __morton3D(coords[0], coords[1], coords[2]);
226
+ }
227
+
228
+
229
+ void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
230
+ static constexpr uint32_t N_THREAD = 128;
231
+ kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>());
232
+ }
233
+
234
+
235
+ // indices: int32, [N]
236
+ // coords: int32, [N, 3]
237
+ __global__ void kernel_morton3D_invert(
238
+ const int * __restrict__ indices,
239
+ const uint32_t N,
240
+ int * coords
241
+ ) {
242
+ // parallel
243
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
244
+ if (n >= N) return;
245
+
246
+ // locate
247
+ coords += n * 3;
248
+
249
+ const int ind = indices[n];
250
+
251
+ coords[0] = __morton3D_invert(ind >> 0);
252
+ coords[1] = __morton3D_invert(ind >> 1);
253
+ coords[2] = __morton3D_invert(ind >> 2);
254
+ }
255
+
256
+
257
+ void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
258
+ static constexpr uint32_t N_THREAD = 128;
259
+ kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>());
260
+ }
261
+
262
+
263
+ // grid: float, [C, H, H, H]
264
+ // N: int, C * H * H * H / 8
265
+ // density_thresh: float
266
+ // bitfield: uint8, [N]
267
+ template <typename scalar_t>
268
+ __global__ void kernel_packbits(
269
+ const scalar_t * __restrict__ grid,
270
+ const uint32_t N,
271
+ const float density_thresh,
272
+ uint8_t * bitfield
273
+ ) {
274
+ // parallel per byte
275
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
276
+ if (n >= N) return;
277
+
278
+ // locate
279
+ grid += n * 8;
280
+
281
+ uint8_t bits = 0;
282
+
283
+ #pragma unroll
284
+ for (uint8_t i = 0; i < 8; i++) {
285
+ bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
286
+ }
287
+
288
+ bitfield[n] = bits;
289
+ }
290
+
291
+
292
+ void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
293
+
294
+ static constexpr uint32_t N_THREAD = 128;
295
+
296
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
297
+ grid.scalar_type(), "packbits", ([&] {
298
+ kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>());
299
+ }));
300
+ }
301
+
302
+ ////////////////////////////////////////////////////
303
+ ///////////// training /////////////
304
+ ////////////////////////////////////////////////////
305
+
306
+ // rays_o/d: [N, 3]
307
+ // grid: [CHHH / 8]
308
+ // xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2]
309
+ // dirs: [M, 3]
310
+ // rays: [N, 3], idx, offset, num_steps
311
+ template <typename scalar_t>
312
+ __global__ void kernel_march_rays_train(
313
+ const scalar_t * __restrict__ rays_o,
314
+ const scalar_t * __restrict__ rays_d,
315
+ const uint8_t * __restrict__ grid,
316
+ const float bound,
317
+ const float dt_gamma, const uint32_t max_steps,
318
+ const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M,
319
+ const scalar_t* __restrict__ nears,
320
+ const scalar_t* __restrict__ fars,
321
+ scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas,
322
+ int * rays,
323
+ int * counter,
324
+ const scalar_t* __restrict__ noises
325
+ ) {
326
+ // parallel per ray
327
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
328
+ if (n >= N) return;
329
+
330
+ // locate
331
+ rays_o += n * 3;
332
+ rays_d += n * 3;
333
+
334
+ // ray marching
335
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
336
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
337
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
338
+ const float rH = 1 / (float)H;
339
+ const float H3 = H * H * H;
340
+
341
+ const float near = nears[n];
342
+ const float far = fars[n];
343
+ const float noise = noises[n];
344
+
345
+ const float dt_min = 2 * SQRT3() / max_steps;
346
+ const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
347
+
348
+ float t0 = near;
349
+
350
+ // perturb
351
+ t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
352
+
353
+ // first pass: estimation of num_steps
354
+ float t = t0;
355
+ uint32_t num_steps = 0;
356
+
357
+ //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
358
+
359
+ while (t < far && num_steps < max_steps) {
360
+ // current point
361
+ const float x = clamp(ox + t * dx, -bound, bound);
362
+ const float y = clamp(oy + t * dy, -bound, bound);
363
+ const float z = clamp(oz + t * dz, -bound, bound);
364
+
365
+ const float dt = clamp(t * dt_gamma, dt_min, dt_max);
366
+
367
+ // get mip level
368
+ const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
369
+
370
+ const float mip_bound = fminf(scalbnf(1.0f, level), bound);
371
+ const float mip_rbound = 1 / mip_bound;
372
+
373
+ // convert to nearest grid position
374
+ const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
375
+ const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
376
+ const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
377
+
378
+ const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
379
+ const bool occ = grid[index / 8] & (1 << (index % 8));
380
+
381
+ // if occpuied, advance a small step, and write to output
382
+ //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps);
383
+
384
+ if (occ) {
385
+ num_steps++;
386
+ t += dt;
387
+ // else, skip a large step (basically skip a voxel grid)
388
+ } else {
389
+ // calc distance to next voxel
390
+ const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
391
+ const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
392
+ const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
393
+
394
+ const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
395
+ // step until next voxel
396
+ do {
397
+ t += clamp(t * dt_gamma, dt_min, dt_max);
398
+ } while (t < tt);
399
+ }
400
+ }
401
+
402
+ //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min);
403
+
404
+ // second pass: really locate and write points & dirs
405
+ uint32_t point_index = atomicAdd(counter, num_steps);
406
+ uint32_t ray_index = atomicAdd(counter + 1, 1);
407
+
408
+ //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index);
409
+
410
+ // write rays
411
+ rays[ray_index * 3] = n;
412
+ rays[ray_index * 3 + 1] = point_index;
413
+ rays[ray_index * 3 + 2] = num_steps;
414
+
415
+ if (num_steps == 0) return;
416
+ if (point_index + num_steps > M) return;
417
+
418
+ xyzs += point_index * 3;
419
+ dirs += point_index * 3;
420
+ deltas += point_index * 2;
421
+
422
+ t = t0;
423
+ uint32_t step = 0;
424
+
425
+ float last_t = t;
426
+
427
+ while (t < far && step < num_steps) {
428
+ // current point
429
+ const float x = clamp(ox + t * dx, -bound, bound);
430
+ const float y = clamp(oy + t * dy, -bound, bound);
431
+ const float z = clamp(oz + t * dz, -bound, bound);
432
+
433
+ const float dt = clamp(t * dt_gamma, dt_min, dt_max);
434
+
435
+ // get mip level
436
+ const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
437
+
438
+ const float mip_bound = fminf(scalbnf(1.0f, level), bound);
439
+ const float mip_rbound = 1 / mip_bound;
440
+
441
+ // convert to nearest grid position
442
+ const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
443
+ const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
444
+ const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
445
+
446
+ // query grid
447
+ const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
448
+ const bool occ = grid[index / 8] & (1 << (index % 8));
449
+
450
+ // if occpuied, advance a small step, and write to output
451
+ if (occ) {
452
+ // write step
453
+ xyzs[0] = x;
454
+ xyzs[1] = y;
455
+ xyzs[2] = z;
456
+ dirs[0] = dx;
457
+ dirs[1] = dy;
458
+ dirs[2] = dz;
459
+ t += dt;
460
+ deltas[0] = dt;
461
+ deltas[1] = t - last_t; // used to calc depth
462
+ last_t = t;
463
+ xyzs += 3;
464
+ dirs += 3;
465
+ deltas += 2;
466
+ step++;
467
+ // else, skip a large step (basically skip a voxel grid)
468
+ } else {
469
+ // calc distance to next voxel
470
+ const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
471
+ const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
472
+ const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
473
+ const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
474
+ // step until next voxel
475
+ do {
476
+ t += clamp(t * dt_gamma, dt_min, dt_max);
477
+ } while (t < tt);
478
+ }
479
+ }
480
+ }
481
+
482
+ void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
483
+
484
+ static constexpr uint32_t N_THREAD = 128;
485
+
486
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
487
+ rays_o.scalar_type(), "march_rays_train", ([&] {
488
+ kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>());
489
+ }));
490
+ }
491
+
492
+
493
+ // sigmas: [M]
494
+ // rgbs: [M, 3]
495
+ // deltas: [M, 2]
496
+ // rays: [N, 3], idx, offset, num_steps
497
+ // weights_sum: [N], final pixel alpha
498
+ // depth: [N,]
499
+ // image: [N, 3]
500
+ template <typename scalar_t>
501
+ __global__ void kernel_composite_rays_train_forward(
502
+ const scalar_t * __restrict__ sigmas,
503
+ const scalar_t * __restrict__ rgbs,
504
+ const scalar_t * __restrict__ deltas,
505
+ const int * __restrict__ rays,
506
+ const uint32_t M, const uint32_t N, const float T_thresh,
507
+ scalar_t * weights_sum,
508
+ scalar_t * depth,
509
+ scalar_t * image
510
+ ) {
511
+ // parallel per ray
512
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
513
+ if (n >= N) return;
514
+
515
+ // locate
516
+ uint32_t index = rays[n * 3];
517
+ uint32_t offset = rays[n * 3 + 1];
518
+ uint32_t num_steps = rays[n * 3 + 2];
519
+
520
+ // empty ray, or ray that exceed max step count.
521
+ if (num_steps == 0 || offset + num_steps > M) {
522
+ weights_sum[index] = 0;
523
+ depth[index] = 0;
524
+ image[index * 3] = 0;
525
+ image[index * 3 + 1] = 0;
526
+ image[index * 3 + 2] = 0;
527
+ return;
528
+ }
529
+
530
+ sigmas += offset;
531
+ rgbs += offset * 3;
532
+ deltas += offset * 2;
533
+
534
+ // accumulate
535
+ uint32_t step = 0;
536
+
537
+ scalar_t T = 1.0f;
538
+ scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0;
539
+
540
+ while (step < num_steps) {
541
+
542
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
543
+ const scalar_t weight = alpha * T;
544
+
545
+ r += weight * rgbs[0];
546
+ g += weight * rgbs[1];
547
+ b += weight * rgbs[2];
548
+
549
+ t += deltas[1]; // real delta
550
+ d += weight * t;
551
+
552
+ ws += weight;
553
+
554
+ T *= 1.0f - alpha;
555
+
556
+ // minimal remained transmittence
557
+ if (T < T_thresh) break;
558
+
559
+ //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
560
+
561
+ // locate
562
+ sigmas++;
563
+ rgbs += 3;
564
+ deltas += 2;
565
+
566
+ step++;
567
+ }
568
+
569
+ //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
570
+
571
+ // write
572
+ weights_sum[index] = ws; // weights_sum
573
+ depth[index] = d;
574
+ image[index * 3] = r;
575
+ image[index * 3 + 1] = g;
576
+ image[index * 3 + 2] = b;
577
+ }
578
+
579
+
580
+ void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {
581
+
582
+ static constexpr uint32_t N_THREAD = 128;
583
+
584
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
585
+ sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
586
+ kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
587
+ }));
588
+ }
589
+
590
+
591
+ // grad_weights_sum: [N,]
592
+ // grad: [N, 3]
593
+ // sigmas: [M]
594
+ // rgbs: [M, 3]
595
+ // deltas: [M, 2]
596
+ // rays: [N, 3], idx, offset, num_steps
597
+ // weights_sum: [N,], weights_sum here
598
+ // image: [N, 3]
599
+ // grad_sigmas: [M]
600
+ // grad_rgbs: [M, 3]
601
+ template <typename scalar_t>
602
+ __global__ void kernel_composite_rays_train_backward(
603
+ const scalar_t * __restrict__ grad_weights_sum,
604
+ const scalar_t * __restrict__ grad_image,
605
+ const scalar_t * __restrict__ sigmas,
606
+ const scalar_t * __restrict__ rgbs,
607
+ const scalar_t * __restrict__ deltas,
608
+ const int * __restrict__ rays,
609
+ const scalar_t * __restrict__ weights_sum,
610
+ const scalar_t * __restrict__ image,
611
+ const uint32_t M, const uint32_t N, const float T_thresh,
612
+ scalar_t * grad_sigmas,
613
+ scalar_t * grad_rgbs
614
+ ) {
615
+ // parallel per ray
616
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
617
+ if (n >= N) return;
618
+
619
+ // locate
620
+ uint32_t index = rays[n * 3];
621
+ uint32_t offset = rays[n * 3 + 1];
622
+ uint32_t num_steps = rays[n * 3 + 2];
623
+
624
+ if (num_steps == 0 || offset + num_steps > M) return;
625
+
626
+ grad_weights_sum += index;
627
+ grad_image += index * 3;
628
+ weights_sum += index;
629
+ image += index * 3;
630
+ sigmas += offset;
631
+ rgbs += offset * 3;
632
+ deltas += offset * 2;
633
+ grad_sigmas += offset;
634
+ grad_rgbs += offset * 3;
635
+
636
+ // accumulate
637
+ uint32_t step = 0;
638
+
639
+ scalar_t T = 1.0f;
640
+ const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0];
641
+ scalar_t r = 0, g = 0, b = 0, ws = 0;
642
+
643
+ while (step < num_steps) {
644
+
645
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
646
+ const scalar_t weight = alpha * T;
647
+
648
+ r += weight * rgbs[0];
649
+ g += weight * rgbs[1];
650
+ b += weight * rgbs[2];
651
+ ws += weight;
652
+
653
+ T *= 1.0f - alpha;
654
+
655
+ // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
656
+ // write grad_rgbs
657
+ grad_rgbs[0] = grad_image[0] * weight;
658
+ grad_rgbs[1] = grad_image[1] * weight;
659
+ grad_rgbs[2] = grad_image[2] * weight;
660
+
661
+ // write grad_sigmas
662
+ grad_sigmas[0] = deltas[0] * (
663
+ grad_image[0] * (T * rgbs[0] - (r_final - r)) +
664
+ grad_image[1] * (T * rgbs[1] - (g_final - g)) +
665
+ grad_image[2] * (T * rgbs[2] - (b_final - b)) +
666
+ grad_weights_sum[0] * (1 - ws_final)
667
+ );
668
+
669
+ //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
670
+ // minimal remained transmittence
671
+ if (T < T_thresh) break;
672
+
673
+ // locate
674
+ sigmas++;
675
+ rgbs += 3;
676
+ deltas += 2;
677
+ grad_sigmas++;
678
+ grad_rgbs += 3;
679
+
680
+ step++;
681
+ }
682
+ }
683
+
684
+
685
+ void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {
686
+
687
+ static constexpr uint32_t N_THREAD = 128;
688
+
689
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
690
+ grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
691
+ kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>());
692
+ }));
693
+ }
694
+
695
+
696
+ ////////////////////////////////////////////////////
697
+ ///////////// infernce /////////////
698
+ ////////////////////////////////////////////////////
699
+
700
+ template <typename scalar_t>
701
+ __global__ void kernel_march_rays(
702
+ const uint32_t n_alive,
703
+ const uint32_t n_step,
704
+ const int* __restrict__ rays_alive,
705
+ const scalar_t* __restrict__ rays_t,
706
+ const scalar_t* __restrict__ rays_o,
707
+ const scalar_t* __restrict__ rays_d,
708
+ const float bound,
709
+ const float dt_gamma, const uint32_t max_steps,
710
+ const uint32_t C, const uint32_t H,
711
+ const uint8_t * __restrict__ grid,
712
+ const scalar_t* __restrict__ nears,
713
+ const scalar_t* __restrict__ fars,
714
+ scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas,
715
+ const scalar_t* __restrict__ noises
716
+ ) {
717
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
718
+ if (n >= n_alive) return;
719
+
720
+ const int index = rays_alive[n]; // ray id
721
+ const float noise = noises[n];
722
+
723
+ // locate
724
+ rays_o += index * 3;
725
+ rays_d += index * 3;
726
+ xyzs += n * n_step * 3;
727
+ dirs += n * n_step * 3;
728
+ deltas += n * n_step * 2;
729
+
730
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
731
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
732
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
733
+ const float rH = 1 / (float)H;
734
+ const float H3 = H * H * H;
735
+
736
+ float t = rays_t[index]; // current ray's t
737
+ const float near = nears[index], far = fars[index];
738
+
739
+ const float dt_min = 2 * SQRT3() / max_steps;
740
+ const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
741
+
742
+ // march for n_step steps, record points
743
+ uint32_t step = 0;
744
+
745
+ // introduce some randomness
746
+ t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
747
+
748
+ float last_t = t;
749
+
750
+ while (t < far && step < n_step) {
751
+ // current point
752
+ const float x = clamp(ox + t * dx, -bound, bound);
753
+ const float y = clamp(oy + t * dy, -bound, bound);
754
+ const float z = clamp(oz + t * dz, -bound, bound);
755
+
756
+ const float dt = clamp(t * dt_gamma, dt_min, dt_max);
757
+
758
+ // get mip level
759
+ const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
760
+
761
+ const float mip_bound = fminf(scalbnf(1, level), bound);
762
+ const float mip_rbound = 1 / mip_bound;
763
+
764
+ // convert to nearest grid position
765
+ const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
766
+ const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
767
+ const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
768
+
769
+ const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
770
+ const bool occ = grid[index / 8] & (1 << (index % 8));
771
+
772
+ // if occpuied, advance a small step, and write to output
773
+ if (occ) {
774
+ // write step
775
+ xyzs[0] = x;
776
+ xyzs[1] = y;
777
+ xyzs[2] = z;
778
+ dirs[0] = dx;
779
+ dirs[1] = dy;
780
+ dirs[2] = dz;
781
+ // calc dt
782
+ t += dt;
783
+ deltas[0] = dt;
784
+ deltas[1] = t - last_t; // used to calc depth
785
+ last_t = t;
786
+ // step
787
+ xyzs += 3;
788
+ dirs += 3;
789
+ deltas += 2;
790
+ step++;
791
+
792
+ // else, skip a large step (basically skip a voxel grid)
793
+ } else {
794
+ // calc distance to next voxel
795
+ const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
796
+ const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
797
+ const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
798
+ const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
799
+ // step until next voxel
800
+ do {
801
+ t += clamp(t * dt_gamma, dt_min, dt_max);
802
+ } while (t < tt);
803
+ }
804
+ }
805
+ }
806
+
807
+
808
+ void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) {
809
+ static constexpr uint32_t N_THREAD = 128;
810
+
811
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
812
+ rays_o.scalar_type(), "march_rays", ([&] {
813
+ kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>());
814
+ }));
815
+ }
816
+
817
+
818
+ template <typename scalar_t>
819
+ __global__ void kernel_composite_rays(
820
+ const uint32_t n_alive,
821
+ const uint32_t n_step,
822
+ const float T_thresh,
823
+ int* rays_alive,
824
+ scalar_t* rays_t,
825
+ const scalar_t* __restrict__ sigmas,
826
+ const scalar_t* __restrict__ rgbs,
827
+ const scalar_t* __restrict__ deltas,
828
+ scalar_t* weights_sum, scalar_t* depth, scalar_t* image
829
+ ) {
830
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
831
+ if (n >= n_alive) return;
832
+
833
+ const int index = rays_alive[n]; // ray id
834
+
835
+ // locate
836
+ sigmas += n * n_step;
837
+ rgbs += n * n_step * 3;
838
+ deltas += n * n_step * 2;
839
+
840
+ rays_t += index;
841
+ weights_sum += index;
842
+ depth += index;
843
+ image += index * 3;
844
+
845
+ scalar_t t = rays_t[0]; // current ray's t
846
+
847
+ scalar_t weight_sum = weights_sum[0];
848
+ scalar_t d = depth[0];
849
+ scalar_t r = image[0];
850
+ scalar_t g = image[1];
851
+ scalar_t b = image[2];
852
+
853
+ // accumulate
854
+ uint32_t step = 0;
855
+ while (step < n_step) {
856
+
857
+ // ray is terminated if delta == 0
858
+ if (deltas[0] == 0) break;
859
+
860
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
861
+
862
+ /*
863
+ T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
864
+ w_i = alpha_i * T_i
865
+ -->
866
+ T_i = 1 - \sum_{j=0}^{i-1} w_j
867
+ */
868
+ const scalar_t T = 1 - weight_sum;
869
+ const scalar_t weight = alpha * T;
870
+ weight_sum += weight;
871
+
872
+ t += deltas[1]; // real delta
873
+ d += weight * t;
874
+ r += weight * rgbs[0];
875
+ g += weight * rgbs[1];
876
+ b += weight * rgbs[2];
877
+
878
+ //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
879
+
880
+ // ray is terminated if T is too small
881
+ // use a larger bound to further accelerate inference
882
+ if (T < T_thresh) break;
883
+
884
+ // locate
885
+ sigmas++;
886
+ rgbs += 3;
887
+ deltas += 2;
888
+ step++;
889
+ }
890
+
891
+ //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
892
+
893
+ // rays_alive = -1 means ray is terminated early.
894
+ if (step < n_step) {
895
+ rays_alive[n] = -1;
896
+ } else {
897
+ rays_t[0] = t;
898
+ }
899
+
900
+ weights_sum[0] = weight_sum; // this is the thing I needed!
901
+ depth[0] = d;
902
+ image[0] = r;
903
+ image[1] = g;
904
+ image[2] = b;
905
+ }
906
+
907
+
908
+ void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {
909
+ static constexpr uint32_t N_THREAD = 128;
910
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
911
+ image.scalar_type(), "composite_rays", ([&] {
912
+ kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
913
+ }));
914
+ }
raymarching/src/raymarching.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <stdint.h>
4
+ #include <torch/torch.h>
5
+
6
+
7
+ void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
8
+ void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
9
+ void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
10
+ void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
11
+ void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
12
+
13
+ void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
14
+ void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
15
+ void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
16
+
17
+ void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
18
+ void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
readme.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stable-Dreamfusion
2
+
3
+ A pytorch implementation of the text-to-3D model **Dreamfusion**, powered by the [Stable Diffusion](https://github.com/CompVis/stable-diffusion) text-to-2D model.
4
+
5
+ The original paper's project page: [_DreamFusion: Text-to-3D using 2D Diffusion_](https://dreamfusion3d.github.io/).
6
+
7
+ Example of "a squierrel" and "a hamburger":
8
+
9
+ ### [Gallery](assets/gallery.md) | [Update Logs](assets/update_logs.md)
10
+
11
+ # Important Notice
12
+ This project is a **work-in-progress**, and contains lots of differences from the paper. Also, many features are still not implmented now. The current generation quality cannot match the results from the original paper, and still fail badly for many prompts.
13
+
14
+ ## Notable differences from the paper
15
+ * Since the Imagen model is not publicly available, we use [Stable Diffusion](https://github.com/CompVis/stable-diffusion) to replace it (implementation from [diffusers](https://github.com/huggingface/diffusers)). Different from Imagen, Stable-Diffusion is a latent diffusion model, which diffuses in a latent space instead of the original image space. Therefore, we need the loss to propagate back from the VAE's encoder part too, which introduces extra time cost in training. Currently, 15000 training steps take about 5 hours to train on a V100.
16
+ * We use the [multi-resolution grid encoder](https://github.com/NVlabs/instant-ngp/) to implement the NeRF backbone (implementation from [torch-ngp](https://github.com/ashawkey/torch-ngp)), which enables much faster rendering (~10FPS at 800x800).
17
+ * We use the Adam optimizer with a larger initial learning rate.
18
+
19
+
20
+ ## TODOs
21
+ * The shading part & normal evaluation.
22
+ * Exporting colored mesh.
23
+
24
+
25
+ # Install
26
+
27
+ ```bash
28
+ git clone https://github.com/ashawkey/stable-dreamfusion.git
29
+ cd stable-dreamfusion
30
+ ```
31
+
32
+ **Important**: To download the Stable Diffusion model checkpoint, you should create a file under this directory called `TOKEN` and copy your hugging face [access token](https://huggingface.co/docs/hub/security-tokens) into it.
33
+
34
+ ### Install with pip
35
+ ```bash
36
+ pip install -r requirements.txt
37
+
38
+ # (optional) install the tcnn backbone if using --tcnn
39
+ pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
40
+
41
+ # (optional) install CLIP guidance for the dreamfield setting
42
+ pip install git+https://github.com/openai/CLIP.git
43
+
44
+ # (optional) install nvdiffrast for exporting textured mesh
45
+ pip install git+https://github.com/NVlabs/nvdiffrast/
46
+ ```
47
+
48
+ ### Build extension (optional)
49
+ By default, we use [`load`](https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load) to build the extension at runtime.
50
+ We also provide the `setup.py` to build each extension:
51
+ ```bash
52
+ # install all extension modules
53
+ bash scripts/install_ext.sh
54
+
55
+ # if you want to install manually, here is an example:
56
+ pip install ./raymarching # install to python path (you still need the raymarching/ folder, since this only install the built extension.)
57
+ ```
58
+
59
+ ### Tested environments
60
+ * Ubuntu 22 with torch 1.12 & CUDA 11.6 on a V100.
61
+
62
+
63
+ # Usage
64
+
65
+ First time running will take some time to compile the CUDA extensions.
66
+
67
+ ```bash
68
+ ### stable-dreamfusion setting
69
+ # train with text prompt
70
+ # `-O` equals `--cuda_ray --fp16 --dir_text`
71
+ python main_nerf.py --text "a hamburger" --workspace trial -O
72
+
73
+ # test (exporting 360 video)
74
+ python main_nerf.py --text "a hamburger" --workspace trial -O --test
75
+
76
+ # test with a GUI (free view control!)
77
+ python main_nerf.py --text "a hamburger" --workspace trial -O --test --gui
78
+
79
+ ### dreamfields (CLIP) setting
80
+ python main_nerf.py --text "a hamburger" --workspace trial_clip -O --guidance clip
81
+ python main_nerf.py --text "a hamburger" --workspace trial_clip -O --test --gui --guidance clip
82
+ ```
83
+
84
+ # Acknowledgement
85
+
86
+ * The amazing original work: [_DreamFusion: Text-to-3D using 2D Diffusion_](https://dreamfusion3d.github.io/).
87
+
88
+ * Huge thanks to the [Stable Diffusion](https://github.com/CompVis/stable-diffusion) and the [diffusers](https://github.com/huggingface/diffusers) library.
89
+
90
+
91
+ * The GUI is developed with [DearPyGui](https://github.com/hoffstadt/DearPyGui).
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch-ema
2
+ ninja
3
+ trimesh
4
+ opencv-python
5
+ tensorboardX
6
+ torch
7
+ numpy
8
+ pandas
9
+ tqdm
10
+ matplotlib
11
+ PyMCubes
12
+ rich
13
+ pysdf
14
+ dearpygui
15
+ scipy
16
+ diffusers
17
+ xatlas
scripts/install_ext.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pip install ./raymarching
2
+ pip install ./shencoder
3
+ pip install ./freqencoder
4
+ pip install ./gridencoder
scripts/run.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ CUDA_VISIBLE_DEVICES=1 python main_nerf.py -O --text "a DSLR photo of cthulhu" --workspace trial_cthulhu
4
+ CUDA_VISIBLE_DEVICES=1 python main_nerf.py -O --text "a DSLR photo of a squirrel" --workspace trial_squirrel
5
+ CUDA_VISIBLE_DEVICES=1 python main_nerf.py -O --text "a DSLR photo of a cat lying on its side batting at a ball of yarn" --workspace trial_cat_lying
shencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sphere_harmonics import SHEncoder
shencoder/backend.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.cpp_extension import load
3
+
4
+ _src_path = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ nvcc_flags = [
7
+ '-O3', '-std=c++14',
8
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
+ ]
10
+
11
+ if os.name == "posix":
12
+ c_flags = ['-O3', '-std=c++14']
13
+ elif os.name == "nt":
14
+ c_flags = ['/O2', '/std:c++17']
15
+
16
+ # find cl.exe
17
+ def find_cl_path():
18
+ import glob
19
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21
+ if paths:
22
+ return paths[0]
23
+
24
+ # If cl.exe is not on path, try to find it.
25
+ if os.system("where cl.exe >nul 2>nul") != 0:
26
+ cl_path = find_cl_path()
27
+ if cl_path is None:
28
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29
+ os.environ["PATH"] += ";" + cl_path
30
+
31
+ _backend = load(name='_sh_encoder',
32
+ extra_cflags=c_flags,
33
+ extra_cuda_cflags=nvcc_flags,
34
+ sources=[os.path.join(_src_path, 'src', f) for f in [
35
+ 'shencoder.cu',
36
+ 'bindings.cpp',
37
+ ]],
38
+ )
39
+
40
+ __all__ = ['_backend']
shencoder/setup.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
+
5
+ _src_path = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ nvcc_flags = [
8
+ '-O3', '-std=c++14',
9
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
+ ]
11
+
12
+ if os.name == "posix":
13
+ c_flags = ['-O3', '-std=c++14']
14
+ elif os.name == "nt":
15
+ c_flags = ['/O2', '/std:c++17']
16
+
17
+ # find cl.exe
18
+ def find_cl_path():
19
+ import glob
20
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
+ if paths:
23
+ return paths[0]
24
+
25
+ # If cl.exe is not on path, try to find it.
26
+ if os.system("where cl.exe >nul 2>nul") != 0:
27
+ cl_path = find_cl_path()
28
+ if cl_path is None:
29
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
+ os.environ["PATH"] += ";" + cl_path
31
+
32
+ setup(
33
+ name='shencoder', # package name, import this to use python API
34
+ ext_modules=[
35
+ CUDAExtension(
36
+ name='_shencoder', # extension name, import this to use CUDA API
37
+ sources=[os.path.join(_src_path, 'src', f) for f in [
38
+ 'shencoder.cu',
39
+ 'bindings.cpp',
40
+ ]],
41
+ extra_compile_args={
42
+ 'cxx': c_flags,
43
+ 'nvcc': nvcc_flags,
44
+ }
45
+ ),
46
+ ],
47
+ cmdclass={
48
+ 'build_ext': BuildExtension,
49
+ }
50
+ )
shencoder/sphere_harmonics.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.autograd import Function
6
+ from torch.autograd.function import once_differentiable
7
+ from torch.cuda.amp import custom_bwd, custom_fwd
8
+
9
+ try:
10
+ import _shencoder as _backend
11
+ except ImportError:
12
+ from .backend import _backend
13
+
14
+ class _sh_encoder(Function):
15
+ @staticmethod
16
+ @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
17
+ def forward(ctx, inputs, degree, calc_grad_inputs=False):
18
+ # inputs: [B, input_dim], float in [-1, 1]
19
+ # RETURN: [B, F], float
20
+
21
+ inputs = inputs.contiguous()
22
+ B, input_dim = inputs.shape # batch size, coord dim
23
+ output_dim = degree ** 2
24
+
25
+ outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
26
+
27
+ if calc_grad_inputs:
28
+ dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device)
29
+ else:
30
+ dy_dx = None
31
+
32
+ _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)
33
+
34
+ ctx.save_for_backward(inputs, dy_dx)
35
+ ctx.dims = [B, input_dim, degree]
36
+
37
+ return outputs
38
+
39
+ @staticmethod
40
+ #@once_differentiable
41
+ @custom_bwd
42
+ def backward(ctx, grad):
43
+ # grad: [B, C * C]
44
+
45
+ inputs, dy_dx = ctx.saved_tensors
46
+
47
+ if dy_dx is not None:
48
+ grad = grad.contiguous()
49
+ B, input_dim, degree = ctx.dims
50
+ grad_inputs = torch.zeros_like(inputs)
51
+ _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs)
52
+ return grad_inputs, None, None
53
+ else:
54
+ return None, None, None
55
+
56
+
57
+
58
+ sh_encode = _sh_encoder.apply
59
+
60
+
61
+ class SHEncoder(nn.Module):
62
+ def __init__(self, input_dim=3, degree=4):
63
+ super().__init__()
64
+
65
+ self.input_dim = input_dim # coord dims, must be 3
66
+ self.degree = degree # 0 ~ 4
67
+ self.output_dim = degree ** 2
68
+
69
+ assert self.input_dim == 3, "SH encoder only support input dim == 3"
70
+ assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]"
71
+
72
+ def __repr__(self):
73
+ return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}"
74
+
75
+ def forward(self, inputs, size=1):
76
+ # inputs: [..., input_dim], normalized real world positions in [-size, size]
77
+ # return: [..., degree^2]
78
+
79
+ inputs = inputs / size # [-1, 1]
80
+
81
+ prefix_shape = list(inputs.shape[:-1])
82
+ inputs = inputs.reshape(-1, self.input_dim)
83
+
84
+ outputs = sh_encode(inputs, self.degree, inputs.requires_grad)
85
+ outputs = outputs.reshape(prefix_shape + [self.output_dim])
86
+
87
+ return outputs
shencoder/src/bindings.cpp ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include "shencoder.h"
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)");
7
+ m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)");
8
+ }
shencoder/src/shencoder.cu ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdint.h>
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_fp16.h>
5
+ #include <cuda_runtime.h>
6
+
7
+ #include <ATen/cuda/CUDAContext.h>
8
+ #include <torch/torch.h>
9
+
10
+ #include <algorithm>
11
+ #include <stdexcept>
12
+
13
+ #include <cstdio>
14
+
15
+
16
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
17
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
18
+ #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
19
+ #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
20
+
21
+
22
+ template <typename T>
23
+ __host__ __device__ T div_round_up(T val, T divisor) {
24
+ return (val + divisor - 1) / divisor;
25
+ }
26
+
27
+ template <typename scalar_t>
28
+ __global__ void kernel_sh(
29
+ const scalar_t * __restrict__ inputs,
30
+ scalar_t * outputs,
31
+ uint32_t B, uint32_t D, uint32_t C,
32
+ scalar_t * dy_dx
33
+ ) {
34
+ const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x;
35
+ if (b >= B) return;
36
+
37
+ const uint32_t C2 = C * C;
38
+
39
+ // locate
40
+ inputs += b * D;
41
+ outputs += b * C2;
42
+
43
+ scalar_t x = inputs[0], y = inputs[1], z = inputs[2];
44
+
45
+ scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z;
46
+ scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2;
47
+ scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2;
48
+
49
+ auto write_sh = [&]() {
50
+ outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi))
51
+ if (C <= 1) { return; }
52
+ outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi))
53
+ outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi))
54
+ outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi))
55
+ if (C <= 2) { return; }
56
+ outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi))
57
+ outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi))
58
+ outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))
59
+ outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi))
60
+ outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi))
61
+ if (C <= 3) { return; }
62
+ outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
63
+ outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi))
64
+ outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))
65
+ outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))
66
+ outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))
67
+ outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))
68
+ outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
69
+ if (C <= 4) { return; }
70
+ outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))
71
+ outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))
72
+ outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))
73
+ outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))
74
+ outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))
75
+ outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))
76
+ outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))
77
+ outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))
78
+ outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
79
+ if (C <= 5) { return; }
80
+ outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
81
+ outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))
82
+ outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
83
+ outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))
84
+ outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
85
+ outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))
86
+ outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
87
+ outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))
88
+ outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))
89
+ outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
90
+ outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
91
+ if (C <= 6) { return; }
92
+ outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
93
+ outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
94
+ outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
95
+ outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
96
+ outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
97
+ outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
98
+ outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi))
99
+ outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
100
+ outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi))
101
+ outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))
102
+ outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
103
+ outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
104
+ outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
105
+ if (C <= 7) { return; }
106
+ outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi))
107
+ outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
108
+ outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi))
109
+ outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
110
+ outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
111
+ outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
112
+ outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
113
+ outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi))
114
+ outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
115
+ outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi))
116
+ outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
117
+ outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
118
+ outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi))
119
+ outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
120
+ outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi))
121
+ };
122
+
123
+ write_sh();
124
+
125
+ if (dy_dx) {
126
+ scalar_t *dx = dy_dx + b * D * C2;
127
+ scalar_t *dy = dx + C2;
128
+ scalar_t *dz = dy + C2;
129
+
130
+ auto write_sh_dx = [&]() {
131
+ dx[0] = 0.0f ; // 0
132
+ if (C <= 1) { return; }
133
+ dx[1] = 0.0f ; // 0
134
+ dx[2] = 0.0f ; // 0
135
+ dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
136
+ if (C <= 2) { return; }
137
+ dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi))
138
+ dx[5] = 0.0f ; // 0
139
+ dx[6] = 0.0f ; // 0
140
+ dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
141
+ dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
142
+ if (C <= 3) { return; }
143
+ dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi))
144
+ dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi))
145
+ dx[11] = 0.0f ; // 0
146
+ dx[12] = 0.0f ; // 0
147
+ dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
148
+ dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
149
+ dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
150
+ if (C <= 4) { return; }
151
+ dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi))
152
+ dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi))
153
+ dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi))
154
+ dx[19] = 0.0f ; // 0
155
+ dx[20] = 0.0f ; // 0
156
+ dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
157
+ dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
158
+ dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
159
+ dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
160
+ if (C <= 5) { return; }
161
+ dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi))
162
+ dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi))
163
+ dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi))
164
+ dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi))
165
+ dx[29] = 0.0f ; // 0
166
+ dx[30] = 0.0f ; // 0
167
+ dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
168
+ dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
169
+ dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi))
170
+ dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
171
+ dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
172
+ if (C <= 6) { return; }
173
+ dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
174
+ dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi))
175
+ dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
176
+ dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi))
177
+ dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
178
+ dx[41] = 0.0f ; // 0
179
+ dx[42] = 0.0f ; // 0
180
+ dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
181
+ dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
182
+ dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
183
+ dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
184
+ dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
185
+ dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
186
+ if (C <= 7) { return; }
187
+ dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi))
188
+ dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
189
+ dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
190
+ dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
191
+ dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi))
192
+ dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
193
+ dx[55] = 0.0f ; // 0
194
+ dx[56] = 0.0f ; // 0
195
+ dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
196
+ dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
197
+ dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi))
198
+ dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
199
+ dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi))
200
+ dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
201
+ dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
202
+ };
203
+
204
+ auto write_sh_dy = [&]() {
205
+ dy[0] = 0.0f ; // 0
206
+ if (C <= 1) { return; }
207
+ dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
208
+ dy[2] = 0.0f ; // 0
209
+ dy[3] = 0.0f ; // 0
210
+ if (C <= 2) { return; }
211
+ dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
212
+ dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
213
+ dy[6] = 0.0f ; // 0
214
+ dy[7] = 0.0f ; // 0
215
+ dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
216
+ if (C <= 3) { return; }
217
+ dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
218
+ dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
219
+ dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
220
+ dy[12] = 0.0f ; // 0
221
+ dy[13] = 0.0f ; // 0
222
+ dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi))
223
+ dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi))
224
+ if (C <= 4) { return; }
225
+ dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
226
+ dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
227
+ dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
228
+ dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
229
+ dy[20] = 0.0f ; // 0
230
+ dy[21] = 0.0f ; // 0
231
+ dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi))
232
+ dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi))
233
+ dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi))
234
+ if (C <= 5) { return; }
235
+ dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
236
+ dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
237
+ dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
238
+ dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
239
+ dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
240
+ dy[30] = 0.0f ; // 0
241
+ dy[31] = 0.0f ; // 0
242
+ dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi))
243
+ dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi))
244
+ dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi))
245
+ dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi))
246
+ if (C <= 6) { return; }
247
+ dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
248
+ dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
249
+ dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
250
+ dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
251
+ dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
252
+ dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
253
+ dy[42] = 0.0f ; // 0
254
+ dy[43] = 0.0f ; // 0
255
+ dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi))
256
+ dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi))
257
+ dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
258
+ dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi))
259
+ dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
260
+ if (C <= 7) { return; }
261
+ dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
262
+ dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
263
+ dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi))
264
+ dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
265
+ dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
266
+ dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
267
+ dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
268
+ dy[56] = 0.0f ; // 0
269
+ dy[57] = 0.0f ; // 0
270
+ dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
271
+ dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
272
+ dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
273
+ dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
274
+ dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
275
+ dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
276
+ };
277
+
278
+ auto write_sh_dz = [&]() {
279
+ dz[0] = 0.0f ; // 0
280
+ if (C <= 1) { return; }
281
+ dz[1] = 0.0f ; // 0
282
+ dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi))
283
+ dz[3] = 0.0f ; // 0
284
+ if (C <= 2) { return; }
285
+ dz[4] = 0.0f ; // 0
286
+ dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
287
+ dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi))
288
+ dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi))
289
+ dz[8] = 0.0f ; // 0
290
+ if (C <= 3) { return; }
291
+ dz[9] = 0.0f ; // 0
292
+ dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi))
293
+ dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi))
294
+ dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi))
295
+ dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi))
296
+ dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi))
297
+ dz[15] = 0.0f ; // 0
298
+ if (C <= 4) { return; }
299
+ dz[16] = 0.0f ; // 0
300
+ dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
301
+ dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi))
302
+ dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi))
303
+ dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi))
304
+ dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi))
305
+ dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi))
306
+ dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
307
+ dz[24] = 0.0f ; // 0
308
+ if (C <= 5) { return; }
309
+ dz[25] = 0.0f ; // 0
310
+ dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi))
311
+ dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi))
312
+ dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi))
313
+ dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi))
314
+ dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi))
315
+ dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi))
316
+ dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi))
317
+ dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi))
318
+ dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
319
+ dz[35] = 0.0f ; // 0
320
+ if (C <= 6) { return; }
321
+ dz[36] = 0.0f ; // 0
322
+ dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
323
+ dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi))
324
+ dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi))
325
+ dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi))
326
+ dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
327
+ dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi))
328
+ dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
329
+ dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi))
330
+ dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi))
331
+ dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
332
+ dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
333
+ dz[48] = 0.0f ; // 0
334
+ if (C <= 7) { return; }
335
+ dz[49] = 0.0f ; // 0
336
+ dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
337
+ dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
338
+ dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi))
339
+ dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi))
340
+ dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
341
+ dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
342
+ dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi))
343
+ dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
344
+ dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi))
345
+ dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi))
346
+ dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
347
+ dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
348
+ dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
349
+ dz[63] = 0.0f ; // 0
350
+ };
351
+ write_sh_dx();
352
+ write_sh_dy();
353
+ write_sh_dz();
354
+ }
355
+ }
356
+
357
+
358
+ template <typename scalar_t>
359
+ __global__ void kernel_sh_backward(
360
+ const scalar_t * __restrict__ grad,
361
+ const scalar_t * __restrict__ inputs,
362
+ uint32_t B, uint32_t D, uint32_t C,
363
+ const scalar_t * __restrict__ dy_dx,
364
+ scalar_t * grad_inputs
365
+ ) {
366
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
367
+ const uint32_t b = t / D;
368
+ if (b >= B) return;
369
+
370
+ const uint32_t d = t - b * D;
371
+ const uint32_t C2 = C * C;
372
+
373
+ // locate
374
+ grad += b * C2;
375
+ dy_dx += b * D * C2 + d * C2;
376
+
377
+ for (int ch = 0; ch < C2; ch++) {
378
+ grad_inputs[t] += grad[ch] * dy_dx[ch];
379
+ //printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]);
380
+ }
381
+
382
+ }
383
+
384
+ // inputs: [B, D], float, in [0, 1]
385
+ // outputs: [B, L * C], float
386
+ template <typename scalar_t>
387
+ void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) {
388
+ static constexpr uint32_t N_THREADS = 256;
389
+ kernel_sh<scalar_t><<<div_round_up(B, N_THREADS), N_THREADS>>>(inputs, outputs, B, D, C, dy_dx);
390
+ }
391
+
392
+
393
+ template <typename scalar_t>
394
+ void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) {
395
+ static constexpr uint32_t N_THREADS = 256;
396
+ kernel_sh_backward<scalar_t><<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad, inputs, B, D, C, dy_dx, grad_inputs);
397
+ }
398
+
399
+
400
+ void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx) {
401
+ CHECK_CUDA(inputs);
402
+ CHECK_CUDA(outputs);
403
+ // CHECK_CUDA(dy_dx);
404
+
405
+ CHECK_CONTIGUOUS(inputs);
406
+ CHECK_CONTIGUOUS(outputs);
407
+ // CHECK_CONTIGUOUS(dy_dx);
408
+
409
+ CHECK_IS_FLOATING(inputs);
410
+ CHECK_IS_FLOATING(outputs);
411
+ // CHECK_IS_FLOATING(dy_dx);
412
+
413
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
414
+ inputs.scalar_type(), "sh_encode_forward_cuda", ([&] {
415
+ sh_encode_forward_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), outputs.data_ptr<scalar_t>(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr);
416
+ }));
417
+ }
418
+
419
+ void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) {
420
+ CHECK_CUDA(grad);
421
+ CHECK_CUDA(inputs);
422
+ CHECK_CUDA(dy_dx);
423
+ CHECK_CUDA(grad_inputs);
424
+
425
+ CHECK_CONTIGUOUS(grad);
426
+ CHECK_CONTIGUOUS(inputs);
427
+ CHECK_CONTIGUOUS(dy_dx);
428
+ CHECK_CONTIGUOUS(grad_inputs);
429
+
430
+ CHECK_IS_FLOATING(grad);
431
+ CHECK_IS_FLOATING(inputs);
432
+ CHECK_IS_FLOATING(dy_dx);
433
+ CHECK_IS_FLOATING(grad_inputs);
434
+
435
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
436
+ grad.scalar_type(), "sh_encode_backward_cuda", ([&] {
437
+ sh_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<scalar_t>(), B, D, C, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>());
438
+ }));
439
+ }
shencoder/src/shencoder.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # pragma once
2
+
3
+ #include <stdint.h>
4
+ #include <torch/torch.h>
5
+
6
+ // inputs: [B, D], float, in [-1, 1]
7
+ // outputs: [B, F], float
8
+
9
+ void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx);
10
+ void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs);