| import torch |
| import torch.nn as nn |
| from torch.autograd import Function |
| from torch.cuda.amp import custom_bwd, custom_fwd |
|
|
| from core.models.rendering.utils.typing import * |
|
|
|
|
| def get_activation(name): |
| if name is None: |
| return lambda x: x |
| name = name.lower() |
| if name == "none": |
| return lambda x: x |
| elif name == "lin2srgb": |
| return lambda x: torch.where( |
| x > 0.0031308, |
| torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, |
| 12.92 * x, |
| ).clamp(0.0, 1.0) |
| elif name == "exp": |
| return lambda x: torch.exp(x) |
| elif name == "shifted_exp": |
| return lambda x: torch.exp(x - 1.0) |
| elif name == "trunc_exp": |
| return trunc_exp |
| elif name == "shifted_trunc_exp": |
| return lambda x: trunc_exp(x - 1.0) |
| elif name == "sigmoid": |
| return lambda x: torch.sigmoid(x) |
| elif name == "tanh": |
| return lambda x: torch.tanh(x) |
| elif name == "shifted_softplus": |
| return lambda x: F.softplus(x - 1.0) |
| elif name == "scale_-11_01": |
| return lambda x: x * 0.5 + 0.5 |
| else: |
| try: |
| return getattr(F, name) |
| except AttributeError: |
| raise ValueError(f"Unknown activation function: {name}") |
|
|
|
|
| class MLP(nn.Module): |
| def __init__( |
| self, |
| dim_in: int, |
| dim_out: int, |
| n_neurons: int, |
| n_hidden_layers: int, |
| activation: str = "relu", |
| output_activation: Optional[str] = None, |
| bias: bool = True, |
| ): |
| super().__init__() |
| layers = [ |
| self.make_linear( |
| dim_in, n_neurons, is_first=True, is_last=False, bias=bias |
| ), |
| self.make_activation(activation), |
| ] |
| for i in range(n_hidden_layers - 1): |
| layers += [ |
| self.make_linear( |
| n_neurons, n_neurons, is_first=False, is_last=False, bias=bias |
| ), |
| self.make_activation(activation), |
| ] |
| layers += [ |
| self.make_linear( |
| n_neurons, dim_out, is_first=False, is_last=True, bias=bias |
| ) |
| ] |
| self.layers = nn.Sequential(*layers) |
| self.output_activation = get_activation(output_activation) |
|
|
| def forward(self, x): |
| x = self.layers(x) |
| x = self.output_activation(x) |
| return x |
|
|
| def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True): |
| layer = nn.Linear(dim_in, dim_out, bias=bias) |
| return layer |
|
|
| def make_activation(self, activation): |
| if activation == "relu": |
| return nn.ReLU(inplace=True) |
| elif activation == "silu": |
| return nn.SiLU(inplace=True) |
| else: |
| raise NotImplementedError |
|
|
|
|
| class _TruncExp(Function): |
| |
| |
| @staticmethod |
| @custom_fwd(cast_inputs=torch.float32) |
| def forward(ctx, x): |
| ctx.save_for_backward(x) |
| return torch.exp(x) |
|
|
| @staticmethod |
| @custom_bwd |
| def backward(ctx, g): |
| x = ctx.saved_tensors[0] |
| return g * torch.exp(torch.clamp(x, max=15)) |
|
|
|
|
| trunc_exp = _TruncExp.apply |
|
|