ReubenSun's picture
1
2ac1c2d
from typing import Optional
import torch
import pytorch3d
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
FoVPerspectiveCameras,
AmbientLights,
PointLights,
DirectionalLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
SoftSilhouetteShader,
HardPhongShader,
TexturesVertex,
TexturesUV,
Materials,
)
from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
from pytorch3d.renderer.lighting import AmbientLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh.shader import ShaderBase
from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading
from pytorch3d.renderer.mesh.rasterizer import Fragments
"""
Customized the original pytorch3d hard flat shader to support N channel flat shading
"""
class HardNChannelFlatShader(ShaderBase):
"""
Per face lighting - the lighting model is applied using the average face
position and the face normal. The blending function hard assigns
the color of the closest face for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = HardFlatShader(device=torch.device("cuda:0"))
"""
def __init__(
self,
device="cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
channels: int = 3,
):
self.channels = channels
ones = ((1.0,) * channels,)
zeros = ((0.0,) * channels,)
if (
not isinstance(lights, AmbientLights)
or not lights.ambient_color.shape[-1] == channels
):
lights = AmbientLights(
ambient_color=ones,
device=device,
)
if not materials or not materials.ambient_color.shape[-1] == channels:
materials = Materials(
device=device,
diffuse_color=zeros,
ambient_color=ones,
specular_color=zeros,
shininess=0.0,
)
blend_params_new = BlendParams(background_color=(1.0,) * channels)
if not isinstance(blend_params, BlendParams):
blend_params = blend_params_new
else:
background_color_ = blend_params.background_color
if (
isinstance(background_color_, Sequence[float])
and not len(background_color_) == channels
):
blend_params = blend_params_new
if (
isinstance(background_color_, torch.Tensor)
and not background_color_.shape[-1] == channels
):
blend_params = blend_params_new
super().__init__(
device,
cameras,
lights,
materials,
blend_params,
)
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = super()._get_cameras(**kwargs)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
colors = flat_shading(
meshes=meshes,
fragments=fragments,
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
images = hard_rgb_blend(colors, fragments, blend_params)
return images