yuandong513
feat: init
17cd746
raw
history blame
3.4 kB
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
from lam.models.rendering.utils.typing import *
def get_activation(name):
if name is None:
return lambda x: x
name = name.lower()
if name == "none":
return lambda x: x
elif name == "lin2srgb":
return lambda x: torch.where(
x > 0.0031308,
torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
12.92 * x,
).clamp(0.0, 1.0)
elif name == "exp":
return lambda x: torch.exp(x)
elif name == "shifted_exp":
return lambda x: torch.exp(x - 1.0)
elif name == "trunc_exp":
return trunc_exp
elif name == "shifted_trunc_exp":
return lambda x: trunc_exp(x - 1.0)
elif name == "sigmoid":
return lambda x: torch.sigmoid(x)
elif name == "tanh":
return lambda x: torch.tanh(x)
elif name == "shifted_softplus":
return lambda x: F.softplus(x - 1.0)
elif name == "scale_-11_01":
return lambda x: x * 0.5 + 0.5
else:
try:
return getattr(F, name)
except AttributeError:
raise ValueError(f"Unknown activation function: {name}")
class MLP(nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
n_neurons: int,
n_hidden_layers: int,
activation: str = "relu",
output_activation: Optional[str] = None,
bias: bool = True,
):
super().__init__()
layers = [
self.make_linear(
dim_in, n_neurons, is_first=True, is_last=False, bias=bias
),
self.make_activation(activation),
]
for i in range(n_hidden_layers - 1):
layers += [
self.make_linear(
n_neurons, n_neurons, is_first=False, is_last=False, bias=bias
),
self.make_activation(activation),
]
layers += [
self.make_linear(
n_neurons, dim_out, is_first=False, is_last=True, bias=bias
)
]
self.layers = nn.Sequential(*layers)
self.output_activation = get_activation(output_activation)
def forward(self, x):
x = self.layers(x)
x = self.output_activation(x)
return x
def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True):
layer = nn.Linear(dim_in, dim_out, bias=bias)
return layer
def make_activation(self, activation):
if activation == "relu":
return nn.ReLU(inplace=True)
elif activation == "silu":
return nn.SiLU(inplace=True)
else:
raise NotImplementedError
class _TruncExp(Function): # pylint: disable=abstract-method
# Implementation from torch-ngp:
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, x): # pylint: disable=arguments-differ
ctx.save_for_backward(x)
return torch.exp(x)
@staticmethod
@custom_bwd
def backward(ctx, g): # pylint: disable=arguments-differ
x = ctx.saved_tensors[0]
return g * torch.exp(torch.clamp(x, max=15))
trunc_exp = _TruncExp.apply