Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Optional | |
| import torch | |
| import pytorch3d | |
| from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj | |
| from pytorch3d.ops import interpolate_face_attributes | |
| from pytorch3d.structures import Meshes | |
| from pytorch3d.renderer import ( | |
| look_at_view_transform, | |
| FoVPerspectiveCameras, | |
| AmbientLights, | |
| PointLights, | |
| DirectionalLights, | |
| Materials, | |
| RasterizationSettings, | |
| MeshRenderer, | |
| MeshRasterizer, | |
| SoftPhongShader, | |
| SoftSilhouetteShader, | |
| HardPhongShader, | |
| TexturesVertex, | |
| TexturesUV, | |
| Materials, | |
| ) | |
| from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend | |
| from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties | |
| from pytorch3d.renderer.lighting import AmbientLights | |
| from pytorch3d.renderer.materials import Materials | |
| from pytorch3d.renderer.mesh.shader import ShaderBase | |
| from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading | |
| from pytorch3d.renderer.mesh.rasterizer import Fragments | |
| """ | |
| Customized the original pytorch3d hard flat shader to support N channel flat shading | |
| """ | |
| class HardNChannelFlatShader(ShaderBase): | |
| """ | |
| Per face lighting - the lighting model is applied using the average face | |
| position and the face normal. The blending function hard assigns | |
| the color of the closest face for each pixel. | |
| To use the default values, simply initialize the shader with the desired | |
| device e.g. | |
| .. code-block:: | |
| shader = HardFlatShader(device=torch.device("cuda:0")) | |
| """ | |
| def __init__( | |
| self, | |
| device="cpu", | |
| cameras: Optional[TensorProperties] = None, | |
| lights: Optional[TensorProperties] = None, | |
| materials: Optional[Materials] = None, | |
| blend_params: Optional[BlendParams] = None, | |
| channels: int = 3, | |
| ): | |
| self.channels = channels | |
| ones = ((1.0,) * channels,) | |
| zeros = ((0.0,) * channels,) | |
| if ( | |
| not isinstance(lights, AmbientLights) | |
| or not lights.ambient_color.shape[-1] == channels | |
| ): | |
| lights = AmbientLights( | |
| ambient_color=ones, | |
| device=device, | |
| ) | |
| if not materials or not materials.ambient_color.shape[-1] == channels: | |
| materials = Materials( | |
| device=device, | |
| diffuse_color=zeros, | |
| ambient_color=ones, | |
| specular_color=zeros, | |
| shininess=0.0, | |
| ) | |
| blend_params_new = BlendParams(background_color=(1.0,) * channels) | |
| if not isinstance(blend_params, BlendParams): | |
| blend_params = blend_params_new | |
| else: | |
| background_color_ = blend_params.background_color | |
| if ( | |
| isinstance(background_color_, Sequence[float]) | |
| and not len(background_color_) == channels | |
| ): | |
| blend_params = blend_params_new | |
| if ( | |
| isinstance(background_color_, torch.Tensor) | |
| and not background_color_.shape[-1] == channels | |
| ): | |
| blend_params = blend_params_new | |
| super().__init__( | |
| device, | |
| cameras, | |
| lights, | |
| materials, | |
| blend_params, | |
| ) | |
| def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: | |
| cameras = super()._get_cameras(**kwargs) | |
| texels = meshes.sample_textures(fragments) | |
| lights = kwargs.get("lights", self.lights) | |
| materials = kwargs.get("materials", self.materials) | |
| blend_params = kwargs.get("blend_params", self.blend_params) | |
| colors = flat_shading( | |
| meshes=meshes, | |
| fragments=fragments, | |
| texels=texels, | |
| lights=lights, | |
| cameras=cameras, | |
| materials=materials, | |
| ) | |
| images = hard_rgb_blend(colors, fragments, blend_params) | |
| return images | |