|
import os |
|
import time |
|
import random |
|
import datetime |
|
import os.path as osp |
|
from functools import partial |
|
|
|
import torch |
|
import gradio as gr |
|
from omegaconf import OmegaConf |
|
|
|
from mld.config import get_module_config |
|
from mld.data.get_data import get_datasets |
|
from mld.models.modeltype.mld import MLD |
|
from mld.utils.utils import set_seed |
|
from mld.data.humanml.utils.plot_script import plot_3d_motion |
|
|
|
WEBSITE = """ |
|
<div class="embed_hidden"> |
|
<h1 style='text-align: center'> MotionLCM: Real-time Controllable Motion Generation via Latent Consistency Model </h1> |
|
|
|
<h2 style='text-align: center'> |
|
<a href="https://github.com/Dai-Wenxun/" target="_blank"><nobr>Wenxun Dai</nobr><sup>1</sup></a>   |
|
<a href="https://lhchen.top/" target="_blank"><nobr>Ling-Hao Chen</nobr></a><sup>1</sup>   |
|
<a href="https://wangjingbo1219.github.io/" target="_blank"><nobr>Jingbo Wang</nobr></a><sup>2</sup>   |
|
<a href="https://moonsliu.github.io/" target="_blank"><nobr>Jinpeng Liu</nobr></a><sup>1</sup>   |
|
<a href="https://daibo.info/" target="_blank"><nobr>Bo Dai</nobr></a><sup>2</sup>   |
|
<a href="https://andytang15.github.io/" target="_blank"><nobr>Yansong Tang</nobr></a><sup>1</sup> |
|
</h2> |
|
|
|
<h2 style='text-align: center'> |
|
<nobr><sup>1</sup>Tsinghua University</nobr>   |
|
<nobr><sup>2</sup>Shanghai AI Laboratory</nobr> |
|
</h2> |
|
|
|
</div> |
|
""" |
|
|
|
WEBSITE_bottom = """ |
|
<div class="embed_hidden"> |
|
<p> |
|
Space adapted from <a href="https://huggingface.co/spaces/Mathux/TMR" target="_blank">TMR</a> |
|
and <a href="https://huggingface.co/spaces/MeYourHint/MoMask" target="_blank">MoMask</a>. |
|
</p> |
|
</div> |
|
""" |
|
|
|
EXAMPLES = [ |
|
"a person does a jump", |
|
"a person waves both arms in the air.", |
|
"The person takes 4 steps backwards.", |
|
"this person bends forward as if to bow.", |
|
"The person was pushed but did not fall.", |
|
"a man walks forward in a snake like pattern.", |
|
"a man paces back and forth along the same line.", |
|
"with arms out to the sides a person walks forward", |
|
"A man bends down and picks something up with his right hand.", |
|
"The man walked forward, spun right on one foot and walked back to his original position.", |
|
"a person slightly bent over with right hand pressing against the air walks forward slowly" |
|
] |
|
|
|
CSS = """ |
|
.contour_video { |
|
display: flex; |
|
flex-direction: column; |
|
justify-content: center; |
|
align-items: center; |
|
z-index: var(--layer-5); |
|
border-radius: var(--block-radius); |
|
background: var(--background-fill-primary); |
|
padding: 0 var(--size-6); |
|
max-height: var(--size-screen-h); |
|
overflow: hidden; |
|
} |
|
""" |
|
|
|
if not os.path.exists("./experiments_t2m/"): |
|
os.system("bash prepare/download_pretrained_models.sh") |
|
if not os.path.exists('./deps/glove/'): |
|
os.system("bash prepare/download_glove.sh") |
|
if not os.path.exists('./deps/sentence-t5-large/'): |
|
os.system("bash prepare/prepare_t5.sh") |
|
if not os.path.exists('./deps/t2m/'): |
|
os.system("bash prepare/download_t2m_evaluators.sh") |
|
if not os.path.exists('./datasets/humanml3d/'): |
|
os.system("bash prepare/prepare_tiny_humanml3d.sh") |
|
|
|
DEFAULT_TEXT = "A person is " |
|
MAX_VIDEOS = 12 |
|
T2M_CFG = "./configs/motionlcm_t2m.yaml" |
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
cfg = OmegaConf.load(T2M_CFG) |
|
cfg_model = get_module_config(cfg.model, cfg.model.target) |
|
cfg = OmegaConf.merge(cfg, cfg_model) |
|
set_seed(1949) |
|
|
|
name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) |
|
output_dir = osp.join(cfg.TEST_FOLDER, name_time_str) |
|
vis_dir = osp.join(output_dir, 'samples') |
|
os.makedirs(output_dir, exist_ok=False) |
|
os.makedirs(vis_dir, exist_ok=False) |
|
|
|
state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"] |
|
print("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS)) |
|
|
|
lcm_key = 'denoiser.time_embedding.cond_proj.weight' |
|
is_lcm = False |
|
if lcm_key in state_dict: |
|
is_lcm = True |
|
time_cond_proj_dim = state_dict[lcm_key].shape[1] |
|
cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim |
|
print(f'Is LCM: {is_lcm}') |
|
|
|
cfg.model.is_controlnet = False |
|
|
|
datasets = get_datasets(cfg, phase="test")[0] |
|
model = MLD(cfg, datasets) |
|
model.to(device) |
|
model.eval() |
|
model.load_state_dict(state_dict) |
|
|
|
|
|
@torch.no_grad() |
|
def generate(text, motion_len, num_videos): |
|
batch = {"text": [text] * num_videos, "length": [motion_len] * num_videos} |
|
|
|
s = time.time() |
|
joints, _ = model(batch) |
|
runtime = round(time.time() - s, 3) |
|
runtime_info = f'Inference {len(joints)} motions, runtime: {runtime}s, device: {device}' |
|
path = [] |
|
for i in range(num_videos): |
|
uid = random.randrange(999999999) |
|
video_path = osp.join(vis_dir, f"sample_{uid}.mp4") |
|
plot_3d_motion(video_path, joints[i].detach().cpu().numpy(), '', fps=20) |
|
path.append(video_path) |
|
return path, runtime_info |
|
|
|
|
|
|
|
def get_video_html(path, video_id, width=700, height=700): |
|
video_html = f""" |
|
<video class="contour_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()" |
|
autoplay loop disablepictureinpicture id="{video_id}"> |
|
<source src="https://wxdai-motionlcm.hf.space/file/{path}" type="video/mp4"> |
|
Your browser does not support the video tag. |
|
</video> |
|
""" |
|
return video_html |
|
|
|
|
|
def generate_component(generate_function, text, motion_len, num_videos): |
|
if text == DEFAULT_TEXT or text == "" or text is None: |
|
return [None for _ in range(MAX_VIDEOS)] + [None] |
|
|
|
motion_len = max(36, min(int(float(motion_len) * 20), 196)) |
|
paths, info = generate_function(text, motion_len, num_videos) |
|
htmls = [get_video_html(path, idx) for idx, path in enumerate(paths)] |
|
htmls = htmls + [None for _ in range(max(0, MAX_VIDEOS - num_videos))] |
|
return htmls + [info] |
|
|
|
|
|
theme = gr.themes.Default(primary_hue="purple", secondary_hue="gray") |
|
generate_and_show = partial(generate_component, generate) |
|
|
|
with gr.Blocks(css=CSS, theme=theme) as demo: |
|
gr.HTML(WEBSITE) |
|
videos = [] |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
text = gr.Textbox( |
|
show_label=True, |
|
label="Text prompt", |
|
value=DEFAULT_TEXT, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
motion_len = gr.Textbox( |
|
show_label=True, |
|
label="Motion length (in seconds, <=9.8s)", |
|
value=5, |
|
info="Any length exceeding 9.8s will be restricted to 9.8s.", |
|
) |
|
with gr.Column(scale=1): |
|
num_videos = gr.Radio( |
|
[1, 4, 8, 12], |
|
label="Videos", |
|
value=8, |
|
info="Number of videos to generate.", |
|
) |
|
|
|
gen_btn = gr.Button("Generate", variant="primary") |
|
clear = gr.Button("Clear", variant="secondary") |
|
|
|
results = gr.Textbox(show_label=True, |
|
label='Inference info (runtime and device)', |
|
info='Real-time inference cannot be achieved using the free CPU. Local GPU deployment is recommended.', |
|
interactive=False) |
|
|
|
with gr.Column(scale=2): |
|
def generate_example(text, motion_len, num_videos): |
|
return generate_and_show(text, motion_len, num_videos) |
|
|
|
examples = gr.Examples( |
|
examples=[[x, None, None] for x in EXAMPLES], |
|
inputs=[text, motion_len, num_videos], |
|
examples_per_page=12, |
|
run_on_click=False, |
|
cache_examples=False, |
|
fn=generate_example, |
|
outputs=[], |
|
) |
|
|
|
for _ in range(3): |
|
with gr.Row(): |
|
for _ in range(4): |
|
video = gr.HTML() |
|
videos.append(video) |
|
|
|
|
|
|
|
|
|
examples.outputs = videos |
|
|
|
def load_example(example_id): |
|
processed_example = examples.non_none_processed_examples[example_id] |
|
return gr.utils.resolve_singleton(processed_example) |
|
|
|
examples.dataset.click( |
|
load_example, |
|
inputs=[examples.dataset], |
|
outputs=examples.inputs_with_examples, |
|
show_progress=False, |
|
postprocess=False, |
|
queue=False, |
|
).then(fn=generate_example, inputs=examples.inputs, outputs=videos + [results]) |
|
|
|
gen_btn.click( |
|
fn=generate_and_show, |
|
inputs=[text, motion_len, num_videos], |
|
outputs=videos + [results], |
|
) |
|
text.submit( |
|
fn=generate_and_show, |
|
inputs=[text, motion_len, num_videos], |
|
outputs=videos + [results], |
|
) |
|
|
|
def clear_videos(): |
|
return [None for _ in range(MAX_VIDEOS)] + [DEFAULT_TEXT] + [None] |
|
|
|
clear.click(fn=clear_videos, outputs=videos + [text] + [results]) |
|
|
|
demo.launch() |
|
|