Spaces:
Running
Running
from calendar import EPOCH | |
from geometry_utils import diffout2motion | |
import gradio as gr | |
import spaces | |
import torch | |
import os | |
from pathlib import Path | |
import smplx | |
from body_renderer import get_render | |
import numpy as np | |
from download_deps import get_smpl_models, download_models, download_model_config | |
from download_deps import download_tmr, download_motionfix, download_motionfix_dataset | |
from download_deps import download_embeddings | |
from website import CREDITS, WEB_source, WEB_target, WEBSITE | |
# import cv2 | |
# import moderngl | |
# ctx = moderngl.create_context(standalone=True) | |
# print(ctx) | |
# sdk_version: 5.5.0 | |
access_token_smpl = os.environ.get('HF_SMPL_TOKEN') | |
os.environ["PYOPENGL_PLATFORM"] = "egl" | |
zero = torch.Tensor([0]).cuda() | |
print(zero.device) # <-- 'cuda:0' 🤗 | |
DEFAULT_TEXT = "do it slower " | |
def greet(n): | |
print(zero.device) # <-- 'cuda:0' 🤗 | |
try: | |
number = float(n) | |
except ValueError: | |
return "Invalid input. Please enter a number." | |
return f"Hello {zero + number} Tensor" | |
def clear(): | |
return "" | |
def show_video(input_text, key_to_use): | |
from normalization import Normalizer | |
normalizer = Normalizer() | |
from diffusion import create_diffusion | |
from text_encoder import ClipTextEncoder | |
from tmed_denoiser import TMED_denoiser | |
model_ckpt = download_models() | |
infeats = download_model_config() | |
checkpoint = torch.load(model_ckpt) | |
# motion_to_edit = download_motion_from_dataset(key_to_use) | |
# ds_sample = joblib.load(motion_to_edit) | |
ds_sample = MFIX_DATASET_DICT[key_to_use] | |
from feature_extractor import FEAT_GET_METHODS | |
data_dict_source = {f'{feat}_source': FEAT_GET_METHODS[feat](ds_sample['motion_source'])[None].to('cuda') | |
for feat in infeats} | |
data_dict_target = {f'{feat}_target': FEAT_GET_METHODS[feat](ds_sample['motion_target'])[None].to('cuda') | |
for feat in infeats} | |
full_batch = data_dict_source | data_dict_target | |
in_batch = normalizer.norm_and_cat(full_batch, infeats) | |
source_motion_norm = in_batch['source'] | |
target_motion_norm = in_batch['target'] | |
seqlen_tgt = source_motion_norm.shape[0] | |
seqlen_src = target_motion_norm.shape[0] | |
# import ipdb; ipdb.set_trace() | |
checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()} | |
tmed_denoiser = TMED_denoiser().to('cuda') | |
tmed_denoiser.load_state_dict(checkpoint, strict=False) | |
tmed_denoiser.eval() | |
text_encoder = ClipTextEncoder() | |
texts_cond = [input_text] | |
diffusion_process = create_diffusion(timestep_respacing=None, | |
learn_sigma=False, sigma_small=True, | |
diffusion_steps=300, | |
noise_schedule='squaredcos_cap_v2', | |
predict_xstart=True) | |
bsz = 1 | |
no_of_texts = len(texts_cond) | |
texts_cond = ['']*no_of_texts + texts_cond | |
texts_cond = ['']*no_of_texts + texts_cond | |
text_emb, text_mask = text_encoder(texts_cond) | |
cond_emb_motion = source_motion_norm | |
cond_motion_mask = torch.ones((bsz, seqlen_src), | |
dtype=bool, device='cuda') | |
mask_target = torch.ones((bsz, seqlen_tgt), | |
dtype=bool, device='cuda') | |
diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device), | |
text_mask.to(cond_emb_motion.device), | |
cond_emb_motion, | |
cond_motion_mask, | |
mask_target, | |
diffusion_process, | |
init_vec=None, | |
init_from='noise', | |
gd_text=2.0, | |
gd_motion=2.0, | |
steps_num=300) | |
edited_motion = diffout2motion(diff_out.permute(1,0,2), normalizer).squeeze() | |
gt_source = diffout2motion(source_motion_norm.permute(1,0,2), | |
normalizer).squeeze() | |
# import ipdb; ipdb.set_trace() | |
# aitrenderer = get_renderer() | |
# SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral') | |
# edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:], | |
# trans=edited_motion[..., :3]) | |
SMPL_MODELS_PATH = str(Path(get_smpl_models())) | |
body_model=smplx.SMPLHLayer(f"{SMPL_MODELS_PATH}/smplh", | |
model_type='smplh', | |
gender='neutral',ext='npz') | |
# run_smpl_fwd_verticesbody_model, body_transl, body_orient, body_pose, | |
# edited_mot_to_render | |
from body_renderer import get_render | |
from transform3d import transform_body_pose | |
# import ipdb; ipdb.set_trace() | |
edited_motion_aa = transform_body_pose(edited_motion[:, 3:], | |
'6d->aa') | |
gt_source_aa = transform_body_pose(gt_source[:, 3:], | |
'6d->aa') | |
if os.path.exists('./output_movie.mp4'): | |
os.remove('./output_movie.mp4') | |
from transform3d import rotate_body_degrees | |
gen_motion_trans = edited_motion[..., :3].detach().cpu() | |
gen_motion_rots_aa = edited_motion_aa.detach().cpu() | |
source_motion_trans = gt_source[..., :3].detach().cpu() | |
source_motion_rots_aa = gt_source_aa.detach().cpu() | |
gen_rots_rotated, gen_trans_rotated = rotate_body_degrees(transform_body_pose( | |
gen_motion_rots_aa, | |
'aa->rot'), | |
gen_motion_trans, offset=np.pi) | |
src_rots_rotated, src_trans_rotated = rotate_body_degrees(transform_body_pose( | |
source_motion_rots_aa, | |
'aa->rot'), | |
source_motion_trans, offset=np.pi) | |
src_rots_rotated_aa = transform_body_pose(src_rots_rotated, | |
'rot->aa') | |
gen_rots_rotated_aa = transform_body_pose(gen_rots_rotated, | |
'rot->aa') | |
fname = get_render(body_model, | |
[gen_trans_rotated, src_trans_rotated], | |
[gen_rots_rotated_aa[:, 0], src_rots_rotated_aa[:, 0]], | |
[gen_rots_rotated_aa[:, 1:], src_rots_rotated_aa[:, 1:]], | |
output_path='./output_movie.mp4', | |
text='', colors=['sky blue', 'red']) | |
# fname = render_motion(AIT_RENDERER, [edited_mot_to_render], | |
# f"movie_example--{str(xx)}", | |
# pose_repr='aa', | |
# color=[color_map['generated']], | |
# smpl_layer=SMPL_LAYER) | |
print(fname) | |
print(os.path.abspath(fname)) | |
return fname | |
MFIX_p = download_motionfix() + '/motionfix' | |
SOURCE_MOTS_p = download_embeddings() + '/embeddings' | |
MFIX_DATASET_DICT = download_motionfix_dataset() | |
import gradio as gr | |
def clear(): | |
return "" | |
def random_source_motion(set_to_pick): | |
# import ipdb;ipdb.set_trace() | |
mfix_train, mfix_test = load_motionfix(MFIX_p) | |
if set_to_pick == 'all': | |
current_set = mfix_test | mfix_train | |
elif set_to_pick == 'train': | |
current_set = mfix_train | |
elif set_to_pick == 'test': | |
current_set = mfix_test | |
import random | |
random_key = random.choice(list(current_set.keys())) | |
curvid = current_set[random_key]['motion_a'] | |
text_annot = current_set[random_key]['annotation'] | |
return curvid, text_annot, random_key, text_annot | |
def retrieve_video(retrieve_text): | |
tmr_text_encoder = get_tmr_model(download_tmr()) | |
# import ipdb;ipdb.set_trace() | |
# text_encoded = tmr_text_encoder([retrieve_text]) | |
motion_embeds = None | |
from gen_utils import read_json | |
import numpy as np | |
motion_embeds = torch.load(SOURCE_MOTS_p+'/source_motions_embeddings.pt') | |
motion_keyids =np.array(read_json(SOURCE_MOTS_p+'/keyids_embeddings.json')) | |
mfix_train, mfix_test = load_motionfix(MFIX_p) | |
all_mots = mfix_test | mfix_train | |
scores = tmr_text_encoder.compute_scores(retrieve_text, embs=motion_embeds) | |
sorted_idxs = np.argsort(-scores) | |
best_keyids = motion_keyids[sorted_idxs] | |
# best_scores = scores[sorted_idxs] | |
top_mot = best_keyids[0] | |
curvid = all_mots[top_mot]['motion_a'] | |
text_annot = all_mots[top_mot]['annotation'] | |
return curvid, text_annot | |
with gr.Blocks(css=""" | |
.gradio-row { | |
display: flex; | |
gap: 20px; | |
} | |
.gradio-column { | |
flex: 1; | |
} | |
.gradio-container { | |
display: flex; | |
flex-direction: column; | |
gap: 10px; | |
} | |
.gradio-button-row { | |
display: flex; | |
gap: 10px; | |
} | |
.gradio-textbox-row { | |
display: flex; | |
gap: 10px; | |
align-items: center; | |
} | |
.gradio-edit-row { | |
gap: 10px; | |
align-items: center; | |
} | |
.gradio-textbox-with-button { | |
display: flex; | |
align-items: center; | |
} | |
.gradio-textbox-with-button input { | |
flex-grow: 1; | |
} | |
""") as demo: | |
gr.Markdown(WEBSITE) | |
random_key_state = gr.State() | |
with gr.Row(elem_id="gradio-row"): | |
with gr.Column(scale=5, elem_id="gradio-column"): | |
gr.Markdown(WEB_source) | |
with gr.Row(elem_id="gradio-button-row"): | |
# iterative_button = gr.Button("Iterative") | |
# retrieve_button = gr.Button("TMRetrieve") | |
random_button = gr.Button("Random") | |
with gr.Row(elem_id="gradio-textbox-row"): | |
with gr.Column(scale=5, elem_id="gradio-textbox-with-button"): | |
# retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:", | |
# show_label=True, label="Retrieval Text", | |
# value=DEFAULT_TEXT) | |
clear_button_retrieval = gr.Button("Clear", scale=0) | |
with gr.Row(elem_id="gradio-textbox-row"): | |
suggested_edit_text = gr.Textbox(placeholder="Texts likely to edit the motion:", | |
show_label=True, label="Suggested Edit Text", | |
value='') | |
xxx = 'https://motion-editing.s3.eu-central-1.amazonaws.com/collection_wo_walks_runs/rendered_pairs/011327_120_240-002682_120_240.mp4' | |
set_to_pick = gr.Radio(['all', 'train', 'test'], | |
value='all', | |
label="Set to pick from", | |
info="Motion will be picked from whole dataset or test or train data.") | |
# import ipdb; ipdb.set_trace() | |
retrieved_video_output = gr.Video(label="Retrieved Motion", | |
# value=xxx, | |
height=360, width=480) | |
with gr.Column(scale=5, elem_id="gradio-column"): | |
gr.Markdown(WEB_target) | |
with gr.Row(elem_id="gradio-edit-row"): | |
clear_button_edit = gr.Button("Clear", scale=0) | |
edit_button = gr.Button("Edit", scale=0) | |
with gr.Row(elem_id="gradio-textbox-row"): | |
input_text = gr.Textbox(placeholder="Type the edit text you want:", | |
show_label=False, label="Input Text", | |
value=DEFAULT_TEXT) | |
video_output = gr.Video(label="Generated Video", height=360, | |
width=480) | |
def process_and_show_video(input_text, random_key_state): | |
fname = show_video(input_text, random_key_state) | |
return fname | |
def process_and_retrieve_video(input_text): | |
fname = retrieve_video(input_text) | |
return fname | |
from retrieval_loader import get_tmr_model | |
from dataset_utils import load_motionfix | |
edit_button.click(process_and_show_video, inputs=[input_text, random_key_state], outputs=video_output) | |
# retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=[retrieved_video_output, suggested_edit_text]) | |
random_button.click(random_source_motion, inputs=set_to_pick, | |
outputs=[retrieved_video_output, | |
suggested_edit_text, | |
random_key_state, | |
input_text]) | |
print(random_key_state) | |
clear_button_edit.click(clear, outputs=input_text) | |
# clear_button_retrieval.click(clear, outputs=retrieve_text) | |
gr.Markdown(CREDITS) | |
demo.launch(share=True) | |