motionfix-demo / OLD_app.py
atnikos's picture
fix egl
545374d
from calendar import EPOCH
from geometry_utils import diffout2motion
import gradio as gr
import spaces
import torch
import os
from pathlib import Path
import smplx
from body_renderer import get_render
import numpy as np
from download_deps import get_smpl_models, download_models, download_model_config
from download_deps import download_tmr, download_motionfix, download_motionfix_dataset
from download_deps import download_embeddings
from website import CREDITS, WEB_source, WEB_target, WEBSITE
# import cv2
# import moderngl
# ctx = moderngl.create_context(standalone=True)
# print(ctx)
# sdk_version: 5.5.0
access_token_smpl = os.environ.get('HF_SMPL_TOKEN')
os.environ["PYOPENGL_PLATFORM"] = "egl"
zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cuda:0' 🤗
DEFAULT_TEXT = "do it slower "
@spaces.GPU
def greet(n):
print(zero.device) # <-- 'cuda:0' 🤗
try:
number = float(n)
except ValueError:
return "Invalid input. Please enter a number."
return f"Hello {zero + number} Tensor"
def clear():
return ""
def show_video(input_text, key_to_use):
from normalization import Normalizer
normalizer = Normalizer()
from diffusion import create_diffusion
from text_encoder import ClipTextEncoder
from tmed_denoiser import TMED_denoiser
model_ckpt = download_models()
infeats = download_model_config()
checkpoint = torch.load(model_ckpt)
# motion_to_edit = download_motion_from_dataset(key_to_use)
# ds_sample = joblib.load(motion_to_edit)
ds_sample = MFIX_DATASET_DICT[key_to_use]
from feature_extractor import FEAT_GET_METHODS
data_dict_source = {f'{feat}_source': FEAT_GET_METHODS[feat](ds_sample['motion_source'])[None].to('cuda')
for feat in infeats}
data_dict_target = {f'{feat}_target': FEAT_GET_METHODS[feat](ds_sample['motion_target'])[None].to('cuda')
for feat in infeats}
full_batch = data_dict_source | data_dict_target
in_batch = normalizer.norm_and_cat(full_batch, infeats)
source_motion_norm = in_batch['source']
target_motion_norm = in_batch['target']
seqlen_tgt = source_motion_norm.shape[0]
seqlen_src = target_motion_norm.shape[0]
# import ipdb; ipdb.set_trace()
checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
tmed_denoiser = TMED_denoiser().to('cuda')
tmed_denoiser.load_state_dict(checkpoint, strict=False)
tmed_denoiser.eval()
text_encoder = ClipTextEncoder()
texts_cond = [input_text]
diffusion_process = create_diffusion(timestep_respacing=None,
learn_sigma=False, sigma_small=True,
diffusion_steps=300,
noise_schedule='squaredcos_cap_v2',
predict_xstart=True)
bsz = 1
no_of_texts = len(texts_cond)
texts_cond = ['']*no_of_texts + texts_cond
texts_cond = ['']*no_of_texts + texts_cond
text_emb, text_mask = text_encoder(texts_cond)
cond_emb_motion = source_motion_norm
cond_motion_mask = torch.ones((bsz, seqlen_src),
dtype=bool, device='cuda')
mask_target = torch.ones((bsz, seqlen_tgt),
dtype=bool, device='cuda')
diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
text_mask.to(cond_emb_motion.device),
cond_emb_motion,
cond_motion_mask,
mask_target,
diffusion_process,
init_vec=None,
init_from='noise',
gd_text=2.0,
gd_motion=2.0,
steps_num=300)
edited_motion = diffout2motion(diff_out.permute(1,0,2), normalizer).squeeze()
gt_source = diffout2motion(source_motion_norm.permute(1,0,2),
normalizer).squeeze()
# import ipdb; ipdb.set_trace()
# aitrenderer = get_renderer()
# SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral')
# edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:],
# trans=edited_motion[..., :3])
SMPL_MODELS_PATH = str(Path(get_smpl_models()))
body_model=smplx.SMPLHLayer(f"{SMPL_MODELS_PATH}/smplh",
model_type='smplh',
gender='neutral',ext='npz')
# run_smpl_fwd_verticesbody_model, body_transl, body_orient, body_pose,
# edited_mot_to_render
from body_renderer import get_render
from transform3d import transform_body_pose
# import ipdb; ipdb.set_trace()
edited_motion_aa = transform_body_pose(edited_motion[:, 3:],
'6d->aa')
gt_source_aa = transform_body_pose(gt_source[:, 3:],
'6d->aa')
if os.path.exists('./output_movie.mp4'):
os.remove('./output_movie.mp4')
from transform3d import rotate_body_degrees
gen_motion_trans = edited_motion[..., :3].detach().cpu()
gen_motion_rots_aa = edited_motion_aa.detach().cpu()
source_motion_trans = gt_source[..., :3].detach().cpu()
source_motion_rots_aa = gt_source_aa.detach().cpu()
gen_rots_rotated, gen_trans_rotated = rotate_body_degrees(transform_body_pose(
gen_motion_rots_aa,
'aa->rot'),
gen_motion_trans, offset=np.pi)
src_rots_rotated, src_trans_rotated = rotate_body_degrees(transform_body_pose(
source_motion_rots_aa,
'aa->rot'),
source_motion_trans, offset=np.pi)
src_rots_rotated_aa = transform_body_pose(src_rots_rotated,
'rot->aa')
gen_rots_rotated_aa = transform_body_pose(gen_rots_rotated,
'rot->aa')
fname = get_render(body_model,
[gen_trans_rotated, src_trans_rotated],
[gen_rots_rotated_aa[:, 0], src_rots_rotated_aa[:, 0]],
[gen_rots_rotated_aa[:, 1:], src_rots_rotated_aa[:, 1:]],
output_path='./output_movie.mp4',
text='', colors=['sky blue', 'red'])
# fname = render_motion(AIT_RENDERER, [edited_mot_to_render],
# f"movie_example--{str(xx)}",
# pose_repr='aa',
# color=[color_map['generated']],
# smpl_layer=SMPL_LAYER)
print(fname)
print(os.path.abspath(fname))
return fname
MFIX_p = download_motionfix() + '/motionfix'
SOURCE_MOTS_p = download_embeddings() + '/embeddings'
MFIX_DATASET_DICT = download_motionfix_dataset()
import gradio as gr
def clear():
return ""
def random_source_motion(set_to_pick):
# import ipdb;ipdb.set_trace()
mfix_train, mfix_test = load_motionfix(MFIX_p)
if set_to_pick == 'all':
current_set = mfix_test | mfix_train
elif set_to_pick == 'train':
current_set = mfix_train
elif set_to_pick == 'test':
current_set = mfix_test
import random
random_key = random.choice(list(current_set.keys()))
curvid = current_set[random_key]['motion_a']
text_annot = current_set[random_key]['annotation']
return curvid, text_annot, random_key, text_annot
def retrieve_video(retrieve_text):
tmr_text_encoder = get_tmr_model(download_tmr())
# import ipdb;ipdb.set_trace()
# text_encoded = tmr_text_encoder([retrieve_text])
motion_embeds = None
from gen_utils import read_json
import numpy as np
motion_embeds = torch.load(SOURCE_MOTS_p+'/source_motions_embeddings.pt')
motion_keyids =np.array(read_json(SOURCE_MOTS_p+'/keyids_embeddings.json'))
mfix_train, mfix_test = load_motionfix(MFIX_p)
all_mots = mfix_test | mfix_train
scores = tmr_text_encoder.compute_scores(retrieve_text, embs=motion_embeds)
sorted_idxs = np.argsort(-scores)
best_keyids = motion_keyids[sorted_idxs]
# best_scores = scores[sorted_idxs]
top_mot = best_keyids[0]
curvid = all_mots[top_mot]['motion_a']
text_annot = all_mots[top_mot]['annotation']
return curvid, text_annot
with gr.Blocks(css="""
.gradio-row {
display: flex;
gap: 20px;
}
.gradio-column {
flex: 1;
}
.gradio-container {
display: flex;
flex-direction: column;
gap: 10px;
}
.gradio-button-row {
display: flex;
gap: 10px;
}
.gradio-textbox-row {
display: flex;
gap: 10px;
align-items: center;
}
.gradio-edit-row {
gap: 10px;
align-items: center;
}
.gradio-textbox-with-button {
display: flex;
align-items: center;
}
.gradio-textbox-with-button input {
flex-grow: 1;
}
""") as demo:
gr.Markdown(WEBSITE)
random_key_state = gr.State()
with gr.Row(elem_id="gradio-row"):
with gr.Column(scale=5, elem_id="gradio-column"):
gr.Markdown(WEB_source)
with gr.Row(elem_id="gradio-button-row"):
# iterative_button = gr.Button("Iterative")
# retrieve_button = gr.Button("TMRetrieve")
random_button = gr.Button("Random")
with gr.Row(elem_id="gradio-textbox-row"):
with gr.Column(scale=5, elem_id="gradio-textbox-with-button"):
# retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
# show_label=True, label="Retrieval Text",
# value=DEFAULT_TEXT)
clear_button_retrieval = gr.Button("Clear", scale=0)
with gr.Row(elem_id="gradio-textbox-row"):
suggested_edit_text = gr.Textbox(placeholder="Texts likely to edit the motion:",
show_label=True, label="Suggested Edit Text",
value='')
xxx = 'https://motion-editing.s3.eu-central-1.amazonaws.com/collection_wo_walks_runs/rendered_pairs/011327_120_240-002682_120_240.mp4'
set_to_pick = gr.Radio(['all', 'train', 'test'],
value='all',
label="Set to pick from",
info="Motion will be picked from whole dataset or test or train data.")
# import ipdb; ipdb.set_trace()
retrieved_video_output = gr.Video(label="Retrieved Motion",
# value=xxx,
height=360, width=480)
with gr.Column(scale=5, elem_id="gradio-column"):
gr.Markdown(WEB_target)
with gr.Row(elem_id="gradio-edit-row"):
clear_button_edit = gr.Button("Clear", scale=0)
edit_button = gr.Button("Edit", scale=0)
with gr.Row(elem_id="gradio-textbox-row"):
input_text = gr.Textbox(placeholder="Type the edit text you want:",
show_label=False, label="Input Text",
value=DEFAULT_TEXT)
video_output = gr.Video(label="Generated Video", height=360,
width=480)
def process_and_show_video(input_text, random_key_state):
fname = show_video(input_text, random_key_state)
return fname
def process_and_retrieve_video(input_text):
fname = retrieve_video(input_text)
return fname
from retrieval_loader import get_tmr_model
from dataset_utils import load_motionfix
edit_button.click(process_and_show_video, inputs=[input_text, random_key_state], outputs=video_output)
# retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=[retrieved_video_output, suggested_edit_text])
random_button.click(random_source_motion, inputs=set_to_pick,
outputs=[retrieved_video_output,
suggested_edit_text,
random_key_state,
input_text])
print(random_key_state)
clear_button_edit.click(clear, outputs=input_text)
# clear_button_retrieval.click(clear, outputs=retrieve_text)
gr.Markdown(CREDITS)
demo.launch(share=True)