|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import cv2 |
|
import sys |
|
import base64 |
|
import subprocess |
|
|
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import argparse |
|
from omegaconf import OmegaConf |
|
|
|
import torch |
|
import zipfile |
|
from glob import glob |
|
import moviepy.editor as mpy |
|
from tools.flame_tracking_single_image import FlameTrackingSingleImage |
|
from lam.runners.infer.head_utils import prepare_motion_seqs, preprocess_image |
|
|
|
try: |
|
import spaces |
|
except: |
|
pass |
|
|
|
|
|
h5_rendering = True |
|
from gradio_gaussian_render import gaussian_render |
|
|
|
|
|
def launch_env_not_compile_with_cuda(): |
|
os.system('pip install chumpy') |
|
os.system('pip install numpy==1.23.0') |
|
os.system( |
|
'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt251/download.html' |
|
) |
|
|
|
|
|
def assert_input_image(input_image): |
|
if input_image is None: |
|
raise gr.Error('No image selected or uploaded!') |
|
|
|
|
|
def prepare_working_dir(): |
|
import tempfile |
|
working_dir = tempfile.TemporaryDirectory() |
|
return working_dir |
|
|
|
|
|
def init_preprocessor(): |
|
from lam.utils.preprocess import Preprocessor |
|
global preprocessor |
|
preprocessor = Preprocessor() |
|
|
|
|
|
def preprocess_fn(image_in: np.ndarray, remove_bg: bool, recenter: bool, |
|
working_dir): |
|
image_raw = os.path.join(working_dir.name, 'raw.png') |
|
with Image.fromarray(image_in) as img: |
|
img.save(image_raw) |
|
image_out = os.path.join(working_dir.name, 'rembg.png') |
|
success = preprocessor.preprocess(image_path=image_raw, |
|
save_path=image_out, |
|
rmbg=remove_bg, |
|
recenter=recenter) |
|
assert success, f'Failed under preprocess_fn!' |
|
return image_out |
|
|
|
|
|
def get_image_base64(path): |
|
with open(path, 'rb') as image_file: |
|
encoded_string = base64.b64encode(image_file.read()).decode() |
|
return f'data:image/png;base64,{encoded_string}' |
|
|
|
|
|
def do_softlink(working_dir, tgt_dir="./runtime_data"): |
|
os.system(f"rm {tgt_dir}") |
|
cmd = f"ln -s {working_dir} ./runtime_data" |
|
os.system(cmd) |
|
return cmd |
|
|
|
|
|
def doRender(working_dir): |
|
working_dir = working_dir.name |
|
cmd = do_softlink(working_dir) |
|
print('='*100, "\n"+cmd, '\ndo render', "\n"+"="*100) |
|
|
|
|
|
def save_images2video(img_lst, v_pth, fps): |
|
from moviepy.editor import ImageSequenceClip |
|
|
|
images = [image.astype(np.uint8) for image in img_lst] |
|
|
|
|
|
clip = ImageSequenceClip(images, fps=fps) |
|
|
|
|
|
clip.write_videofile(v_pth, codec='libx264') |
|
|
|
print(f"Video saved successfully at {v_pth}") |
|
|
|
|
|
def add_audio_to_video(video_path, out_path, audio_path, fps=30): |
|
|
|
from moviepy.editor import VideoFileClip, AudioFileClip |
|
|
|
|
|
video_clip = VideoFileClip(video_path) |
|
|
|
|
|
audio_clip = AudioFileClip(audio_path) |
|
|
|
|
|
""" |
|
if audio_clip.duration > 10: |
|
audio_clip = audio_clip.subclip(0, 10) |
|
""" |
|
|
|
|
|
video_clip_with_audio = video_clip.set_audio(audio_clip) |
|
|
|
|
|
video_clip_with_audio.write_videofile(out_path, codec='libx264', audio_codec='aac', fps=fps) |
|
|
|
print(f"Audio added successfully at {out_path}") |
|
|
|
def parse_configs(): |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", type=str) |
|
parser.add_argument("--infer", type=str) |
|
args, unknown = parser.parse_known_args() |
|
|
|
cfg = OmegaConf.create() |
|
cli_cfg = OmegaConf.from_cli(unknown) |
|
|
|
|
|
if os.environ.get("APP_INFER") is not None: |
|
args.infer = os.environ.get("APP_INFER") |
|
if os.environ.get("APP_MODEL_NAME") is not None: |
|
cli_cfg.model_name = os.environ.get("APP_MODEL_NAME") |
|
|
|
args.config = args.infer if args.config is None else args.config |
|
|
|
if args.config is not None: |
|
cfg_train = OmegaConf.load(args.config) |
|
cfg.source_size = cfg_train.dataset.source_image_res |
|
try: |
|
cfg.src_head_size = cfg_train.dataset.src_head_size |
|
except: |
|
cfg.src_head_size = 112 |
|
cfg.render_size = cfg_train.dataset.render_image.high |
|
_relative_path = os.path.join( |
|
cfg_train.experiment.parent, |
|
cfg_train.experiment.child, |
|
os.path.basename(cli_cfg.model_name).split("_")[-1], |
|
) |
|
|
|
cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path) |
|
cfg.image_dump = os.path.join("exps", "images", _relative_path) |
|
cfg.video_dump = os.path.join("exps", "videos", _relative_path) |
|
|
|
if args.infer is not None: |
|
cfg_infer = OmegaConf.load(args.infer) |
|
cfg.merge_with(cfg_infer) |
|
cfg.setdefault( |
|
"save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp") |
|
) |
|
cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images")) |
|
cfg.setdefault( |
|
"video_dump", os.path.join("dumps", cli_cfg.model_name, "videos") |
|
) |
|
cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes")) |
|
|
|
cfg.motion_video_read_fps = 30 |
|
cfg.merge_with(cli_cfg) |
|
|
|
cfg.setdefault("logger", "INFO") |
|
|
|
assert cfg.model_name is not None, "model_name is required" |
|
|
|
return cfg, cfg_train |
|
|
|
|
|
def create_zip_archive(output_zip='runtime_data/h5_render_data.zip', base_vid="nice", in_fd="./runtime_data"): |
|
flame_params_pth = os.path.join("./assets/sample_motion/export", base_vid, "flame_params.json") |
|
file_lst = [ |
|
f'{in_fd}/lbs_weight_20k.json', f'{in_fd}/offset.ply', f'{in_fd}/skin.glb', |
|
f'{in_fd}/vertex_order.json', f'{in_fd}/bone_tree.json', |
|
flame_params_pth |
|
] |
|
try: |
|
|
|
with zipfile.ZipFile(output_zip, 'w') as zipf: |
|
|
|
for file_path in file_lst: |
|
zipf.write(file_path, arcname=os.path.join("h5_render_data", os.path.basename(file_path))) |
|
print(f"Archive created successfully: {output_zip}") |
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
|
|
|
|
def demo_lam(flametracking, lam, cfg): |
|
|
|
|
|
def core_fn(image_path: str, video_params, working_dir): |
|
image_raw = os.path.join(working_dir.name, "raw.png") |
|
with Image.open(image_path).convert('RGB') as img: |
|
img.save(image_raw) |
|
|
|
base_vid = os.path.basename(video_params).split(".")[0] |
|
flame_params_dir = os.path.join("./assets/sample_motion/export", base_vid, "flame_param") |
|
base_iid = os.path.basename(image_path).split('.')[0] |
|
image_path = os.path.join("./assets/sample_input", base_iid, "images/00000_00.png") |
|
|
|
dump_video_path = os.path.join(working_dir.name, "output.mp4") |
|
dump_image_path = os.path.join(working_dir.name, "output.png") |
|
|
|
|
|
omit_prefix = os.path.dirname(image_raw) |
|
image_name = os.path.basename(image_raw) |
|
uid = image_name.split(".")[0] |
|
subdir_path = os.path.dirname(image_raw).replace(omit_prefix, "") |
|
subdir_path = ( |
|
subdir_path[1:] if subdir_path.startswith("/") else subdir_path |
|
) |
|
print("subdir_path and uid:", subdir_path, uid) |
|
|
|
motion_seqs_dir = flame_params_dir |
|
|
|
dump_image_dir = os.path.dirname(dump_image_path) |
|
os.makedirs(dump_image_dir, exist_ok=True) |
|
|
|
print(image_raw, motion_seqs_dir, dump_image_dir, dump_video_path) |
|
|
|
dump_tmp_dir = dump_image_dir |
|
|
|
if os.path.exists(dump_video_path): |
|
return dump_image_path, dump_video_path |
|
|
|
motion_img_need_mask = cfg.get("motion_img_need_mask", False) |
|
vis_motion = cfg.get("vis_motion", False) |
|
|
|
|
|
return_code = flametracking.preprocess(image_raw) |
|
assert (return_code == 0), "flametracking preprocess failed!" |
|
return_code = flametracking.optimize() |
|
assert (return_code == 0), "flametracking optimize failed!" |
|
return_code, output_dir = flametracking.export() |
|
assert (return_code == 0), "flametracking export failed!" |
|
|
|
image_path = os.path.join(output_dir, "images/00000_00.png") |
|
mask_path = os.path.join(output_dir, "fg_masks/00000_00.png") |
|
print("image_path:", image_path, "\n"+"mask_path:", mask_path) |
|
|
|
aspect_standard = 1.0/1.0 |
|
source_size = cfg.source_size |
|
render_size = cfg.render_size |
|
render_fps = 30 |
|
|
|
image, _, _, shape_param = preprocess_image(image_path, mask_path=mask_path, intr=None, pad_ratio=0, bg_color=1., |
|
max_tgt_size=None, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1.0], |
|
render_tgt_size=source_size, multiply=14, need_mask=True, get_shape_param=True) |
|
|
|
|
|
save_ref_img_path = os.path.join(dump_tmp_dir, "output.png") |
|
vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() * 255).astype(np.uint8) |
|
Image.fromarray(vis_ref_img).save(save_ref_img_path) |
|
|
|
|
|
src = image_path.split('/')[-3] |
|
driven = motion_seqs_dir.split('/')[-2] |
|
src_driven = [src, driven] |
|
motion_seq = prepare_motion_seqs(motion_seqs_dir, None, save_root=dump_tmp_dir, fps=render_fps, |
|
bg_color=1., aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1,0], |
|
render_image_res=render_size, multiply=16, |
|
need_mask=motion_img_need_mask, vis_motion=vis_motion, |
|
shape_param=shape_param, test_sample=False, cross_id=False, src_driven=src_driven) |
|
|
|
|
|
motion_seq["flame_params"]["betas"] = shape_param.unsqueeze(0) |
|
device, dtype = "cuda", torch.float32 |
|
print("start to inference...................") |
|
with torch.no_grad(): |
|
|
|
res = lam.infer_single_view(image.unsqueeze(0).to(device, dtype), None, None, |
|
render_c2ws=motion_seq["render_c2ws"].to(device), |
|
render_intrs=motion_seq["render_intrs"].to(device), |
|
render_bg_colors=motion_seq["render_bg_colors"].to(device), |
|
flame_params={k:v.to(device) for k, v in motion_seq["flame_params"].items()}) |
|
|
|
|
|
if h5_rendering: |
|
res['cano_gs_lst'][0].save_ply(os.path.join(working_dir.name, "offset.ply"), rgb2sh=False, offset2xyz=True) |
|
|
|
h5_fd = working_dir.name |
|
lam.renderer.flame_model.save_h5_info(shape_param.unsqueeze(0).cuda(), fd=h5_fd) |
|
res['cano_gs_lst'][0].save_ply(os.path.join(h5_fd, "offset.ply"), rgb2sh=False, offset2xyz=True) |
|
cmd = do_softlink(h5_fd) |
|
cmd = "thirdparties/blender/blender --background --python 'tools/generateGLBWithBlender_v2.py'" |
|
os.system(cmd) |
|
output_zip = os.path.join(h5_fd, "h5_render_data.zip") |
|
create_zip_archive(output_zip=output_zip, base_vid=base_vid, in_fd=h5_fd) |
|
|
|
|
|
rgb = res["comp_rgb"].detach().cpu().numpy() |
|
mask = res["comp_mask"].detach().cpu().numpy() |
|
mask[mask < 0.5] = 0.0 |
|
rgb = rgb * mask + (1 - mask) * 1 |
|
rgb = (np.clip(rgb, 0, 1.0) * 255).astype(np.uint8) |
|
if vis_motion: |
|
vis_ref_img = np.tile( |
|
cv2.resize(vis_ref_img, (rgb[0].shape[1], rgb[0].shape[0]), interpolation=cv2.INTER_AREA)[None, :, :, :], |
|
(rgb.shape[0], 1, 1, 1), |
|
) |
|
rgb = np.concatenate([vis_ref_img, rgb, motion_seq["vis_motion_render"]], axis=2) |
|
|
|
os.makedirs(os.path.dirname(dump_video_path), exist_ok=True) |
|
|
|
save_images2video(rgb, dump_video_path, render_fps) |
|
audio_path = os.path.join("./assets/sample_motion/export", base_vid, base_vid+".wav") |
|
dump_video_path_wa = dump_video_path.replace(".mp4", "_audio.mp4") |
|
add_audio_to_video(dump_video_path, dump_video_path_wa, audio_path) |
|
|
|
return dump_image_path, dump_video_path_wa |
|
|
|
with gr.Blocks(analytics_enabled=False) as demo: |
|
|
|
logo_url = './assets/images/logo.jpeg' |
|
logo_base64 = get_image_base64(logo_url) |
|
gr.HTML(f""" |
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> |
|
<div> |
|
<h1> <img src="{logo_base64}" style='height:35px; display:inline-block;'/> Large Avatar Model for One-shot Animatable Gaussian Head</h1> |
|
</div> |
|
</div> |
|
""") |
|
gr.HTML( |
|
""" |
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;"> |
|
<a class="flex-item" href="https://arxiv.org/abs/2502.17796" target="_blank"> |
|
<img src="https://img.shields.io/badge/Paper-arXiv-darkred.svg" alt="arXiv Paper"> |
|
</a> |
|
<a class="flex-item" href="https://aigc3d.github.io/projects/LAM/" target="_blank"> |
|
<img src="https://img.shields.io/badge/Project-LAM-blue" alt="Project Page"> |
|
</a> |
|
<a class="flex-item" href="https://github.com/aigc3d/LAM" target="_blank"> |
|
<img src="https://img.shields.io/github/stars/aigc3d/LAM?label=Github%20★&logo=github&color=C8C" alt="badge-github-stars"> |
|
</a> |
|
<a class="flex-item" href="https://youtu.be/FrfE3RYSKhk" target="_blank"> |
|
<img src="https://img.shields.io/badge/Youtube-Video-red.svg" alt="Video"> |
|
</a> |
|
</div> |
|
""" |
|
) |
|
|
|
gr.HTML("""<div style="margin-top: -10px"> |
|
<p style="margin: 4px 0; line-height: 1.2"><h4 style="color: red; margin: 2px 0">Notes1: Inputing front-face images or face orientation close to the driven signal gets better results.</h4></p> |
|
<p style="margin: 4px 0; line-height: 1.2"><h4 style="color: red; margin: 2px 0">Notes2: Due to computational constraints with Hugging Face's ZeroGPU infrastructure, video generation requires ~1 minute per instance.</h4></p> |
|
<p style="margin: 4px 0; line-height: 1.2"><h4 style="color: red; margin: 2px 0">Notes3: Using LAM-20K model (lower quality than premium LAM-80K) to mitigate processing latency.</h4></p> |
|
</div>""") |
|
|
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(variant='panel', scale=1): |
|
with gr.Tabs(elem_id='lam_input_image'): |
|
with gr.TabItem('Input Image'): |
|
with gr.Row(): |
|
input_image = gr.Image(label='Input Image', |
|
image_mode='RGB', |
|
height=480, |
|
width=270, |
|
sources='upload', |
|
type='filepath', |
|
elem_id='content_image') |
|
|
|
with gr.Row(): |
|
examples = [ |
|
['assets/sample_input/messi.png'], |
|
['assets/sample_input/status.png'], |
|
['assets/sample_input/james.png'], |
|
['assets/sample_input/cluo.jpg'], |
|
['assets/sample_input/dufu.jpg'], |
|
['assets/sample_input/libai.jpg'], |
|
['assets/sample_input/barbara.jpg'], |
|
['assets/sample_input/pop.png'], |
|
['assets/sample_input/musk.jpg'], |
|
['assets/sample_input/speed.jpg'], |
|
['assets/sample_input/zhouxingchi.jpg'], |
|
] |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[input_image], |
|
examples_per_page=20 |
|
) |
|
|
|
with gr.Column(): |
|
with gr.Tabs(elem_id='lam_input_video'): |
|
with gr.TabItem('Input Video'): |
|
with gr.Row(): |
|
video_input = gr.Video(label='Input Video', |
|
height=480, |
|
width=270, |
|
interactive=False) |
|
|
|
examples = ['./assets/sample_motion/export/Speeding_Scandal/Speeding_Scandal.mp4', |
|
'./assets/sample_motion/export/Look_In_My_Eyes/Look_In_My_Eyes.mp4', |
|
'./assets/sample_motion/export/D_ANgelo_Dinero/D_ANgelo_Dinero.mp4', |
|
'./assets/sample_motion/export/Michael_Wayne_Rosen/Michael_Wayne_Rosen.mp4', |
|
'./assets/sample_motion/export/I_Am_Iron_Man/I_Am_Iron_Man.mp4', |
|
'./assets/sample_motion/export/Anti_Drugs/Anti_Drugs.mp4', |
|
'./assets/sample_motion/export/Pen_Pineapple_Apple_Pen/Pen_Pineapple_Apple_Pen.mp4', |
|
'./assets/sample_motion/export/Joe_Biden/Joe_Biden.mp4', |
|
'./assets/sample_motion/export/Donald_Trump/Donald_Trump.mp4', |
|
'./assets/sample_motion/export/Taylor_Swift/Taylor_Swift.mp4', |
|
'./assets/sample_motion/export/GEM/GEM.mp4', |
|
'./assets/sample_motion/export/The_Shawshank_Redemption/The_Shawshank_Redemption.mp4' |
|
] |
|
print("Video example list {}".format(examples)) |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[video_input], |
|
examples_per_page=20, |
|
) |
|
with gr.Column(variant='panel', scale=1): |
|
with gr.Tabs(elem_id='lam_processed_image'): |
|
with gr.TabItem('Processed Image'): |
|
with gr.Row(): |
|
processed_image = gr.Image( |
|
label='Processed Image', |
|
image_mode='RGBA', |
|
type='filepath', |
|
elem_id='processed_image', |
|
height=480, |
|
width=270, |
|
interactive=False) |
|
|
|
with gr.Column(variant='panel', scale=1): |
|
with gr.Tabs(elem_id='lam_render_video'): |
|
with gr.TabItem('Rendered Video'): |
|
with gr.Row(): |
|
output_video = gr.Video(label='Rendered Video', |
|
format='mp4', |
|
height=480, |
|
width=270, |
|
autoplay=True) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(variant='panel', scale=1): |
|
submit = gr.Button('Generate', |
|
elem_id='lam_generate', |
|
variant='primary') |
|
|
|
if h5_rendering: |
|
gr.HTML(f""" |
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> |
|
<div> |
|
<h2> Cross-platform H5 Rendering</h2> |
|
</div> |
|
</div> |
|
""") |
|
gr.set_static_paths("runtime_data/") |
|
assetPrefix = 'gradio_api/file=runtime_data/' |
|
with gr.Row(): |
|
gs = gaussian_render(width = 300, height = 400, assets = assetPrefix + 'h5_render_data.zip') |
|
|
|
working_dir = gr.State() |
|
submit.click( |
|
fn=assert_input_image, |
|
inputs=[input_image], |
|
queue=False, |
|
).success( |
|
fn=prepare_working_dir, |
|
outputs=[working_dir], |
|
queue=False, |
|
).success( |
|
fn=core_fn, |
|
inputs=[input_image, video_input, |
|
working_dir], |
|
outputs=[processed_image, output_video], |
|
).success( |
|
doRender, |
|
inputs=[working_dir], |
|
js='''() => window.start()''' |
|
) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|
|
|
|
def _build_model(cfg): |
|
from lam.models import ModelLAM |
|
from safetensors.torch import load_file |
|
|
|
model = ModelLAM(**cfg.model) |
|
resume = os.path.join(cfg.model_name, "model.safetensors") |
|
print("="*100) |
|
print("loading pretrained weight from:", resume) |
|
if resume.endswith('safetensors'): |
|
ckpt = load_file(resume, device='cpu') |
|
else: |
|
ckpt = torch.load(resume, map_location='cpu') |
|
state_dict = model.state_dict() |
|
for k, v in ckpt.items(): |
|
if k in state_dict: |
|
if state_dict[k].shape == v.shape: |
|
state_dict[k].copy_(v) |
|
else: |
|
print(f"WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.") |
|
else: |
|
print(f"WARN] unexpected param {k}: {v.shape}") |
|
print("finish loading pretrained weight from:", resume) |
|
print("="*100) |
|
return model |
|
|
|
|
|
def launch_gradio_app(): |
|
|
|
os.environ.update({ |
|
'APP_ENABLED': '1', |
|
'APP_MODEL_NAME': |
|
'./model_zoo/lam_models/releases/lam/lam-20k/step_045500/', |
|
'APP_INFER': './configs/inference/lam-20k-8gpu.yaml', |
|
'APP_TYPE': 'infer.lam', |
|
'NUMBA_THREADING_LAYER': 'omp', |
|
}) |
|
|
|
cfg, _ = parse_configs() |
|
lam = _build_model(cfg) |
|
lam.to('cuda') |
|
|
|
flametracking = FlameTrackingSingleImage(output_dir='tracking_output', |
|
alignment_model_path='./model_zoo/flame_tracking_models/68_keypoints_model.pkl', |
|
vgghead_model_path='./model_zoo/flame_tracking_models/vgghead/vgg_heads_l.trcd', |
|
human_matting_path='./model_zoo/flame_tracking_models/matting/stylematte_synth.pt', |
|
facebox_model_path='./model_zoo/flame_tracking_models/FaceBoxesV2.pth', |
|
detect_iris_landmarks=True) |
|
|
|
demo_lam(flametracking, lam, cfg) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
launch_gradio_app() |
|
|