Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional | |
import torch | |
import pytorch3d | |
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj | |
from pytorch3d.ops import interpolate_face_attributes | |
from pytorch3d.structures import Meshes | |
from pytorch3d.renderer import ( | |
look_at_view_transform, | |
FoVPerspectiveCameras, | |
AmbientLights, | |
PointLights, | |
DirectionalLights, | |
Materials, | |
RasterizationSettings, | |
MeshRenderer, | |
MeshRasterizer, | |
SoftPhongShader, | |
SoftSilhouetteShader, | |
HardPhongShader, | |
TexturesVertex, | |
TexturesUV, | |
Materials, | |
) | |
from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend | |
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties | |
from pytorch3d.renderer.lighting import AmbientLights | |
from pytorch3d.renderer.materials import Materials | |
from pytorch3d.renderer.mesh.shader import ShaderBase | |
from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading | |
from pytorch3d.renderer.mesh.rasterizer import Fragments | |
""" | |
Customized the original pytorch3d hard flat shader to support N channel flat shading | |
""" | |
class HardNChannelFlatShader(ShaderBase): | |
""" | |
Per face lighting - the lighting model is applied using the average face | |
position and the face normal. The blending function hard assigns | |
the color of the closest face for each pixel. | |
To use the default values, simply initialize the shader with the desired | |
device e.g. | |
.. code-block:: | |
shader = HardFlatShader(device=torch.device("cuda:0")) | |
""" | |
def __init__( | |
self, | |
device="cpu", | |
cameras: Optional[TensorProperties] = None, | |
lights: Optional[TensorProperties] = None, | |
materials: Optional[Materials] = None, | |
blend_params: Optional[BlendParams] = None, | |
channels: int = 3, | |
): | |
self.channels = channels | |
ones = ((1.0,) * channels,) | |
zeros = ((0.0,) * channels,) | |
if ( | |
not isinstance(lights, AmbientLights) | |
or not lights.ambient_color.shape[-1] == channels | |
): | |
lights = AmbientLights( | |
ambient_color=ones, | |
device=device, | |
) | |
if not materials or not materials.ambient_color.shape[-1] == channels: | |
materials = Materials( | |
device=device, | |
diffuse_color=zeros, | |
ambient_color=ones, | |
specular_color=zeros, | |
shininess=0.0, | |
) | |
blend_params_new = BlendParams(background_color=(1.0,) * channels) | |
if not isinstance(blend_params, BlendParams): | |
blend_params = blend_params_new | |
else: | |
background_color_ = blend_params.background_color | |
if ( | |
isinstance(background_color_, Sequence[float]) | |
and not len(background_color_) == channels | |
): | |
blend_params = blend_params_new | |
if ( | |
isinstance(background_color_, torch.Tensor) | |
and not background_color_.shape[-1] == channels | |
): | |
blend_params = blend_params_new | |
super().__init__( | |
device, | |
cameras, | |
lights, | |
materials, | |
blend_params, | |
) | |
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: | |
cameras = super()._get_cameras(**kwargs) | |
texels = meshes.sample_textures(fragments) | |
lights = kwargs.get("lights", self.lights) | |
materials = kwargs.get("materials", self.materials) | |
blend_params = kwargs.get("blend_params", self.blend_params) | |
colors = flat_shading( | |
meshes=meshes, | |
fragments=fragments, | |
texels=texels, | |
lights=lights, | |
cameras=cameras, | |
materials=materials, | |
) | |
images = hard_rgb_blend(colors, fragments, blend_params) | |
return images | |