| |
| |
| |
| |
| |
| |
|
|
| import os |
| from typing import Any, Dict, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| from omegaconf import DictConfig, OmegaConf |
|
|
| from core.utils.hf_hub import wrap_model_hub |
|
|
| |
| DEFAULT_BATCH_SIZE = 40 |
|
|
|
|
| def parse_app_configs( |
| model_cards: Dict[str, Dict[str, str]], |
| ) -> Tuple[DictConfig, DictConfig]: |
| """Parse model configuration from environment variables and config files. |
| |
| Returns: |
| A tuple of (cfg, cfg_train) containing merged configurations. |
| """ |
| cli_cfg = OmegaConf.create() |
| cfg = OmegaConf.create() |
|
|
| app_model_name = os.environ.get("APP_MODEL_NAME") |
| if app_model_name is None: |
| raise NotImplementedError("APP_MODEL_NAME environment variable must be set") |
|
|
| model_card = model_cards[app_model_name] |
| model_path = model_card["model_path"] |
| model_config = model_card["model_config"] |
|
|
| cli_cfg.model_name = model_path |
|
|
| if model_config is not None: |
| cfg_train = OmegaConf.load(model_config) |
| cfg.source_size = cfg_train.dataset.source_image_res |
| try: |
| cfg.src_head_size = cfg_train.dataset.src_head_size |
| except AttributeError: |
| cfg.src_head_size = 112 |
| cfg.render_size = cfg_train.dataset.render_image.high |
|
|
| _relative_path = os.path.join( |
| cfg_train.experiment.parent, |
| cfg_train.experiment.child, |
| os.path.basename(cli_cfg.model_name).split("_")[-1], |
| ) |
|
|
| cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path) |
| cfg.image_dump = os.path.join("exps", "images", _relative_path) |
| cfg.video_dump = os.path.join("exps", "videos", _relative_path) |
|
|
| cfg.motion_video_read_fps = 6 |
| cfg.merge_with(cli_cfg) |
| cfg.setdefault("logger", "INFO") |
| assert cfg.model_name is not None, "model_name is required" |
|
|
| return cfg, cfg_train |
|
|
|
|
| def build_app_model(cfg: DictConfig) -> torch.nn.Module: |
| """Build and load the LHM model from pretrained weights. |
| |
| Args: |
| cfg: Configuration object containing model_name and other parameters. |
| |
| Returns: |
| Loaded LHM model ready for inference. |
| """ |
| from core.models import model_dict |
|
|
| model_cls = wrap_model_hub(model_dict["human_lrm_a4o"]) |
| model = model_cls.from_pretrained(cfg.model_name) |
| return model |
|
|
|
|
| @torch.no_grad() |
| def inference_results( |
| model: torch.nn.Module, |
| ref_img_tensors: torch.Tensor, |
| smplx_params: Dict[str, torch.Tensor], |
| motion_seq: Dict[str, Any], |
| video_size: int = 40, |
| ref_imgs_bool: Optional[torch.Tensor] = None, |
| visualized_center: bool = False, |
| batch_size: int = DEFAULT_BATCH_SIZE, |
| device: str = "cuda", |
| ) -> np.ndarray: |
| """Run inference on a motion sequence with batching to prevent OOM. |
| |
| Args: |
| model: LHM model for human animation. |
| ref_img_tensors: Reference image tensors of shape (N, C, H, W). |
| smplx_params: SMPL-X parameters for the initial pose. |
| motion_seq: Dictionary containing motion sequence data. |
| video_size: Total number of frames to render. |
| ref_imgs_bool: Boolean mask indicating which reference images to use. |
| visualized_center: If True, crops output to subject bounds with 10% padding. |
| batch_size: Number of frames to process in each batch. |
| device: Device to run inference on. |
| |
| Returns: |
| Rendered RGB frames as numpy array of shape (T, H, W, 3). |
| """ |
| offset_list = motion_seq.get("offset_list") |
| ori_h, ori_w = motion_seq.get("ori_size", (512, 512)) |
| output_rgb = torch.ones((ori_h, ori_w, 3)) |
| ref_imgs_bool = torch.ones( |
| ref_img_tensors.shape[0], dtype=torch.bool, device=device |
| ) |
|
|
| model_outputs = model.infer_single_view( |
| ref_img_tensors.unsqueeze(0).to(device), |
| None, |
| None, |
| render_c2ws=motion_seq["render_c2ws"].to(device), |
| render_intrs=motion_seq["render_intrs"].to(device), |
| render_bg_colors=motion_seq["render_bg_colors"].to(device), |
| smplx_params={k: v.to(device) for k, v in smplx_params.items()}, |
| ref_imgs_bool=ref_imgs_bool.unsqueeze(0), |
| ) |
|
|
| if len(model_outputs) == 7: |
| ( |
| gs_model_list, |
| query_points, |
| transform_mat_neutral_pose, |
| gs_hidden_features, |
| image_latents, |
| motion_emb, |
| pos_emb, |
| ) = model_outputs |
| else: |
| ( |
| gs_model_list, |
| query_points, |
| transform_mat_neutral_pose, |
| gs_hidden_features, |
| image_latents, |
| motion_emb, |
| ) = model_outputs |
| pos_emb = None |
|
|
| batch_smplx_params = { |
| "betas": smplx_params["betas"].to(device), |
| "transform_mat_neutral_pose": transform_mat_neutral_pose, |
| } |
|
|
| frame_varying_keys = [ |
| "root_pose", |
| "body_pose", |
| "jaw_pose", |
| "leye_pose", |
| "reye_pose", |
| "lhand_pose", |
| "rhand_pose", |
| "trans", |
| "focal", |
| "princpt", |
| "img_size_wh", |
| "expr", |
| ] |
|
|
| batch_rgb_list = [] |
| batch_mask_list = [] |
| num_batches = (video_size + batch_size - 1) // batch_size |
|
|
| for batch_idx in range(0, video_size, batch_size): |
| current_batch = batch_idx // batch_size + 1 |
| print(f"Processing batch {current_batch}/{num_batches}") |
|
|
| batch_smplx_params.update( |
| { |
| key: motion_seq["smplx_params"][key][ |
| :, batch_idx : batch_idx + batch_size |
| ].to(device) |
| for key in frame_varying_keys |
| } |
| ) |
|
|
| mask_seqs = ( |
| motion_seq.get("masks", [])[batch_idx : batch_idx + batch_size] |
| if "masks" in motion_seq |
| else None |
| ) |
|
|
| anim_kwargs = { |
| "gs_model_list": gs_model_list, |
| "query_points": query_points, |
| "smplx_params": batch_smplx_params, |
| "render_c2ws": motion_seq["render_c2ws"][ |
| :, batch_idx : batch_idx + batch_size |
| ].to(device), |
| "render_intrs": motion_seq["render_intrs"][ |
| :, batch_idx : batch_idx + batch_size |
| ].to(device), |
| "render_bg_colors": motion_seq["render_bg_colors"][ |
| :, batch_idx : batch_idx + batch_size |
| ].to(device), |
| "gs_hidden_features": gs_hidden_features, |
| "image_latents": image_latents, |
| "motion_emb": motion_emb, |
| } |
|
|
| if pos_emb is not None: |
| anim_kwargs["pos_emb"] = pos_emb |
| if offset_list is not None: |
| anim_kwargs["offset_list"] = offset_list[batch_idx : batch_idx + batch_size] |
| if mask_seqs is not None: |
| anim_kwargs["mask_seqs"] = mask_seqs |
| if output_rgb is not None: |
| anim_kwargs["output_rgb"] = output_rgb |
|
|
| batch_rgb, batch_mask = model.animation_infer(**anim_kwargs) |
| batch_rgb_list.append((batch_rgb.clamp(0, 1) * 255).to(torch.uint8).numpy()) |
| batch_mask_list.append((batch_mask.clamp(0, 1) * 255).to(torch.uint8).numpy()) |
| |
| print("End of inference") |
| |
| if visualized_center: |
| mask_numpy = np.concatenate(batch_mask_list, axis=0) |
| h_indices, w_indices = np.where(mask_numpy > 0.25)[1:] |
|
|
| if len(h_indices) > 0 and len(w_indices) > 0: |
| top, bottom = h_indices.min(), h_indices.max() |
| left, right = w_indices.min(), w_indices.max() |
|
|
| center_y, center_x = (top + bottom) / 2, (left + right) / 2 |
| height, width = bottom - top, right - left |
| new_height, new_width = height * 1.1, width * 1.1 |
|
|
| top_new = max(0, int(center_y - new_height / 2)) |
| bottom_new = int(center_y + new_height / 2) |
| left_new = max(0, int(center_x - new_width / 2)) |
| right_new = int(center_x + new_width / 2) |
|
|
| rgb = np.concatenate(batch_rgb_list, axis=0) |
| output = rgb[:, top_new:bottom_new, left_new:right_new] |
| else: |
| output = np.concatenate(batch_rgb_list, axis=0) |
| else: |
| output = np.concatenate(batch_rgb_list, axis=0) |
|
|
| return output |
|
|