| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import sys |
|
|
| sys.path.append("./") |
| import time |
|
|
| import cv2 |
| import numpy as np |
| import torch |
|
|
| torch._dynamo.config.disable = True |
| import glob |
| import json |
| from typing import Dict, Optional, Tuple |
|
|
| import torch |
| from accelerate import Accelerator |
| from omegaconf import DictConfig, OmegaConf |
| from PIL import Image |
|
|
| from core.runners.infer.utils import ( |
| prepare_motion_seqs_cano, |
| prepare_motion_seqs_eval, |
| ) |
| from core.utils.hf_hub import wrap_model_hub |
|
|
|
|
| def resize_with_padding(images, target_size): |
| """ |
| Combine 4 images into a 2x2 grid, then resize with aspect ratio preserved, |
| and pad with white to match the target size. |
| |
| Args: |
| images: List[np.ndarray], each of shape (H, W), dtype usually uint8 |
| target_size: tuple (H1, W1) |
| |
| Returns: |
| np.ndarray: Output image of shape (H1, W1), dtype uint8, padded with white (255) |
| """ |
| assert len(images) == 4, "Exactly 4 images are required" |
|
|
| H, W = images[0].shape[:2] |
| assert all( |
| img.shape[:2] == (H, W) for img in images |
| ), "All images must have the same shape (H, W)" |
|
|
| |
| top_row = np.hstack([images[0], images[1]]) |
| bottom_row = np.hstack([images[2], images[3]]) |
| combined = np.vstack([top_row, bottom_row]) |
|
|
| Hc, Wc, _ = combined.shape |
|
|
| target_h, target_w = target_size |
|
|
| |
| scale_h = target_h / Hc |
| scale_w = target_w / Wc |
| scale = min(scale_h, scale_w) |
|
|
| new_h = int(Hc * scale) |
| new_w = int(Wc * scale) |
|
|
| |
| resized = cv2.resize(combined, (new_w, new_h), interpolation=cv2.INTER_AREA) |
|
|
| |
| padded = np.full((target_h, target_w, 3), 255, dtype=np.uint8) |
|
|
| |
| top = (target_h - new_h) // 2 |
| left = (target_w - new_w) // 2 |
|
|
| |
| padded[top : top + new_h, left : left + new_w] = resized |
|
|
| return padded |
|
|
|
|
| DATASETS_CONFIG = { |
| "eval": dict( |
| root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/LHM_video_dataset/", |
| meta_path="./train_data/ClothVideo/label/valid_LHM_dataset_train_val_100.json", |
| ), |
| "dataset5": dict( |
| root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v5_tar/", |
| meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv5_val_filter460-self-rotated-69.json", |
| ), |
| "dataset6": dict( |
| root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v6_tar/", |
| meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv6_test_100-self-rotated-25.json", |
| ), |
| "synthetic": dict( |
| root_dirs="/mnt/workspaces/dataset/video_human_datasets/synthetic_data_tar/", |
| meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_LHM_synthetic_dataset_val_17.json", |
| ), |
| "dataset5_train": dict( |
| root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v5_tar/", |
| meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv5_train_filter40K-self-rotated-7147.json", |
| ), |
| "dataset6_train": dict( |
| root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v6_tar/", |
| meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv6_train_5W-self-rotated-11341.json", |
| ), |
| "eval_train": dict( |
| root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/LHM_video_dataset/", |
| meta_path="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/clean_labels/valid_LHM_dataset_train_filter_16W.json", |
| ), |
| "in_the_wild": dict( |
| root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/PFLHM-Causal-Video/", |
| meta_path="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/clean_labels/eval_sparse_lhm_wild.json", |
| ), |
| "in_the_wild_people_snapshot": dict( |
| root_dirs="/mnt/workspaces/dataset/people_snapshot/", |
| meta_path="/mnt/workspaces/dataset/people_snapshot/peoplesnapshot.json", |
| ), |
| "in_the_wild_rec_mv": dict( |
| root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/rec_mv_dataset", |
| meta_path="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/rec_mv_dataset/rec_mv_dataset.json", |
| ), |
| "in_the_wild_mvhumannet": dict( |
| root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/mvhumannet/", |
| meta_path="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/mvhumannet/mvhumannet.json", |
| ), |
| "dataset_train_real": dict( |
| root_dirs="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/LHM_video_dataset/", |
| meta_path="/mnt/workspaces/rmbg/papers/heyuan/stablenorml/d6f8c7e1a2b3c4d5e6f7a8b9c0d1e2f3/tmp/video_human_datasets/clean_labels/valid_LHM_dataset_train_filter_16W.json", |
| ), |
| "dataset5_train_real": dict( |
| root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v5_tar/", |
| meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv5_train_filter40K.json", |
| ), |
| "dataset6_train_real": dict( |
| root_dirs="/mnt/workspaces/dataset/video_human_datasets/selected_dataset_v6_tar/", |
| meta_path="/mnt/workspaces/dataset/video_human_datasets/clean_labels/valid_selected_datasetv6_train_5W.json", |
| ), |
| "web_dresscode": dict( |
| root_dirs="/mnt/workspaces/datasets/lhm_human_datasets/DressCode/", |
| meta_path=None, |
| ), |
| "hweb_hero": dict( |
| root_dirs="/mnt/workspaces/datasets/lhm_human_datasets/pinterest_download_0903_gen_full_fixed_view_random_pose_2_filtered", |
| meta_path=None, |
| ), |
| } |
|
|
|
|
| def obtain_motion_sequence(motion_seqs): |
| motion_seqs = sorted(glob.glob(os.path.join(motion_seqs, "*.json"))) |
|
|
| smplx_list = [] |
|
|
| for motion in motion_seqs: |
|
|
| with open(motion) as reader: |
| smplx_params = json.load(reader) |
|
|
| flame_path = motion.replace("smplx_params", "flame_params") |
| if os.path.exists(flame_path): |
| with open(flame_path) as reader: |
| flame_params = json.load(reader) |
| smplx_params["expr"] = torch.FloatTensor(flame_params["expcode"]) |
|
|
| |
| smplx_params["jaw_pose"] = torch.FloatTensor(flame_params["posecode"][3:]) |
| smplx_params["leye_pose"] = torch.FloatTensor(flame_params["eyecode"][:3]) |
| smplx_params["reye_pose"] = torch.FloatTensor(flame_params["eyecode"][3:]) |
| else: |
| smplx_params["expr"] = torch.FloatTensor([0.0] * 100) |
|
|
| smplx_list.append(smplx_params) |
|
|
| return smplx_list |
|
|
|
|
| def _build_model(cfg): |
| from core.models import model_dict |
|
|
| hf_model_cls = wrap_model_hub(model_dict["human_lrm_a4o"]) |
| model = hf_model_cls.from_pretrained(cfg.model_name) |
|
|
| return model |
|
|
|
|
| def get_smplx_params(data, device): |
| smplx_params = {} |
| smplx_keys = [ |
| "root_pose", |
| "body_pose", |
| "jaw_pose", |
| "leye_pose", |
| "reye_pose", |
| "lhand_pose", |
| "rhand_pose", |
| "expr", |
| "trans", |
| "betas", |
| ] |
| for k, v in data.items(): |
| if k in smplx_keys: |
| |
| smplx_params[k] = data[k].unsqueeze(0).to(device) |
| return smplx_params |
|
|
|
|
| def animation_infer( |
| renderer, |
| gs_model_list, |
| query_points, |
| smplx_params, |
| render_c2ws, |
| render_intrs, |
| render_bg_colors, |
| ) -> dict: |
| """Render animation frames in parallel without redundant computations. |
| |
| Args: |
| renderer: The rendering engine |
| gs_model_list: List of Gaussian models |
| query_points: 3D query points |
| smplx_params: SMPL-X parameters |
| render_c2ws: Camera-to-world matrices |
| render_intrs: Intrinsic camera parameters |
| render_bg_colors: Background colors |
| |
| Returns: |
| Dictionary of rendered results (rgb, mask, depth, etc.) |
| """ |
|
|
| render_h, render_w = int(render_intrs[0, 0, 1, 2] * 2), int( |
| render_intrs[0, 0, 0, 2] * 2 |
| ) |
| |
| render_res_list = [] |
| num_views = render_c2ws.shape[1] |
|
|
| start_time = time.time() |
|
|
| |
| render_res_list = [ |
| renderer.forward_animate_gs( |
| gs_model_list, |
| query_points, |
| renderer.get_single_view_smpl_data(smplx_params, view_idx), |
| render_c2ws[:, view_idx : view_idx + 1], |
| render_intrs[:, view_idx : view_idx + 1], |
| render_h, |
| render_w, |
| render_bg_colors[:, view_idx : view_idx + 1], |
| ) |
| for view_idx in range(num_views) |
| ] |
|
|
| |
| avg_time = (time.time() - start_time) / num_views |
| print(f"Average time per frame: {avg_time:.4f}s") |
|
|
| |
| out = defaultdict(list) |
| for res in render_res_list: |
| for k, v in res.items(): |
| out[k].append(v.detach().cpu() if isinstance(v, torch.Tensor) else v) |
|
|
| |
| for k, v in out.items(): |
| if isinstance(v[0], torch.Tensor): |
| out[k] = torch.concat(v, dim=1) |
| if k in {"comp_rgb", "comp_mask", "comp_depth"}: |
| out[k] = out[k][0].permute( |
| 0, 2, 3, 1 |
| ) |
|
|
| return out |
|
|
|
|
| @torch.no_grad() |
| def inference_results( |
| lhm: torch.nn.Module, |
| batch: Dict, |
| smplx_params: Dict, |
| motion_seq: Dict, |
| camera_size: int = 40, |
| ref_imgs_bool=None, |
| batch_size: int = 40, |
| device: str = "cuda", |
| ) -> np.ndarray: |
| """Perform inference on a motion sequence in batches to avoid OOM.""" |
|
|
| offset_list = motion_seq["offset_list"] |
| ori_h, ori_w = motion_seq["ori_size"] |
|
|
| output_rgb = torch.ones((ori_h, ori_w, 3)) |
|
|
| |
| ( |
| gs_model_list, |
| query_points, |
| transform_mat_neutral_pose, |
| gs_hidden_features, |
| image_latents, |
| motion_emb, |
| pos_emb, |
| ) = lhm.infer_single_view( |
| batch["source_rgbs"].unsqueeze(0).to(device), |
| None, |
| None, |
| render_c2ws=motion_seq["render_c2ws"].to(device), |
| render_intrs=motion_seq["render_intrs"].to(device), |
| render_bg_colors=motion_seq["render_bg_colors"].to(device), |
| smplx_params={k: v.to(device) for k, v in smplx_params.items()}, |
| ref_imgs_bool=ref_imgs_bool, |
| ) |
|
|
| |
| batch_smplx_params = { |
| "betas": smplx_params["betas"].to(device), |
| "transform_mat_neutral_pose": transform_mat_neutral_pose, |
| } |
|
|
| keys = [ |
| "root_pose", |
| "body_pose", |
| "jaw_pose", |
| "leye_pose", |
| "reye_pose", |
| "lhand_pose", |
| "rhand_pose", |
| "trans", |
| "focal", |
| "princpt", |
| "img_size_wh", |
| "expr", |
| ] |
|
|
| batch_list = [] |
| batch_mask_list = [] |
| for batch_i in range(0, camera_size, batch_size): |
| print( |
| f"Processing batch {batch_i//batch_size + 1}/{(camera_size + batch_size - 1)//batch_size}" |
| ) |
|
|
| |
| batch_smplx_params.update( |
| { |
| key: motion_seq["smplx_params"][key][ |
| :, batch_i : batch_i + batch_size |
| ].to(device) |
| for key in keys |
| } |
| ) |
|
|
| |
| batch_rgb, batch_mask = lhm.animation_infer( |
| gs_model_list, |
| query_points, |
| batch_smplx_params, |
| render_c2ws=motion_seq["render_c2ws"][:, batch_i : batch_i + batch_size].to( |
| device |
| ), |
| render_intrs=motion_seq["render_intrs"][ |
| :, batch_i : batch_i + batch_size |
| ].to(device), |
| render_bg_colors=motion_seq["render_bg_colors"][ |
| :, batch_i : batch_i + batch_size |
| ].to(device), |
| gs_hidden_features=gs_hidden_features, |
| image_latents=image_latents, |
| motion_emb=motion_emb, |
| pos_emb=pos_emb, |
| offset_list=offset_list[batch_i : batch_i + batch_size], |
| mask_seqs=motion_seq["masks"][batch_i : batch_i + batch_size], |
| output_rgb=output_rgb, |
| ) |
|
|
| |
| batch_list.append((batch_rgb.clamp(0, 1) * 255).to(torch.uint8).numpy()) |
| batch_mask_list.append((batch_mask.clamp(0, 1) * 255).to(torch.uint8).numpy()) |
|
|
| return np.concatenate(batch_list, axis=0), np.concatenate(batch_mask_list, axis=0) |
|
|
|
|
| @torch.no_grad() |
| def inference_gs_model( |
| lhm: torch.nn.Module, |
| batch: Dict, |
| smplx_params: Dict, |
| motion_seq: Dict, |
| camera_size: int = 40, |
| ref_imgs_bool=None, |
| batch_size: int = 40, |
| device: str = "cuda", |
| ) -> np.ndarray: |
| """Perform inference on a motion sequence in batches to avoid OOM.""" |
|
|
| |
| ( |
| gs_model_list, |
| query_points, |
| transform_mat_neutral_pose, |
| gs_hidden_features, |
| image_latents, |
| motion_emb, |
| ) = lhm.infer_single_view( |
| batch["source_rgbs"].unsqueeze(0).to(device), |
| None, |
| None, |
| render_c2ws=motion_seq["render_c2ws"].to(device), |
| render_intrs=motion_seq["render_intrs"].to(device), |
| render_bg_colors=motion_seq["render_bg_colors"].to(device), |
| smplx_params={k: v.to(device) for k, v in smplx_params.items()}, |
| ref_imgs_bool=ref_imgs_bool, |
| ) |
|
|
| |
| batch_smplx_params = { |
| "betas": smplx_params["betas"].to(device), |
| "transform_mat_neutral_pose": transform_mat_neutral_pose, |
| } |
|
|
| keys = [ |
| "root_pose", |
| "body_pose", |
| "jaw_pose", |
| "leye_pose", |
| "reye_pose", |
| "lhand_pose", |
| "rhand_pose", |
| "trans", |
| "focal", |
| "princpt", |
| "img_size_wh", |
| "expr", |
| ] |
|
|
| batch_list = [] |
| for batch_i in range(0, camera_size, batch_size): |
| print( |
| f"Processing batch {batch_i//batch_size + 1}/{(camera_size + batch_size - 1)//batch_size}" |
| ) |
|
|
| |
| batch_smplx_params.update( |
| { |
| key: motion_seq["smplx_params"][key][ |
| :, batch_i : batch_i + batch_size |
| ].to(device) |
| for key in keys |
| } |
| ) |
|
|
| |
| gs_model = lhm.inference_gs( |
| gs_model_list, |
| query_points, |
| batch_smplx_params, |
| render_c2ws=motion_seq["render_c2ws"][:, batch_i : batch_i + batch_size].to( |
| device |
| ), |
| render_intrs=motion_seq["render_intrs"][ |
| :, batch_i : batch_i + batch_size |
| ].to(device), |
| render_bg_colors=motion_seq["render_bg_colors"][ |
| :, batch_i : batch_i + batch_size |
| ].to(device), |
| gs_hidden_features=gs_hidden_features, |
| image_latents=image_latents, |
| motion_emb=motion_emb, |
| ) |
| return gs_model |
|
|
|
|
| @torch.no_grad() |
| def lhm_validation_inference( |
| lhm: Optional[torch.nn.Module], |
| save_path: str, |
| view: int = 16, |
| cfg: Optional[Dict] = None, |
| motion_path: Optional[str] = None, |
| exp_name: str = "eval", |
| debug: bool = False, |
| split: int = 1, |
| gpus: int = 0, |
| ) -> None: |
| """Run validation inference on the model.""" |
| if lhm is not None: |
| lhm.cuda().eval() |
|
|
| assert motion_path is not None |
| cfg = cfg or {} |
|
|
| |
| gt_save_path = os.path.join(os.path.dirname(save_path), "gt") |
| gt_mask_save_path = os.path.join(os.path.dirname(save_path), "mask") |
| os.makedirs(gt_save_path, exist_ok=True) |
| os.makedirs(gt_mask_save_path, exist_ok=True) |
|
|
| |
| dataset_config = DATASETS_CONFIG[exp_name] |
| kwargs = {} |
| if exp_name == "eval" or exp_name == "eval_train": |
| from core.datasets.video_human_lhm_dataset_a4o import ( |
| VideoHumanLHMA4ODatasetEval as VideoDataset, |
| ) |
| elif "in_the_wild" in exp_name: |
| from core.datasets.video_in_the_wild_dataset import ( |
| VideoInTheWildEval as VideoDataset, |
| ) |
|
|
| kwargs["heuristic_sampling"] = True |
| elif "hweb" in exp_name: |
| from core.datasets.video_in_the_wild_web_dataset import ( |
| WebInTheWildHeurEval as VideoDataset, |
| ) |
|
|
| kwargs["heuristic_sampling"] = False |
| elif "web" in exp_name: |
| from core.datasets.video_in_the_wild_web_dataset import ( |
| WebInTheWildEval as VideoDataset, |
| ) |
|
|
| kwargs["heuristic_sampling"] = False |
| else: |
| from core.datasets.video_human_dataset_a4o import ( |
| VideoHumanA4ODatasetEval as VideoDataset, |
| ) |
|
|
| dataset = VideoDataset( |
| root_dirs=dataset_config["root_dirs"], |
| meta_path=dataset_config["meta_path"], |
| sample_side_views=7, |
| render_image_res_low=420, |
| render_image_res_high=420, |
| render_region_size=(682, 420), |
| source_image_res=512, |
| debug=False, |
| use_flame=True, |
| ref_img_size=view, |
| womask=True, |
| is_val=True, |
| processing_pipeline=[ |
| dict(name="PadRatioWithScale", target_ratio=5 / 3, tgt_max_size_list=[840]), |
| dict(name="ToTensor"), |
| ], |
| **kwargs, |
| ) |
|
|
| |
| smplx_path = os.path.join(motion_path, "smplx_params") |
| mask_path = os.path.join(motion_path, "samurai_seg") |
| motion_seqs = sorted(glob.glob(os.path.join(smplx_path, "*.json"))) |
| motion_id_seqs = [ |
| motion_seq.split("/")[-1].replace(".json", "") for motion_seq in motion_seqs |
| ] |
| mask_paths = [ |
| os.path.join(mask_path, motion_id_seq + ".png") |
| for motion_id_seq in motion_id_seqs |
| ] |
|
|
| motion_seqs = prepare_motion_seqs_cano( |
| obtain_motion_sequence(smplx_path), |
| mask_paths=mask_paths, |
| bg_color=1.0, |
| aspect_standard=5.0 / 3, |
| enlarge_ratio=[1.0, 1.0], |
| tgt_size=cfg.get("render_size", 420), |
| render_image_res=cfg.get("render_size", 420), |
| need_mask=cfg.get("motion_img_need_mask", False), |
| vis_motion=cfg.get("vis_motion", False), |
| motion_size=100 if debug else 1000, |
| specific_id_list=None, |
| ) |
|
|
| motion_id = motion_seqs["motion_id"] |
|
|
| |
| dataset_size = len(dataset) |
| bins = int(np.ceil(dataset_size / split)) |
|
|
| for idx in range(bins * gpus, bins * (gpus + 1)): |
| try: |
| item = dataset.__getitem__(idx, view) |
| except: |
| continue |
|
|
| uid = item["uid"] |
| save_folder = os.path.join(save_path, uid) |
| os.makedirs(save_folder, exist_ok=True) |
|
|
| print(f"Processing {uid}, idx: {idx}") |
|
|
| video_dir = os.path.join(save_path, f"view_{view:03d}") |
| os.makedirs(video_dir, exist_ok=True) |
|
|
| |
| motion_seqs["smplx_params"]["betas"] = item["betas"].unsqueeze(0) |
| try: |
| rgbs, masks = inference_results( |
| lhm, |
| item, |
| motion_seqs["smplx_params"], |
| motion_seqs, |
| camera_size=motion_seqs["smplx_params"]["root_pose"].shape[1], |
| ref_imgs_bool=item["ref_imgs_bool"].unsqueeze(0), |
| ) |
| except: |
| print("Error in infering") |
| continue |
|
|
| for rgb, mask, mi in zip(rgbs, masks, motion_id): |
| pred_image = Image.fromarray(rgb) |
| pred_mask = Image.fromarray(mask) |
| idx = mi |
| pred_name = f"rgb_{idx:05d}.png" |
| mask_pred_name = f"mask_{idx:05d}.png" |
| save_img_path = os.path.join(save_folder, pred_name) |
| save_mask_path = os.path.join(save_folder, mask_pred_name) |
| pred_image.save(save_img_path) |
| pred_mask.save(save_mask_path) |
|
|
|
|
| @torch.no_grad() |
| def lhm_validation_inference_gs( |
| lhm: Optional[torch.nn.Module], |
| save_path: str, |
| view: int = 16, |
| cfg: Optional[Dict] = None, |
| motion_path: Optional[str] = None, |
| exp_name: str = "eval", |
| debug: bool = False, |
| split: int = 1, |
| gpus: int = 0, |
| ) -> None: |
| """Run validation inference on the model.""" |
| if lhm is not None: |
| lhm.cuda().eval() |
|
|
| assert motion_path is not None |
| cfg = cfg or {} |
|
|
| |
| dataset_config = DATASETS_CONFIG[exp_name] |
| kwargs = {} |
| if exp_name == "eval" or exp_name == "eval_train": |
| from core.datasets.video_human_lhm_dataset_a4o import ( |
| VideoHumanLHMA4ODatasetEval as VideoDataset, |
| ) |
| elif "in_the_wild" in exp_name: |
| from core.datasets.video_in_the_wild_dataset import ( |
| VideoInTheWildEval as VideoDataset, |
| ) |
|
|
| kwargs["heuristic_sampling"] = True |
| else: |
| from core.datasets.video_human_dataset_a4o import ( |
| VideoHumanA4ODatasetEval as VideoDataset, |
| ) |
|
|
| dataset = VideoDataset( |
| root_dirs=dataset_config["root_dirs"], |
| meta_path=dataset_config["meta_path"], |
| sample_side_views=7, |
| render_image_res_low=420, |
| render_image_res_high=420, |
| render_region_size=(700, 420), |
| source_image_res=420, |
| debug=False, |
| use_flame=True, |
| ref_img_size=view, |
| womask=True, |
| is_val=True, |
| **kwargs, |
| ) |
|
|
| |
| smplx_path = os.path.join(motion_path, "smplx_params") |
| motion_seqs = prepare_motion_seqs_eval( |
| obtain_motion_sequence(smplx_path), |
| bg_color=1.0, |
| aspect_standard=5.0 / 3, |
| enlarge_ratio=[1.0, 1.0], |
| render_image_res=cfg.get("render_size", 384), |
| need_mask=cfg.get("motion_img_need_mask", False), |
| vis_motion=cfg.get("vis_motion", False), |
| motion_size=1, |
| specific_id_list=None, |
| ) |
|
|
| |
|
|
| dataset_size = len(dataset) |
| bins = int(np.ceil(dataset_size / split)) |
|
|
| for idx in range(bins * gpus, bins * (gpus + 1)): |
|
|
| try: |
| item = dataset.__getitem__(idx, view) |
| except: |
| continue |
|
|
| uid = item["uid"] |
| print(f"Processing {uid}, idx: {idx}") |
|
|
| gs_dir = os.path.join(save_path, f"view_{view:03d}") |
| os.makedirs(gs_dir, exist_ok=True) |
| gs_file_path = os.path.join(gs_dir, f"{uid}.ply") |
|
|
| if os.path.exists(gs_file_path): |
| continue |
|
|
| |
| motion_seqs["smplx_params"]["betas"] = item["betas"] |
|
|
| gs_model = inference_gs_model( |
| lhm, |
| item, |
| motion_seqs["smplx_params"], |
| motion_seqs, |
| camera_size=motion_seqs["smplx_params"]["root_pose"].shape[1], |
| ) |
| print(f"generated GS Model!") |
|
|
| gs_model.save_ply(gs_file_path) |
|
|
|
|
| def parse_configs() -> Tuple[DictConfig, str]: |
| """Parse configuration from environment variables and model config files. |
| |
| Returns: |
| Tuple containing: |
| - Merged configuration object |
| - Model name extracted from MODEL_PATH |
| """ |
| cli_cfg = OmegaConf.create() |
| cfg = OmegaConf.create() |
|
|
| |
| model_path = os.environ.get("MODEL_PATH", "").rstrip("/") |
| if not model_path: |
| raise ValueError("MODEL_PATH environment variable is required") |
|
|
| cli_cfg.model_name = model_path |
| model_name = model_path.split("/")[-2] |
|
|
| |
| if model_config := os.environ.get("APP_MODEL_CONFIG"): |
| cfg_train = OmegaConf.load(model_config) |
|
|
| |
| cfg.update( |
| { |
| "source_size": cfg_train.dataset.source_image_res, |
| "src_head_size": getattr(cfg_train.dataset, "src_head_size", 112), |
| "render_size": cfg_train.dataset.render_image.high, |
| "motion_video_read_fps": 30, |
| "logger": "INFO", |
| } |
| ) |
|
|
| |
| exp_id = os.path.basename(model_path).split("_")[-1] |
| relative_path = os.path.join( |
| cfg_train.experiment.parent, cfg_train.experiment.child, exp_id |
| ) |
|
|
| cfg.update( |
| { |
| "save_tmp_dump": os.path.join("exps", "save_tmp", relative_path), |
| "image_dump": os.path.join("exps", "images", relative_path), |
| "video_dump": os.path.join("exps", "videos", relative_path), |
| } |
| ) |
|
|
| |
| cfg.merge_with(cli_cfg) |
| if not cfg.get("model_name"): |
| raise ValueError("model_name is required in configuration") |
|
|
| return cfg, model_name |
|
|
|
|
| def get_parse(): |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Inference_Config") |
| parser.add_argument("-c", "--config", required=True, help="config file path") |
| parser.add_argument("-w", "--ckpt", required=True, help="model checkpoint") |
| parser.add_argument("-v", "--view", help="input views", default=16, type=int) |
| parser.add_argument("-p", "--pre", help="exp_name", type=str) |
| parser.add_argument("-m", "--motion", help="motion_path", type=str) |
| parser.add_argument( |
| "-s", |
| "--split", |
| help="split_dataset, used for distribution inference.", |
| type=int, |
| default=1, |
| ) |
| parser.add_argument("-g", "--gpus", help="current gpu id", type=int, default=0) |
| parser.add_argument("--gs", help="output gaussian model", action="store_true") |
| parser.add_argument( |
| "--output", help="output_cano_folder", default="./debug/cano_output", type=str |
| ) |
| parser.add_argument("--debug", help="motion_path", action="store_true") |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def main(): |
|
|
| args = get_parse() |
|
|
| os.environ.update( |
| { |
| "APP_ENABLED": "1", |
| "APP_MODEL_CONFIG": args.config, |
| "APP_TYPE": "infer.human_lrm_a4o", |
| "APP_TYPE": "infer.human_lrm_a4o", |
| "NUMBA_THREADING_LAYER": "omp", |
| "MODEL_PATH": args.ckpt, |
| } |
| ) |
|
|
| exp_name = args.pre |
| motion = args.motion |
|
|
| motion = motion[:-1] if motion[-1] == "/" else motion |
| cfg, model_name = parse_configs() |
| |
| assert exp_name in list(DATASETS_CONFIG.keys()) |
|
|
| output_path = args.output |
|
|
| lhm = _build_model(cfg) |
|
|
| if not args.gs: |
| os.makedirs(output_path, exist_ok=True) |
| lhm_validation_inference( |
| lhm, |
| output_path, |
| args.view, |
| cfg, |
| motion, |
| exp_name, |
| debug=args.debug, |
| split=args.split, |
| gpus=args.gpus, |
| ) |
| else: |
| raise NotImplementedError |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|