| |
| |
| |
| |
| |
| |
| import os |
|
|
| import numpy as np |
| import torch |
|
|
| from core.models.rendering.base_gs_render import BaseGSRender |
| from core.models.rendering.gaussian_decoder.mlp_decoder import GSMLPDecoder |
| from core.models.rendering.utils.typing import * |
| from core.outputs.output import GaussianAppOutput |
|
|
|
|
| class GS3DRenderer(BaseGSRender): |
| def __init__( |
| self, |
| human_model_path, |
| subdivide_num, |
| smpl_type, |
| feat_dim, |
| query_dim, |
| use_rgb, |
| sh_degree, |
| xyz_offset_max_step, |
| mlp_network_config, |
| expr_param_dim, |
| shape_param_dim, |
| clip_scaling=0.2, |
| cano_pose_type=0, |
| decoder_mlp=False, |
| skip_decoder=False, |
| fix_opacity=False, |
| fix_rotation=False, |
| decode_with_extra_info=None, |
| gradient_checkpointing=False, |
| apply_pose_blendshape=False, |
| dense_sample_pts=40000, |
| gs_deform_scale=0.005, |
| render_features=False, |
| ): |
| """ |
| Initializes the GS3DRenderer, a subclass of BaseGSRender for 3D Gaussian Splatting rendering. |
| |
| Args: |
| human_model_path (str): Path to human model files. |
| subdivide_num (int): Subdivision number for base mesh. |
| smpl_type (str): Type of SMPL/SMPL-X/other model to use. |
| feat_dim (int): Dimension of feature embeddings. |
| query_dim (int): Dimension of query points/features. |
| use_rgb (bool): Whether to use RGB channels. |
| sh_degree (int): Spherical harmonics degree for appearance. |
| xyz_offset_max_step (float): Max offset per step for position. |
| mlp_network_config (dict or None): MLP configuration for feature mapping. |
| expr_param_dim (int): Expression parameter dimension. |
| shape_param_dim (int): Shape parameter dimension. |
| clip_scaling (float, optional): Output scaling for decoder. Default 0.2. |
| cano_pose_type (int, optional): Canonical pose type. Default 0. |
| decoder_mlp (bool, optional): Use MLP in decoder cross-attention. Default False. |
| skip_decoder (bool, optional): Whether to skip decoder and cross-attn layers. Default False. |
| fix_opacity (bool, optional): Fix opacity during training. Default False. |
| fix_rotation (bool, optional): Fix rotation during training. Default False. |
| decode_with_extra_info (dict or None, optional): Provide extra info to decoder. Default None. |
| gradient_checkpointing (bool, optional): Enable gradient checkpointing. Default False. |
| apply_pose_blendshape (bool, optional): Apply pose blendshape. Default False. |
| dense_sample_pts (int, optional): Dense sample points for mesh/voxel. Default 40000. |
| gs_deform_scale (float, optional): Deformation scale for Gaussian Splatting. Default 0.005. |
| render_features (bool, optional): Output additional features in renderer. Default False. |
| """ |
|
|
| super(GS3DRenderer, self).__init__( |
| human_model_path, |
| subdivide_num, |
| smpl_type, |
| feat_dim, |
| query_dim, |
| use_rgb, |
| sh_degree, |
| xyz_offset_max_step, |
| mlp_network_config, |
| expr_param_dim, |
| shape_param_dim, |
| clip_scaling, |
| cano_pose_type, |
| decoder_mlp, |
| skip_decoder, |
| fix_opacity, |
| fix_rotation, |
| decode_with_extra_info, |
| gradient_checkpointing, |
| apply_pose_blendshape, |
| dense_sample_pts, |
| gs_deform_scale, |
| render_features, |
| ) |
|
|
| self.gs_net = GSMLPDecoder( |
| in_channels=query_dim, |
| use_rgb=use_rgb, |
| sh_degree=self.sh_degree, |
| clip_scaling=clip_scaling, |
| init_scaling=-6.0, |
| init_density=0.1, |
| xyz_offset=True, |
| restrict_offset=True, |
| xyz_offset_max_step=xyz_offset_max_step, |
| fix_opacity=fix_opacity, |
| fix_rotation=fix_rotation, |
| use_fine_feat=( |
| True |
| if decode_with_extra_info is not None |
| and decode_with_extra_info["type"] is not None |
| else False |
| ), |
| ) |
|
|
| self.gs_deform_scale = gs_deform_scale |
|
|
| def hyper_step(self, step): |
| """using to adjust the constrain scale.""" |
|
|
| self.gs_net.hyper_step(step) |
|
|
| def forward_gs_attr( |
| self, x, query_points, smplx_data, debug=False, x_fine=None, mesh_meta=None |
| ): |
| """ |
| x: [N, C] Float[Tensor, "Np Cp"], |
| query_points: [N, 3] Float[Tensor, "Np 3"] |
| """ |
| device = x.device |
| if self.mlp_network_config is not None: |
| |
| x = self.mlp_net(x) |
| if x_fine is not None: |
| x_fine = self.mlp_net(x_fine) |
|
|
| |
|
|
| is_constrain_body = mesh_meta["is_constrain_body"].to( |
| self.smplx_model.is_constrain_body |
| ) |
| is_hands = (mesh_meta["is_rhand"] + mesh_meta["is_lhand"]).to( |
| self.smplx_model.is_rhand |
| ) |
| is_upper_body = mesh_meta["is_upper_body"].to(self.smplx_model.is_upper_body) |
| |
| |
| |
|
|
| constrain_dict = dict( |
| is_constrain_body=is_constrain_body, |
| is_hands=is_hands, |
| is_upper_body=is_upper_body, |
| ) |
|
|
| gs_attr: GaussianAppOutput = self.gs_net( |
| x, query_points, x_fine, constrain_dict |
| ) |
|
|
| return gs_attr |
|
|
|
|
| def test(): |
| import cv2 |
|
|
| human_model_path = "./pretrained_models/human_model_files" |
| smplx_data_root = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/gyeongsik/smplx_optimized/smplx_params_smoothed" |
| shape_param_file = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/gyeongsik/smplx_optimized/shape_param.json" |
|
|
| from core.models.rendering.smpl_x import read_smplx_param |
|
|
| batch_size = 1 |
| device = "cuda" |
| smplx_data, cam_param_list, ori_image_list = read_smplx_param( |
| smplx_data_root=smplx_data_root, shape_param_file=shape_param_file, batch_size=2 |
| ) |
| smplx_data_tmp = smplx_data |
| for k, v in smplx_data.items(): |
| smplx_data_tmp[k] = v.unsqueeze(0) |
| if (k == "betas") or (k == "face_offset") or (k == "joint_offset"): |
| smplx_data_tmp[k] = v[0].unsqueeze(0) |
| smplx_data = smplx_data_tmp |
|
|
| gs_render = GS3DRenderer( |
| human_model_path=human_model_path, |
| subdivide_num=2, |
| smpl_type="smplx", |
| feat_dim=64, |
| query_dim=64, |
| use_rgb=False, |
| sh_degree=3, |
| mlp_network_config=None, |
| xyz_offset_max_step=1.8 / 32, |
| ) |
|
|
| gs_render.to(device) |
| |
|
|
| c2w_list = [] |
| intr_list = [] |
| for cam_param in cam_param_list: |
| c2w = torch.eye(4).to(device) |
| c2w[:3, :3] = cam_param["R"] |
| c2w[:3, 3] = cam_param["t"] |
| c2w_list.append(c2w) |
| intr = torch.eye(4).to(device) |
| intr[0, 0] = cam_param["focal"][0] |
| intr[1, 1] = cam_param["focal"][1] |
| intr[0, 2] = cam_param["princpt"][0] |
| intr[1, 2] = cam_param["princpt"][1] |
| intr_list.append(intr) |
|
|
| c2w = torch.stack(c2w_list).unsqueeze(0) |
| intrinsic = torch.stack(intr_list).unsqueeze(0) |
|
|
| out = gs_render.forward( |
| gs_hidden_features=torch.zeros((batch_size, 2048, 64)).float().to(device), |
| query_points=None, |
| smplx_data=smplx_data, |
| c2w=c2w, |
| intrinsic=intrinsic, |
| height=int(cam_param_list[0]["princpt"][1]) * 2, |
| width=int(cam_param_list[0]["princpt"][0]) * 2, |
| background_color=torch.tensor([1.0, 1.0, 1.0]) |
| .float() |
| .view(1, 1, 3) |
| .repeat(batch_size, 2, 1) |
| .to(device), |
| debug=False, |
| ) |
|
|
| for k, v in out.items(): |
| if k == "comp_rgb_bg": |
| print("comp_rgb_bg", v) |
| continue |
| for b_idx in range(len(v)): |
| if k == "3dgs": |
| for v_idx in range(len(v[b_idx])): |
| v[b_idx][v_idx].save_ply(f"./debug_vis/{b_idx}_{v_idx}.ply") |
| continue |
| for v_idx in range(v.shape[1]): |
| save_path = os.path.join("./debug_vis", f"{b_idx}_{v_idx}_{k}.jpg") |
| cv2.imwrite( |
| save_path, |
| (v[b_idx, v_idx].detach().cpu().numpy() * 255).astype(np.uint8), |
| ) |
|
|