Chao Xu
sparseneus and elev est
854f0d0
raw
history blame contribute delete
No virus
3.19 kB
import torch
import torch.nn as nn
""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
else:
freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
if self.kwargs['normalize']:
embed_fns.append(lambda x, p_fn=p_fn,
freq=freq: p_fn(x * freq) / freq)
else:
embed_fns.append(lambda x, p_fn=p_fn,
freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, normalize=False, input_dims=3):
embed_kwargs = {
'include_input': True,
'input_dims': input_dims,
'max_freq_log2': multires - 1,
'num_freqs': multires,
'normalize': normalize,
'log_sampling': True,
'periodic_fns': [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
def embed(x, eo=embedder_obj): return eo.embed(x)
return embed, embedder_obj.out_dim
class Embedding(nn.Module):
def __init__(self, in_channels, N_freqs, logscale=True, normalize=False):
"""
Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
in_channels: number of input channels (3 for both xyz and direction)
"""
super(Embedding, self).__init__()
self.N_freqs = N_freqs
self.in_channels = in_channels
self.funcs = [torch.sin, torch.cos]
self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1)
self.normalize = normalize
if logscale:
self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs)
else:
self.freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs)
def forward(self, x):
"""
Embeds x to (x, sin(2^k x), cos(2^k x), ...)
Different from the paper, "x" is also in the output
See https://github.com/bmild/nerf/issues/12
Inputs:
x: (B, self.in_channels)
Outputs:
out: (B, self.out_channels)
"""
out = [x]
for freq in self.freq_bands:
for func in self.funcs:
if self.normalize:
out += [func(freq * x) / freq]
else:
out += [func(freq * x)]
return torch.cat(out, -1)