Spaces:
Runtime error
Runtime error
File size: 4,815 Bytes
2fa4776 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import random
from dataclasses import dataclass, field
import envlight
import numpy as np
import nvdiffrast.torch as dr
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.materials.base import BaseMaterial
from threestudio.utils.ops import get_activation
from threestudio.utils.typing import *
@threestudio.register("pbr-material")
class PBRMaterial(BaseMaterial):
@dataclass
class Config(BaseMaterial.Config):
material_activation: str = "sigmoid"
environment_texture: str = "load/lights/mud_road_puresky_1k.hdr"
environment_scale: float = 2.0
min_metallic: float = 0.0
max_metallic: float = 0.9
min_roughness: float = 0.08
max_roughness: float = 0.9
use_bump: bool = True
cfg: Config
def configure(self) -> None:
self.requires_normal = True
self.requires_tangent = self.cfg.use_bump
self.light = envlight.EnvLight(
self.cfg.environment_texture, scale=self.cfg.environment_scale
)
FG_LUT = torch.from_numpy(
np.fromfile("load/lights/bsdf_256_256.bin", dtype=np.float32).reshape(
1, 256, 256, 2
)
)
self.register_buffer("FG_LUT", FG_LUT)
def forward(
self,
features: Float[Tensor, "*B Nf"],
viewdirs: Float[Tensor, "*B 3"],
shading_normal: Float[Tensor, "B ... 3"],
tangent: Optional[Float[Tensor, "B ... 3"]] = None,
**kwargs,
) -> Float[Tensor, "*B 3"]:
prefix_shape = features.shape[:-1]
material: Float[Tensor, "*B Nf"] = get_activation(self.cfg.material_activation)(
features
)
albedo = material[..., :3]
metallic = (
material[..., 3:4] * (self.cfg.max_metallic - self.cfg.min_metallic)
+ self.cfg.min_metallic
)
roughness = (
material[..., 4:5] * (self.cfg.max_roughness - self.cfg.min_roughness)
+ self.cfg.min_roughness
)
if self.cfg.use_bump:
assert tangent is not None
# perturb_normal is a delta to the initialization [0, 0, 1]
perturb_normal = (material[..., 5:8] * 2 - 1) + torch.tensor(
[0, 0, 1], dtype=material.dtype, device=material.device
)
perturb_normal = F.normalize(perturb_normal.clamp(-1, 1), dim=-1)
# apply normal perturbation in tangent space
bitangent = F.normalize(torch.cross(tangent, shading_normal), dim=-1)
shading_normal = (
tangent * perturb_normal[..., 0:1]
- bitangent * perturb_normal[..., 1:2]
+ shading_normal * perturb_normal[..., 2:3]
)
shading_normal = F.normalize(shading_normal, dim=-1)
v = -viewdirs
n_dot_v = (shading_normal * v).sum(-1, keepdim=True)
reflective = n_dot_v * shading_normal * 2 - v
diffuse_albedo = (1 - metallic) * albedo
fg_uv = torch.cat([n_dot_v, roughness], -1).clamp(0, 1)
fg = dr.texture(
self.FG_LUT,
fg_uv.reshape(1, -1, 1, 2).contiguous(),
filter_mode="linear",
boundary_mode="clamp",
).reshape(*prefix_shape, 2)
F0 = (1 - metallic) * 0.04 + metallic * albedo
specular_albedo = F0 * fg[:, 0:1] + fg[:, 1:2]
diffuse_light = self.light(shading_normal)
specular_light = self.light(reflective, roughness)
color = diffuse_albedo * diffuse_light + specular_albedo * specular_light
color = color.clamp(0.0, 1.0)
return color
def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]:
material: Float[Tensor, "*N Nf"] = get_activation(self.cfg.material_activation)(
features
)
albedo = material[..., :3]
metallic = (
material[..., 3:4] * (self.cfg.max_metallic - self.cfg.min_metallic)
+ self.cfg.min_metallic
)
roughness = (
material[..., 4:5] * (self.cfg.max_roughness - self.cfg.min_roughness)
+ self.cfg.min_roughness
)
out = {
"albedo": albedo,
"metallic": metallic,
"roughness": roughness,
}
if self.cfg.use_bump:
perturb_normal = (material[..., 5:8] * 2 - 1) + torch.tensor(
[0, 0, 1], dtype=material.dtype, device=material.device
)
perturb_normal = F.normalize(perturb_normal.clamp(-1, 1), dim=-1)
perturb_normal = (perturb_normal + 1) / 2
out.update(
{
"bump": perturb_normal,
}
)
return out
|