motionfix-demo / app.py
atnikos's picture
attempts to fix
10ff2d6
import os
from pathlib import Path
import gradio as gr
import spaces
import torch
import smplx
import numpy as np
from website import CREDITS, WEB_source, WEB_target, WEBSITE
from download_deps import get_smpl_models, download_models, download_model_config
from download_deps import download_tmr, download_motionfix, download_motionfix_dataset
from download_deps import download_embeddings
import random
# DO NOT initialize CUDA here
DEFAULT_TEXT = "do it slower"
import os
os.environ['PYOPENGL_PLATFORM'] = 'egl'
os.environ['LD_LIBRARY_PATH'] = '/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/nvidia/current:' + os.environ.get('LD_LIBRARY_PATH', '')
# Optional debugging
import subprocess
try:
result = subprocess.run(['ldconfig', '-p'], capture_output=True, text=True)
egl_libs = [line for line in result.stdout.split('\n') if 'EGL' in line]
print("Available EGL libraries:", egl_libs)
except Exception as e:
print(f"Error finding libraries: {e}")
# Example videos
example_videos = [
"./examples/1919.mp4",
"./examples/5376.mp4",
"./examples/1259.mp4",
"./examples/3686.mp4",
"./examples/1289.mp4",
"./examples/1893.mp4",
"./examples/3262.mp4",
"./examples/6117.mp4",
"./examples/1031.mp4",
"./examples/6247.mp4",
]
# Example videos
example_keys = [
"001919",
"005376",
"001259",
"003686",
"001289",
"001893",
"003262",
"006117",
"001031",
"006247",
]
# Example videos
example_texts = [
"mirror",
"move in a smaller circle",
"less deep",
"turn back faster",
"cross your legs",
"step to the right",
"start sitting down a bit later",
"start a bit later, hold elbow lower at the end",
"extend the arm further back and catch higher",
"hold right arm higher",
]
example_video_outputs = [gr.Video(label=f"Example {i+1}",
value=example_videos[i])
for i in range(4)]
class MotionEditor:
def __init__(self):
# Don't initialize any CUDA components in __init__
self.is_initialized = False
self.MFIX_p = download_motionfix() + '/motionfix'
# self.SOURCE_MOTS_p = download_embeddings() + '/embeddings'
self.MFIX_DATASET_DICT = download_motionfix_dataset()
self.model_ckpt_path = download_models("899_bs128_zipped") # small_model_zipped_last/last_zipped
self.model_cfg = download_model_config('bs_128_conf') # small_model_config / big_model_config
self.model_config_feats = self.model_cfg.model.input_feats
@spaces.GPU
def initialize_if_needed(self):
"""Initialize models only when needed, within a GPU-decorated function"""
if self.is_initialized:
return
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
print(f"Current CUDA device: {torch.cuda.current_device()}")
print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
# Check total and available memory
total_memory = torch.cuda.get_device_properties(0).total_memory
reserved_memory = torch.cuda.memory_reserved(0)
allocated_memory = torch.cuda.memory_allocated(0)
print(f"Total GPU Memory: {total_memory / 1e9} GB")
print(f"Reserved Memory: {reserved_memory / 1e9} GB")
print(f"Allocated Memory: {allocated_memory / 1e9} GB")
from normalization import Normalizer
from diffusion import create_diffusion
from text_encoder import ClipTextEncoder
from tmed_denoiser import TMED_denoiser
# Initialize components
self.device = torch.device('cuda')
self.normalizer = Normalizer()
self.text_encoder = ClipTextEncoder()
# Load models and configs
model_ckpt = self.model_ckpt_path
self.infeats = self.model_config_feats
checkpoint = torch.load(model_ckpt, map_location=self.device)
checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
# Setup denoiser
self.tmed_denoiser = TMED_denoiser(latent_dim=self.model_cfg.model.latent_dim,
num_layers=8,
ff_size=1024,
num_heads=4).to(self.device)
self.tmed_denoiser.load_state_dict(checkpoint, strict=False)
self.tmed_denoiser.eval()
# Setup diffusion
self.diffusion = create_diffusion(
timestep_respacing=None,
learn_sigma=False,
sigma_small=True,
diffusion_steps=self.model_cfg.model.diff_params.num_train_timesteps,
noise_schedule='squaredcos_cap_v2',
predict_xstart=True
)
# Setup SMPL model
smpl_models_path = str(Path(get_smpl_models()))
self.body_model = smplx.SMPLHLayer(
f"{smpl_models_path}/smplh",
model_type='smplh',
gender='neutral',
ext='npz'
)
self.is_initialized = True
@spaces.GPU(duration=360)
def process_motion(self, input_text, key_to_use):
"""Main processing function, GPU-decorated"""
self.initialize_if_needed()
# import ipdb; ipdb.set_trace()
# Load dataset sample
ds_sample = self.MFIX_DATASET_DICT[key_to_use]
# Process features
data_dict = self.process_features(ds_sample)
source_motion_norm, target_motion_norm = self.normalize_motions(data_dict)
source_motion = self.denormalize_motion(source_motion_norm)
# Generate edited motion
edited_motion = self.generate_edited_motion(
input_text,
source_motion_norm,
target_motion_norm
)
# Render result
return self.render_result(edited_motion, source_motion)
def process_features(self, ds_sample):
"""Process features - called from within GPU-decorated function"""
from feature_extractor import FEAT_GET_METHODS
data_dict = {}
for feat in self.infeats:
data_dict[f'{feat}_source'] = FEAT_GET_METHODS[feat](
ds_sample['motion_source']
)[None].to(self.device)
data_dict[f'{feat}_target'] = FEAT_GET_METHODS[feat](
ds_sample['motion_target']
)[None].to(self.device)
return data_dict
def normalize_motions(self, data_dict):
"""Normalize motions - called from within GPU-decorated function"""
batch = self.normalizer.norm_and_cat(data_dict, self.infeats)
return batch['source'], batch['target']
def generate_edited_motion(self, input_text, source_motion, target_motion):
"""Generate edited motion - called from within GPU-decorated function"""
# Encode text
texts_cond = [''] * 2 + [input_text]
text_emb, text_mask = self.text_encoder(texts_cond)
# Setup masks
bsz = 1
seqlen_src = source_motion.shape[0]
seqlen_tgt = target_motion.shape[0]
cond_motion_mask = torch.ones((bsz, seqlen_src), dtype=bool, device=self.device)
mask_target = torch.ones((bsz, seqlen_tgt), dtype=bool, device=self.device)
# Generate diffusion output
diff_out = self.tmed_cenoiser._diffusion_reverse(
text_emb.to(self.device),
text_mask.to(self.device),
source_motion,
cond_motion_mask,
mask_target,
self.diffusion,
init_vec=None,
init_from='noise',
gd_text=2.0,
gd_motion=3.0,
steps_num=self.model_cfg.model.diff_params.num_train_timesteps
)
return self.denormalize_motion(diff_out)
def denormalize_motion(self, diff_out):
"""Denormalize motion - called from within GPU-decorated function"""
from geometry_utils import diffout2motion
# import ipdb; ipdb.set_trace()
return diffout2motion(diff_out.permute(1, 0, 2), self.normalizer).squeeze()
def render_result(self, edited_motion, source_motion):
"""Render result - called from within GPU-decorated function"""
from body_renderer import get_render
from transform3d import transform_body_pose, rotate_body_degrees
# Transform motions
edited_motion_transformed = self.transform_motion(edited_motion)
source_motion_transformed = self.transform_motion(source_motion)
# Render video
if os.path.exists('./output_movie.mp4'):
os.remove('./output_movie.mp4')
# import ipdb; ipdb.set_trace()
return get_render(
self.body_model,
[edited_motion_transformed['trans'].detach().cpu(),
source_motion_transformed['trans'].detach().cpu()],
[edited_motion_transformed['rots_init'].detach().cpu(),
source_motion_transformed['rots_init'].detach().cpu()],
[edited_motion_transformed['rots_rest'].detach().cpu(),
source_motion_transformed['rots_rest'].detach().cpu()],
output_path='./output_movie.mp4',
text='',
colors=['sky blue', 'red']
)
def transform_motion(self, motion):
"""Transform motion - called from within GPU-decorated function"""
from transform3d import transform_body_pose, rotate_body_degrees
motion_aa = transform_body_pose(motion[:, 3:], '6d->aa')
trans = motion[..., :3].detach().cpu()
rots_aa = motion_aa.detach().cpu()
rots_rotated, trans_rotated = rotate_body_degrees(
transform_body_pose(rots_aa, 'aa->rot'),
trans,
offset=np.pi
)
rots_rotated_aa = transform_body_pose(rots_rotated, 'rot->aa')
return {
'trans': trans_rotated,
'rots_init': rots_rotated_aa[:, 0],
'rots_rest': rots_rotated_aa[:, 1:]
}
# Gradio Interface
def create_gradio_interface():
editor = MotionEditor()
@spaces.GPU
def process_and_show_video(input_text, random_key_state):
return editor.process_motion(input_text, random_key_state)
def random_source_motion(set_to_pick):
from dataset_utils import load_motionfix
mfix_train, mfix_test = load_motionfix(editor.MFIX_p)
current_set = {
'all': mfix_test | mfix_train,
'train': mfix_train,
'test': mfix_test
}[set_to_pick]
random_key = random.choice(list(current_set.keys()))
motion = current_set[random_key]['motion_a']
text_annot = current_set[random_key]['annotation']
# should add one more text_annot
return gr.update(value=motion,
visible=True), random_key, text_annot
def clear():
return ""
# Gradio UI
with gr.Blocks(css=CUSTOM_CSS) as demo:
gr.HTML(WEBSITE)
random_key_state = gr.State()
with gr.Row():
with gr.Column(scale=5):
gr.HTML(WEB_source)
with gr.Row():
random_button = gr.Button("Random", scale=0)
# clear_button_retrieval = gr.Button("Clear", scale=0)
# Example videos grid with buttons
# suggested_edit_text = gr.Textbox(
# placeholder="Texts likely to edit the motion:",
# label="Suggested Edit Text",
# value=''
# )
set_to_pick = gr.Radio(
['all', 'train', 'test'],
value='all',
label="Set to pick from"
)
retrieved_video_output = gr.Video(
label="Retrieved Motion",
height=360,
width=480,
visible=False # Initially hidden
)
gr.HTML(("""<div class="embed_hidden" style="text-align: center;">
<h1>Examples</h1></div>"""))
with gr.Row():
# First example
with gr.Column():
gr.Video(value=example_videos[0],
height=180,width=240,
label="Example 1")
example_button1 = gr.Button("Select Ex. 1",
elem_classes=["fit-text"])
# Second example
with gr.Column():
gr.Video(value=example_videos[1],
height=180,width=240,
label="Example 2")
example_button2 = gr.Button("Select Ex. 2",
elem_classes=["fit-text"])
with gr.Row():
# Third example
with gr.Column():
gr.Video(value=example_videos[2],
height=180,width=240,
label="Example 3")
example_button3 = gr.Button("Select Ex. 3",
elem_classes=["fit-text"])
# Fourth example
with gr.Column():
gr.Video(value=example_videos[3],
height=180,width=240,
label="Example 4")
example_button4 = gr.Button("Select Ex. 4",
elem_classes=["fit-text"])
with gr.Column(scale=5):
gr.HTML(WEB_target)
with gr.Row():
clear_button_edit = gr.Button("Clear", scale=0)
edit_button = gr.Button("Edit", scale=0)
input_text = gr.Textbox(
placeholder="Type the edit text you want:",
label="Input Text",
value=DEFAULT_TEXT
)
video_output = gr.Video(
label="Generated Video",
height=360,
width=480
)
# Event handlers
edit_button.click(
process_and_show_video,
inputs=[input_text, random_key_state],
outputs=video_output
)
random_button.click(
random_source_motion,
inputs=set_to_pick,
outputs=[
retrieved_video_output,
# suggested_edit_text,
random_key_state,
input_text
]
)
# def load_example_video(example_path):
# # motion = current_set[random_key]['motion_a']
# # text_annot = current_set[random_key]['annotation']
# import ipdb; ipdb.set_trace()
# return gr.update(value=example_path, visible=True)
def load_example(example_video, example_key, example_text):
# Update all outputs
return (
gr.update(value=example_video, visible=True), # Update video output
# example_text, # Update suggested edit text
example_key, # Update random key state
example_text # Update input text
)
example_button1.click(
fn=lambda: load_example(example_videos[0], example_keys[0], example_texts[0]),
inputs=None,
outputs=[
retrieved_video_output,
# suggested_edit_text,
random_key_state,
input_text
]
)
example_button2.click(
fn=lambda: load_example(example_videos[1], example_keys[1], example_texts[1]),
inputs=None,
outputs=[
retrieved_video_output,
# suggested_edit_text,
random_key_state,
input_text
]
)
example_button3.click(
fn=lambda: load_example(example_videos[2], example_keys[2], example_texts[2]),
inputs=None,
outputs=[
retrieved_video_output,
# suggested_edit_text,
random_key_state,
input_text
]
)
example_button4.click(
fn=lambda: load_example(example_videos[3], example_keys[3], example_texts[3]),
inputs=None,
outputs=[
retrieved_video_output,
# suggested_edit_text,
random_key_state,
input_text
]
)
clear_button_edit.click(clear, outputs=input_text)
# clear_button_retrieval.click(clear, outputs=suggested_edit_text)
gr.Markdown(CREDITS)
return demo
# Constants
CUSTOM_CSS = """
.gradio-row { display: flex; gap: 20px; }
.gradio-column { flex: 1; }
.gradio-container { display: flex; flex-direction: column; gap: 10px; }
.gradio-button-row { display: flex; gap: 10px; }
.gradio-textbox-row { display: flex; gap: 10px; align-items: center; }
.gradio-edit-row { gap: 10px; align-items: center; }
.gradio-textbox-with-button { display: flex; align-items: center; }
.gradio-textbox-with-button input { flex-grow: 1; }
button.fit-text {
width: auto; /* Automatically adjusts to the text length */
padding: 10px 20px; /* Adjust padding for a better look */
font-size: 12px; /* Control font size */
text-align: center; /* Center the text */
margin: 0 auto; /* Center the button horizontally */
display: inline-block; /* Prevent it from stretching */
}
"""
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch(share=True)