Spaces:
Running
Running
import os | |
from pathlib import Path | |
import gradio as gr | |
import spaces | |
import torch | |
import smplx | |
import numpy as np | |
from website import CREDITS, WEB_source, WEB_target, WEBSITE | |
from download_deps import get_smpl_models, download_models, download_model_config | |
from download_deps import download_tmr, download_motionfix, download_motionfix_dataset | |
from download_deps import download_embeddings | |
import random | |
# DO NOT initialize CUDA here | |
DEFAULT_TEXT = "do it slower" | |
import os | |
os.environ['PYOPENGL_PLATFORM'] = 'egl' | |
os.environ['LD_LIBRARY_PATH'] = '/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/nvidia/current:' + os.environ.get('LD_LIBRARY_PATH', '') | |
# Optional debugging | |
import subprocess | |
try: | |
result = subprocess.run(['ldconfig', '-p'], capture_output=True, text=True) | |
egl_libs = [line for line in result.stdout.split('\n') if 'EGL' in line] | |
print("Available EGL libraries:", egl_libs) | |
except Exception as e: | |
print(f"Error finding libraries: {e}") | |
# Example videos | |
example_videos = [ | |
"./examples/1919.mp4", | |
"./examples/5376.mp4", | |
"./examples/1259.mp4", | |
"./examples/3686.mp4", | |
"./examples/1289.mp4", | |
"./examples/1893.mp4", | |
"./examples/3262.mp4", | |
"./examples/6117.mp4", | |
"./examples/1031.mp4", | |
"./examples/6247.mp4", | |
] | |
# Example videos | |
example_keys = [ | |
"001919", | |
"005376", | |
"001259", | |
"003686", | |
"001289", | |
"001893", | |
"003262", | |
"006117", | |
"001031", | |
"006247", | |
] | |
# Example videos | |
example_texts = [ | |
"mirror", | |
"move in a smaller circle", | |
"less deep", | |
"turn back faster", | |
"cross your legs", | |
"step to the right", | |
"start sitting down a bit later", | |
"start a bit later, hold elbow lower at the end", | |
"extend the arm further back and catch higher", | |
"hold right arm higher", | |
] | |
example_video_outputs = [gr.Video(label=f"Example {i+1}", | |
value=example_videos[i]) | |
for i in range(4)] | |
class MotionEditor: | |
def __init__(self): | |
# Don't initialize any CUDA components in __init__ | |
self.is_initialized = False | |
self.MFIX_p = download_motionfix() + '/motionfix' | |
# self.SOURCE_MOTS_p = download_embeddings() + '/embeddings' | |
self.MFIX_DATASET_DICT = download_motionfix_dataset() | |
self.model_ckpt_path = download_models("899_bs128_zipped") # small_model_zipped_last/last_zipped | |
self.model_cfg = download_model_config('bs_128_conf') # small_model_config / big_model_config | |
self.model_config_feats = self.model_cfg.model.input_feats | |
def initialize_if_needed(self): | |
"""Initialize models only when needed, within a GPU-decorated function""" | |
if self.is_initialized: | |
return | |
if not torch.cuda.is_available(): | |
raise RuntimeError("CUDA is not available") | |
print(f"Current CUDA device: {torch.cuda.current_device()}") | |
print(f"CUDA device name: {torch.cuda.get_device_name(0)}") | |
# Check total and available memory | |
total_memory = torch.cuda.get_device_properties(0).total_memory | |
reserved_memory = torch.cuda.memory_reserved(0) | |
allocated_memory = torch.cuda.memory_allocated(0) | |
print(f"Total GPU Memory: {total_memory / 1e9} GB") | |
print(f"Reserved Memory: {reserved_memory / 1e9} GB") | |
print(f"Allocated Memory: {allocated_memory / 1e9} GB") | |
from normalization import Normalizer | |
from diffusion import create_diffusion | |
from text_encoder import ClipTextEncoder | |
from tmed_denoiser import TMED_denoiser | |
# Initialize components | |
self.device = torch.device('cuda') | |
self.normalizer = Normalizer() | |
self.text_encoder = ClipTextEncoder() | |
# Load models and configs | |
model_ckpt = self.model_ckpt_path | |
self.infeats = self.model_config_feats | |
checkpoint = torch.load(model_ckpt, map_location=self.device) | |
checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()} | |
# Setup denoiser | |
self.tmed_denoiser = TMED_denoiser(latent_dim=self.model_cfg.model.latent_dim, | |
num_layers=8, | |
ff_size=1024, | |
num_heads=4).to(self.device) | |
self.tmed_denoiser.load_state_dict(checkpoint, strict=False) | |
self.tmed_denoiser.eval() | |
# Setup diffusion | |
self.diffusion = create_diffusion( | |
timestep_respacing=None, | |
learn_sigma=False, | |
sigma_small=True, | |
diffusion_steps=self.model_cfg.model.diff_params.num_train_timesteps, | |
noise_schedule='squaredcos_cap_v2', | |
predict_xstart=True | |
) | |
# Setup SMPL model | |
smpl_models_path = str(Path(get_smpl_models())) | |
self.body_model = smplx.SMPLHLayer( | |
f"{smpl_models_path}/smplh", | |
model_type='smplh', | |
gender='neutral', | |
ext='npz' | |
) | |
self.is_initialized = True | |
def process_motion(self, input_text, key_to_use): | |
"""Main processing function, GPU-decorated""" | |
self.initialize_if_needed() | |
# import ipdb; ipdb.set_trace() | |
# Load dataset sample | |
ds_sample = self.MFIX_DATASET_DICT[key_to_use] | |
# Process features | |
data_dict = self.process_features(ds_sample) | |
source_motion_norm, target_motion_norm = self.normalize_motions(data_dict) | |
source_motion = self.denormalize_motion(source_motion_norm) | |
# Generate edited motion | |
edited_motion = self.generate_edited_motion( | |
input_text, | |
source_motion_norm, | |
target_motion_norm | |
) | |
# Render result | |
return self.render_result(edited_motion, source_motion) | |
def process_features(self, ds_sample): | |
"""Process features - called from within GPU-decorated function""" | |
from feature_extractor import FEAT_GET_METHODS | |
data_dict = {} | |
for feat in self.infeats: | |
data_dict[f'{feat}_source'] = FEAT_GET_METHODS[feat]( | |
ds_sample['motion_source'] | |
)[None].to(self.device) | |
data_dict[f'{feat}_target'] = FEAT_GET_METHODS[feat]( | |
ds_sample['motion_target'] | |
)[None].to(self.device) | |
return data_dict | |
def normalize_motions(self, data_dict): | |
"""Normalize motions - called from within GPU-decorated function""" | |
batch = self.normalizer.norm_and_cat(data_dict, self.infeats) | |
return batch['source'], batch['target'] | |
def generate_edited_motion(self, input_text, source_motion, target_motion): | |
"""Generate edited motion - called from within GPU-decorated function""" | |
# Encode text | |
texts_cond = [''] * 2 + [input_text] | |
text_emb, text_mask = self.text_encoder(texts_cond) | |
# Setup masks | |
bsz = 1 | |
seqlen_src = source_motion.shape[0] | |
seqlen_tgt = target_motion.shape[0] | |
cond_motion_mask = torch.ones((bsz, seqlen_src), dtype=bool, device=self.device) | |
mask_target = torch.ones((bsz, seqlen_tgt), dtype=bool, device=self.device) | |
# Generate diffusion output | |
diff_out = self.tmed_cenoiser._diffusion_reverse( | |
text_emb.to(self.device), | |
text_mask.to(self.device), | |
source_motion, | |
cond_motion_mask, | |
mask_target, | |
self.diffusion, | |
init_vec=None, | |
init_from='noise', | |
gd_text=2.0, | |
gd_motion=3.0, | |
steps_num=self.model_cfg.model.diff_params.num_train_timesteps | |
) | |
return self.denormalize_motion(diff_out) | |
def denormalize_motion(self, diff_out): | |
"""Denormalize motion - called from within GPU-decorated function""" | |
from geometry_utils import diffout2motion | |
# import ipdb; ipdb.set_trace() | |
return diffout2motion(diff_out.permute(1, 0, 2), self.normalizer).squeeze() | |
def render_result(self, edited_motion, source_motion): | |
"""Render result - called from within GPU-decorated function""" | |
from body_renderer import get_render | |
from transform3d import transform_body_pose, rotate_body_degrees | |
# Transform motions | |
edited_motion_transformed = self.transform_motion(edited_motion) | |
source_motion_transformed = self.transform_motion(source_motion) | |
# Render video | |
if os.path.exists('./output_movie.mp4'): | |
os.remove('./output_movie.mp4') | |
# import ipdb; ipdb.set_trace() | |
return get_render( | |
self.body_model, | |
[edited_motion_transformed['trans'].detach().cpu(), | |
source_motion_transformed['trans'].detach().cpu()], | |
[edited_motion_transformed['rots_init'].detach().cpu(), | |
source_motion_transformed['rots_init'].detach().cpu()], | |
[edited_motion_transformed['rots_rest'].detach().cpu(), | |
source_motion_transformed['rots_rest'].detach().cpu()], | |
output_path='./output_movie.mp4', | |
text='', | |
colors=['sky blue', 'red'] | |
) | |
def transform_motion(self, motion): | |
"""Transform motion - called from within GPU-decorated function""" | |
from transform3d import transform_body_pose, rotate_body_degrees | |
motion_aa = transform_body_pose(motion[:, 3:], '6d->aa') | |
trans = motion[..., :3].detach().cpu() | |
rots_aa = motion_aa.detach().cpu() | |
rots_rotated, trans_rotated = rotate_body_degrees( | |
transform_body_pose(rots_aa, 'aa->rot'), | |
trans, | |
offset=np.pi | |
) | |
rots_rotated_aa = transform_body_pose(rots_rotated, 'rot->aa') | |
return { | |
'trans': trans_rotated, | |
'rots_init': rots_rotated_aa[:, 0], | |
'rots_rest': rots_rotated_aa[:, 1:] | |
} | |
# Gradio Interface | |
def create_gradio_interface(): | |
editor = MotionEditor() | |
def process_and_show_video(input_text, random_key_state): | |
return editor.process_motion(input_text, random_key_state) | |
def random_source_motion(set_to_pick): | |
from dataset_utils import load_motionfix | |
mfix_train, mfix_test = load_motionfix(editor.MFIX_p) | |
current_set = { | |
'all': mfix_test | mfix_train, | |
'train': mfix_train, | |
'test': mfix_test | |
}[set_to_pick] | |
random_key = random.choice(list(current_set.keys())) | |
motion = current_set[random_key]['motion_a'] | |
text_annot = current_set[random_key]['annotation'] | |
# should add one more text_annot | |
return gr.update(value=motion, | |
visible=True), random_key, text_annot | |
def clear(): | |
return "" | |
# Gradio UI | |
with gr.Blocks(css=CUSTOM_CSS) as demo: | |
gr.HTML(WEBSITE) | |
random_key_state = gr.State() | |
with gr.Row(): | |
with gr.Column(scale=5): | |
gr.HTML(WEB_source) | |
with gr.Row(): | |
random_button = gr.Button("Random", scale=0) | |
# clear_button_retrieval = gr.Button("Clear", scale=0) | |
# Example videos grid with buttons | |
# suggested_edit_text = gr.Textbox( | |
# placeholder="Texts likely to edit the motion:", | |
# label="Suggested Edit Text", | |
# value='' | |
# ) | |
set_to_pick = gr.Radio( | |
['all', 'train', 'test'], | |
value='all', | |
label="Set to pick from" | |
) | |
retrieved_video_output = gr.Video( | |
label="Retrieved Motion", | |
height=360, | |
width=480, | |
visible=False # Initially hidden | |
) | |
gr.HTML(("""<div class="embed_hidden" style="text-align: center;"> | |
<h1>Examples</h1></div>""")) | |
with gr.Row(): | |
# First example | |
with gr.Column(): | |
gr.Video(value=example_videos[0], | |
height=180,width=240, | |
label="Example 1") | |
example_button1 = gr.Button("Select Ex. 1", | |
elem_classes=["fit-text"]) | |
# Second example | |
with gr.Column(): | |
gr.Video(value=example_videos[1], | |
height=180,width=240, | |
label="Example 2") | |
example_button2 = gr.Button("Select Ex. 2", | |
elem_classes=["fit-text"]) | |
with gr.Row(): | |
# Third example | |
with gr.Column(): | |
gr.Video(value=example_videos[2], | |
height=180,width=240, | |
label="Example 3") | |
example_button3 = gr.Button("Select Ex. 3", | |
elem_classes=["fit-text"]) | |
# Fourth example | |
with gr.Column(): | |
gr.Video(value=example_videos[3], | |
height=180,width=240, | |
label="Example 4") | |
example_button4 = gr.Button("Select Ex. 4", | |
elem_classes=["fit-text"]) | |
with gr.Column(scale=5): | |
gr.HTML(WEB_target) | |
with gr.Row(): | |
clear_button_edit = gr.Button("Clear", scale=0) | |
edit_button = gr.Button("Edit", scale=0) | |
input_text = gr.Textbox( | |
placeholder="Type the edit text you want:", | |
label="Input Text", | |
value=DEFAULT_TEXT | |
) | |
video_output = gr.Video( | |
label="Generated Video", | |
height=360, | |
width=480 | |
) | |
# Event handlers | |
edit_button.click( | |
process_and_show_video, | |
inputs=[input_text, random_key_state], | |
outputs=video_output | |
) | |
random_button.click( | |
random_source_motion, | |
inputs=set_to_pick, | |
outputs=[ | |
retrieved_video_output, | |
# suggested_edit_text, | |
random_key_state, | |
input_text | |
] | |
) | |
# def load_example_video(example_path): | |
# # motion = current_set[random_key]['motion_a'] | |
# # text_annot = current_set[random_key]['annotation'] | |
# import ipdb; ipdb.set_trace() | |
# return gr.update(value=example_path, visible=True) | |
def load_example(example_video, example_key, example_text): | |
# Update all outputs | |
return ( | |
gr.update(value=example_video, visible=True), # Update video output | |
# example_text, # Update suggested edit text | |
example_key, # Update random key state | |
example_text # Update input text | |
) | |
example_button1.click( | |
fn=lambda: load_example(example_videos[0], example_keys[0], example_texts[0]), | |
inputs=None, | |
outputs=[ | |
retrieved_video_output, | |
# suggested_edit_text, | |
random_key_state, | |
input_text | |
] | |
) | |
example_button2.click( | |
fn=lambda: load_example(example_videos[1], example_keys[1], example_texts[1]), | |
inputs=None, | |
outputs=[ | |
retrieved_video_output, | |
# suggested_edit_text, | |
random_key_state, | |
input_text | |
] | |
) | |
example_button3.click( | |
fn=lambda: load_example(example_videos[2], example_keys[2], example_texts[2]), | |
inputs=None, | |
outputs=[ | |
retrieved_video_output, | |
# suggested_edit_text, | |
random_key_state, | |
input_text | |
] | |
) | |
example_button4.click( | |
fn=lambda: load_example(example_videos[3], example_keys[3], example_texts[3]), | |
inputs=None, | |
outputs=[ | |
retrieved_video_output, | |
# suggested_edit_text, | |
random_key_state, | |
input_text | |
] | |
) | |
clear_button_edit.click(clear, outputs=input_text) | |
# clear_button_retrieval.click(clear, outputs=suggested_edit_text) | |
gr.Markdown(CREDITS) | |
return demo | |
# Constants | |
CUSTOM_CSS = """ | |
.gradio-row { display: flex; gap: 20px; } | |
.gradio-column { flex: 1; } | |
.gradio-container { display: flex; flex-direction: column; gap: 10px; } | |
.gradio-button-row { display: flex; gap: 10px; } | |
.gradio-textbox-row { display: flex; gap: 10px; align-items: center; } | |
.gradio-edit-row { gap: 10px; align-items: center; } | |
.gradio-textbox-with-button { display: flex; align-items: center; } | |
.gradio-textbox-with-button input { flex-grow: 1; } | |
button.fit-text { | |
width: auto; /* Automatically adjusts to the text length */ | |
padding: 10px 20px; /* Adjust padding for a better look */ | |
font-size: 12px; /* Control font size */ | |
text-align: center; /* Center the text */ | |
margin: 0 auto; /* Center the button horizontally */ | |
display: inline-block; /* Prevent it from stretching */ | |
} | |
""" | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.launch(share=True) |