Spaces:
Runtime error
Runtime error
import random | |
from dataclasses import dataclass, field | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import threestudio | |
from threestudio.models.materials.base import BaseMaterial | |
from threestudio.models.networks import get_encoding, get_mlp | |
from threestudio.utils.ops import dot, get_activation | |
from threestudio.utils.typing import * | |
class NeuralRadianceMaterial(BaseMaterial): | |
class Config(BaseMaterial.Config): | |
input_feature_dims: int = 8 | |
color_activation: str = "sigmoid" | |
dir_encoding_config: dict = field( | |
default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} | |
) | |
mlp_network_config: dict = field( | |
default_factory=lambda: { | |
"otype": "FullyFusedMLP", | |
"activation": "ReLU", | |
"n_neurons": 16, | |
"n_hidden_layers": 2, | |
} | |
) | |
cfg: Config | |
def configure(self) -> None: | |
self.encoding = get_encoding(3, self.cfg.dir_encoding_config) | |
self.n_input_dims = self.cfg.input_feature_dims + self.encoding.n_output_dims # type: ignore | |
self.network = get_mlp(self.n_input_dims, 3, self.cfg.mlp_network_config) | |
def forward( | |
self, | |
features: Float[Tensor, "*B Nf"], | |
viewdirs: Float[Tensor, "*B 3"], | |
**kwargs, | |
) -> Float[Tensor, "*B 3"]: | |
# viewdirs and normals must be normalized before passing to this function | |
viewdirs = (viewdirs + 1.0) / 2.0 # (-1, 1) => (0, 1) | |
viewdirs_embd = self.encoding(viewdirs.view(-1, 3)) | |
network_inp = torch.cat( | |
[features.view(-1, features.shape[-1]), viewdirs_embd], dim=-1 | |
) | |
color = self.network(network_inp).view(*features.shape[:-1], 3) | |
color = get_activation(self.cfg.color_activation)(color) | |
return color | |