diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..bdbc63062285e11eca08241cd0db63e5e30d4fd5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,165 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so +.DS_Store +pyglet +app2.py +render.py +cache + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md index 6522c0728af61e1a74e60b2a58062e572f562729..b2fbb5d3901ec87d97f3541c4473666170b108d6 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ --- title: MotionGPT -emoji: 🌖 -colorFrom: blue -colorTo: indigo +emoji: 🏃 +colorFrom: yellow +colorTo: blue sdk: gradio sdk_version: 3.43.2 app_file: app.py pinned: false -license: cc +license: mit --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..28e50679f321ce06de66eb202ffad4b04fc6cf06 --- /dev/null +++ b/app.py @@ -0,0 +1,511 @@ +import gradio as gr +import random +import torch +import time +import cv2 +import os +import numpy as np +import OpenGL.GL as gl +import pytorch_lightning as pl +import moviepy.editor as mp +from pathlib import Path +from mGPT.data.build_data import build_data +from mGPT.models.build_model import build_model +from mGPT.config import parse_args +from scipy.spatial.transform import Rotation as RRR +import mGPT.render.matplot.plot_3d_global as plot_3d +from mGPT.render.pyrender.hybrik_loc2rot import HybrIKJointsToRotmat +from mGPT.render.pyrender.smpl_render import SMPLRender +from transformers import WhisperProcessor, WhisperForConditionalGeneration +import librosa +from huggingface_hub import snapshot_download + +os.environ["PYOPENGL_PLATFORM"] = "egl" +os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1" +os.system('pip install /home/user/app/pyrender') + +# Load model +cfg = parse_args(phase="webui") # parse config file +cfg.FOLDER = 'cache' +output_dir = Path(cfg.FOLDER) +output_dir.mkdir(parents=True, exist_ok=True) +pl.seed_everything(cfg.SEED_VALUE) +if torch.cuda.is_available(): + device = torch.device("cuda") +else: + device = torch.device("cpu") + +model_path = snapshot_download(repo_id="bill-jiang/MotionGPT-base") + +datamodule = build_data(cfg, phase="test") +model = build_model(cfg, datamodule) +state_dict = torch.load(f'{model_path}/motiongpt_s3_h3d.tar', + map_location="cpu")["state_dict"] +model.load_state_dict(state_dict) +model.to(device) + +audio_processor = WhisperProcessor.from_pretrained(cfg.model.whisper_path) +audio_model = WhisperForConditionalGeneration.from_pretrained( + cfg.model.whisper_path).to(device) +forced_decoder_ids_zh = audio_processor.get_decoder_prompt_ids( + language="zh", task="translate") +forced_decoder_ids_en = audio_processor.get_decoder_prompt_ids( + language="en", task="translate") + +# HTML Style + +Video_Components = """ +
+ + + + + + + +
+""" + +Video_Components_example = """ +
+ + + + +
+""" + +Text_Components = """ +

{msg}

+""" + + +def motion_token_to_string(motion_token, lengths, codebook_size=512): + motion_string = [] + for i in range(motion_token.shape[0]): + motion_i = motion_token[i].cpu( + ) if motion_token.device.type == 'cuda' else motion_token[i] + motion_list = motion_i.tolist()[:lengths[i]] + motion_string.append( + (f'' + + ''.join([f'' for i in motion_list]) + + f'')) + return motion_string + + +def render_motion(data, feats, method='fast'): + fname = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime( + time.time())) + str(np.random.randint(10000, 99999)) + video_fname = fname + '.mp4' + feats_fname = fname + '.npy' + output_npy_path = os.path.join(output_dir, feats_fname) + output_mp4_path = os.path.join(output_dir, video_fname) + np.save(output_npy_path, feats) + + if method == 'slow': + if len(data.shape) == 4: + data = data[0] + data = data - data[0, 0] + pose_generator = HybrIKJointsToRotmat() + pose = pose_generator(data) + pose = np.concatenate([ + pose, + np.stack([np.stack([np.eye(3)] * pose.shape[0], 0)] * 2, 1) + ], 1) + shape = [768, 768] + render = SMPLRender(cfg.RENDER.SMPL_MODEL_PATH) + + if not os.environ.get("PYOPENGL_PLATFORM"): + os.environ["DISPLAY"] = ":0.0" + os.environ["PYOPENGL_PLATFORM"] = "egl" + + size = (shape[1], shape[0]) + fps = 20.0 + fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') + videoWriter = cv2.VideoWriter(output_mp4_path, fourcc, fps, size) + r = RRR.from_rotvec(np.array([np.pi, 0.0, 0.0])) + pose[:, 0] = np.matmul(r.as_matrix().reshape(1, 3, 3), pose[:, 0]) + for i in range(data.shape[0]): + img = np.zeros([shape[0], shape[1], 3]) + aroot = data[[i], 0] + np.array([[0.0, 0.0, 30.0]]) + aroot[:, 1] = -aroot[:, 1] + params = dict(pred_shape=np.zeros([1, 10]), + pred_root=aroot, + pred_pose=pose[[i]]) + renderImg = render.render(img.copy(), params) + renderImg = (renderImg * 255).astype(np.uint8) + videoWriter.write(renderImg) + videoWriter.release() + output_video_h264_name = output_mp4_path[:-4] + '_h264.mp4' + command = 'ffmpeg -y -i {} -vcodec h264 {}'.format( + output_mp4_path, output_video_h264_name) + os.system(command) + output_mp4_path = output_video_h264_name + video_fname = video_fname[:-4] + '_h264.mp4' + elif method == 'fast': + output_gif_path = output_mp4_path[:-4] + '.gif' + if len(data.shape) == 3: + data = data[None] + if isinstance(data, torch.Tensor): + data = data.cpu().numpy() + pose_vis = plot_3d.draw_to_batch(data, [''], [output_gif_path]) + out_video = mp.VideoFileClip(output_gif_path) + out_video.write_videofile(output_mp4_path) + + return output_mp4_path, video_fname, output_npy_path, feats_fname + + +def load_motion(motion_uploaded, method): + file = motion_uploaded['file'] + + feats = torch.tensor(np.load(file), device=model.device) + if len(feats.shape) == 2: + feats = feats[None] + # feats = model.datamodule.normalize(feats) + + # Motion tokens + motion_lengths = feats.shape[0] + motion_token, _ = model.vae.encode(feats) + + motion_token_string = model.lm.motion_token_to_string( + motion_token, [motion_token.shape[1]])[0] + motion_token_length = motion_token.shape[1] + + # Motion rendered + joints = model.datamodule.feats2joints(feats.cpu()).cpu().numpy() + output_mp4_path, video_fname, output_npy_path, joints_fname = render_motion( + joints, + feats.to('cpu').numpy(), method) + + motion_uploaded.update({ + "feats": feats, + "joints": joints, + "motion_video": output_mp4_path, + "motion_video_fname": video_fname, + "motion_joints": output_npy_path, + "motion_joints_fname": joints_fname, + "motion_lengths": motion_lengths, + "motion_token": motion_token, + "motion_token_string": motion_token_string, + "motion_token_length": motion_token_length, + }) + + return motion_uploaded + + +def add_text(history, text, motion_uploaded, data_stored, method): + data_stored = data_stored + [{'user_input': text}] + + text = f"""

{text}

""" + history = history + [(text, None)] + if 'file' in motion_uploaded.keys(): + motion_uploaded = load_motion(motion_uploaded, method) + output_mp4_path = motion_uploaded['motion_video'] + video_fname = motion_uploaded['motion_video_fname'] + output_npy_path = motion_uploaded['motion_joints'] + joints_fname = motion_uploaded['motion_joints_fname'] + history = history + [(Video_Components.format( + video_path=output_mp4_path, + video_fname=video_fname, + motion_path=output_npy_path, + motion_fname=joints_fname), None)] + + return history, gr.update(value="", + interactive=False), motion_uploaded, data_stored + + +def add_audio(history, audio_path, data_stored, language='en'): + audio, sampling_rate = librosa.load(audio_path, sr=16000) + input_features = audio_processor( + audio, sampling_rate, return_tensors="pt" + ).input_features # whisper training sampling rate, do not modify + input_features = torch.Tensor(input_features).to(device) + + if language == 'English': + forced_decoder_ids = forced_decoder_ids_en + else: + forced_decoder_ids = forced_decoder_ids_zh + predicted_ids = audio_model.generate(input_features, + forced_decoder_ids=forced_decoder_ids) + text_input = audio_processor.batch_decode(predicted_ids, + skip_special_tokens=True) + text_input = str(text_input).strip('[]"') + data_stored = data_stored + [{'user_input': text_input}] + gr.update(value=data_stored, interactive=False) + history = history + [(text_input, None)] + + return history, data_stored + + +def add_file(history, file, txt, motion_uploaded): + motion_uploaded['file'] = file.name + txt = txt.replace(" ", "") + " " + return history, gr.update(value=txt, interactive=True), motion_uploaded + + +def bot(history, motion_uploaded, data_stored, method): + + motion_length, motion_token_string = motion_uploaded[ + "motion_lengths"], motion_uploaded["motion_token_string"] + + input = data_stored[-1]['user_input'] + prompt = model.lm.placeholder_fulfill(input, motion_length, + motion_token_string, "") + data_stored[-1]['model_input'] = prompt + batch = { + "length": [motion_length], + "text": [prompt], + } + + outputs = model(batch, task="t2m") + out_feats = outputs["feats"][0] + out_lengths = outputs["length"][0] + out_joints = outputs["joints"][:out_lengths].detach().cpu().numpy() + out_texts = outputs["texts"][0] + output_mp4_path, video_fname, output_npy_path, joints_fname = render_motion( + out_joints, + out_feats.to('cpu').numpy(), method) + + motion_uploaded = { + "feats": None, + "joints": None, + "motion_video": None, + "motion_lengths": 0, + "motion_token": None, + "motion_token_string": '', + "motion_token_length": 0, + } + + data_stored[-1]['model_output'] = { + "feats": out_feats, + "joints": out_joints, + "length": out_lengths, + "texts": out_texts, + "motion_video": output_mp4_path, + "motion_video_fname": video_fname, + "motion_joints": output_npy_path, + "motion_joints_fname": joints_fname, + } + + if '' == out_texts: + response = [ + Video_Components.format(video_path=output_mp4_path, + video_fname=video_fname, + motion_path=output_npy_path, + motion_fname=joints_fname) + ] + elif '' in out_texts: + response = [ + Text_Components.format( + msg=out_texts.split("")[0]), + Video_Components.format(video_path=output_mp4_path, + video_fname=video_fname, + motion_path=output_npy_path, + motion_fname=joints_fname), + Text_Components.format( + msg=out_texts.split("")[1]), + ] + else: + response = f"""

{out_texts}

""" + + history[-1][1] = "" + for character in response: + history[-1][1] += character + time.sleep(0.02) + yield history, motion_uploaded, data_stored + + +def bot_example(history, responses): + for response in responses: + history[-1][1] = "" + for character in response: + history[-1][1] += character + time.sleep(0.02) + yield history, motion_uploaded, data_stored + + +# Examples +chat_instruct = [ + (None, + "**👋 Hi, I'm MotionGPT! I can generate realistic human motion from text, or generate text from motion.**" + ), + (None, + "You can chat with me in pure text like generating human motion following your descriptions." + ), + (None, + "After generation, you can click the button in the top right of generation human motion result to download the human motion video or feature stored in .npy format." + ), + (None, + "With the human motion feature file downloaded or got from dataset, you are able to ask me to translate it!" + ), + (None, + "Of courser, you can also purely chat with me and let me give you human motion in text, here are some examples!" + ), + (None, + "We provide two motion visulization methods. The default fast method is skeleton line ploting which is like the examples below:" + ), + (None, + Video_Components_example.format(video_path="assets/videos/t2m_0.mp4", + video_fname="example1.mp4")), + (None, + "And the slow method is SMPL model rendering which is more realistic but slower." + ), + (None, + Video_Components_example.format(video_path="assets/videos/t2m_0.mp4", + video_fname="example1.mp4")), + (None, "👉 Follow the examples and try yourself!"), +] + +t2m_examples = [ + (None, + "You can chat with me in pure text, following are some examples of text-to-motion generation!" + ), + ("Generate a person is walking forwards, but stumbles and steps back, then carries on forward.", + Video_Components_example.format(video_path="assets/videos/t2m_0.mp4", + video_fname="example1.mp4")), + ("Generate a person is walking forwards, but stumbles and steps back, then carries on forward.", + Video_Components_example.format(video_path="assets/videos/t2m_0.mp4", + video_fname="example1.mp4")), + ("Generate a person is walking forwards, but stumbles and steps back, then carries on forward.", + Video_Components_example.format(video_path="assets/videos/t2m_0.mp4", + video_fname="example1.mp4")), +] + +m2t_examples = [ + (None, + "With the human motion feature file downloaded or got from dataset, you are able to ask me to translate it, here are some examples!" + ), + ("Please explain the movement shown in [Motion_tokens] using natural language.", + None), + (Video_Components_example.format(video_path="assets/videos/m2t_0.mp4", + video_fname="example2.mp4"), + "a person walks forward then does a backwards z-shape movement to its left side. then back to the right." + ), + ("Please explain the movement shown in [Motion_tokens] using natural language.", + None), + (Video_Components_example.format(video_path="assets/videos/m2t_0.mp4", + video_fname="example2.mp4"), + "a person walks forward then does a backwards z-shape movement to its left side. then back to the right." + ), +] + +t2t_examples = [ + (None, + "Of courser, you can also purely chat with me and let me give you human motion in text, here are some examples!" + ), + ('Depict a motion as like you have seen it.', + "The person walks while swaying their hips along a curved path to the left slowly then stops to look down at the edge of the grey platform at something." + ), + ('Depict a motion as like you have seen it.', + "The person walks while swaying their hips along a curved path to the left slowly then stops to look down at the edge of the grey platform at something." + ), +] + +Init_chatbot = [ + (None, + "**👋 Hi, I'm MotionGPT! I can generate realistic human motion from text, or generate text from motion.**" + ) +] + t2m_examples[:3] + m2t_examples[:2] + t2t_examples[:2] + chat_instruct[-4:] + +with open("assets/css/custom.css", "r", encoding="utf-8") as f: + customCSS = f.read() + +with gr.Blocks(css=customCSS) as demo: + + # Variables + motion_uploaded = gr.State({ + "feats": None, + "joints": None, + "motion_video": None, + "motion_lengths": 0, + "motion_token": None, + "motion_token_string": '', + "motion_token_length": 0, + }) + data_stored = gr.State([]) + + gr.Markdown("# MotionGPT") + + chatbot = gr.Chatbot(Init_chatbot, + elem_id="mGPT", + height=600, + label="MotionGPT", + avatar_images=(None, + ("assets/images/avatar_bot.jpg")), + bubble_full_width=False) + + with gr.Row(): + with gr.Column(scale=0.85): + with gr.Row(): + txt = gr.Textbox( + label="Text", + show_label=False, + placeholder= + "Enter text and press ENTER or speak to input. You can also upload motion.", + container=False) + + with gr.Row(): + aud = gr.Audio(source="microphone", + label="Speak input", + type='filepath') + btn = gr.UploadButton("📁 Upload motion", + elem_id="upload", + file_types=["file"], + variant='primary') + regen = gr.Button("🔄 Regenerate", elem_id="regen") + clear = gr.ClearButton([txt, chatbot, aud], value='🗑️ Clear') + + with gr.Row(): + gr.Markdown(''' + ### You can get more examples (pre-generated for faster response) by clicking the buttons below: + ''') + + with gr.Row(): + instruct = gr.Button("Instructions", elem_id="instruction") + t2m_eg = gr.Button("Text-to-Motion", elem_id="t2m") + m2t_eg = gr.Button("Motion-to-Text", elem_id="m2t") + t2t_eg = gr.Button("Random description", elem_id="t2t") + + with gr.Column(scale=0.15, min_width=150): + method = gr.Dropdown(["slow", "fast"], + label="Visulization method", + interactive=True, + elem_id="method", + value="fast") + + language = gr.Dropdown(["English", "中文"], + label="Speech language", + interactive=True, + elem_id="language", + value="English") + + txt_msg = txt.submit( + add_text, [chatbot, txt, motion_uploaded, data_stored, method], + [chatbot, txt, motion_uploaded, data_stored], + queue=False).then(bot, [chatbot, motion_uploaded, data_stored, method], + [chatbot, motion_uploaded, data_stored]) + + txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) + + file_msg = btn.upload(add_file, [chatbot, btn, txt, motion_uploaded], + [chatbot, txt, motion_uploaded], + queue=False) + aud_msg = aud.stop_recording( + add_audio, [chatbot, aud, data_stored, language], + [chatbot, data_stored], + queue=False).then(bot, [chatbot, motion_uploaded, data_stored, method], + [chatbot, motion_uploaded, data_stored]) + regen_msg = regen.click(bot, + [chatbot, motion_uploaded, data_stored, method], + [chatbot, motion_uploaded, data_stored], + queue=False) + chatbot.change(scroll_to_output=True) + +demo.queue() + +if __name__ == "__main__": + demo.launch(debug=True) diff --git a/assets/css/custom.css b/assets/css/custom.css new file mode 100644 index 0000000000000000000000000000000000000000..c0ba92a840084d6b943777ddc8f2281cc3665b0d --- /dev/null +++ b/assets/css/custom.css @@ -0,0 +1,359 @@ +/* Borrowed from https://huggingface.co/spaces/project-baize/chat-with-baize */ + +:root { + --chatbot-color-light: #f6f6f6; + --chatbot-color-dark: #121111; +} + +/* Light mode (default) */ +#mGPT { + background-color: var(--chatbot-color-light) !important; + color: #000000 !important; +} +[data-testid='bot'] { + background-color: #ffffff !important; +} +[data-testid='user'] { + background-color: #95ec69 !important; +} + +/* Dark mode */ +.dark #mGPT { + background-color: var(--chatbot-color-dark) !important; + color: #ffffff !important; +} +.dark [data-testid='bot'] { + background-color: #2c2c2c !important; +} + +.dark [data-testid='user'] { + background-color: #26b561 !important; +} + +#mGPT { + height: 100%; + min-height: 500px; +} + +[class*='message-buttons'] { + visibility: hidden; +} + +[class*='message'] { + border: none; + font-size: var(--text-xl) !important; + line-height: var(--line-xl) !important; +} +/* [data-testid='bot'] { + max-width: 85%; + width: auto !important; + border-bottom-left-radius: 0 !important; +} +[data-testid='user'] { + max-width: 85%; + width: auto !important; + border-bottom-right-radius: 0 !important; +} */ + +/* Text & Video */ +#method { + line-height: 1.95 !important; +} + +.side-content { + max-width: 340px; +} + +/* @media only screen and (min-width: 768px) { + .side-content { + float: left; + overflow-wrap: break-word; + padding-right: 2rem; + } + + .side-video { + float: right; + } +} */ + +/* Buttom */ +#upload { + color: #000000; +} + +.videodl-button { + position: absolute; + left: 80%; + top: 5px; + width: 24px; + height: 24px; +} + +.videodl-button svg { + width: 24px; + height: 24px; +} + +.npydl-button { + position: absolute; + left: 90%; + top: 5px; + width: 24px; + height: 24px; +} + +.npydl-button svg { + width: 24px; + height: 24px; +} + +/* Table */ +table { + margin: 1em 0; + border-collapse: collapse; + empty-cells: show; +} +td, +th { + border: 1.2px solid var(--border-color-primary) !important; + padding: 0.2em; +} +thead { + background-color: rgba(175, 184, 193, 0.2); +} +thead th { + padding: 0.5em 0.2em; +} +/* Inline code */ +#mGPT code { + display: inline; + white-space: break-spaces; + border-radius: 6px; + margin: 0 2px 0 2px; + padding: 0.2em 0.4em 0.1em 0.4em; + background-color: rgba(175, 184, 193, 0.2); +} +/* Code block */ +#mGPT pre code { + display: block; + overflow: auto; + white-space: pre; + background-color: hsla(0, 0%, 0%, 80%) !important; + border-radius: 10px; + padding: 1.4em 1.2em 0em 1.4em; + margin: 1.2em 2em 1.2em 0.5em; + color: #fff; + box-shadow: 6px 6px 16px hsla(0, 0%, 0%, 0.2); +} +/* Hightlight */ +#mGPT .highlight { + background-color: transparent; +} +#mGPT .highlight .hll { + background-color: #49483e; +} +#mGPT .highlight .c { + color: #75715e; +} /* Comment */ +#mGPT .highlight .err { + color: #960050; + background-color: #1e0010; +} /* Error */ +#mGPT .highlight .k { + color: #66d9ef; +} /* Keyword */ +#mGPT .highlight .l { + color: #ae81ff; +} /* Literal */ +#mGPT .highlight .n { + color: #f8f8f2; +} /* Name */ +#mGPT .highlight .o { + color: #f92672; +} /* Operator */ +#mGPT .highlight .p { + color: #f8f8f2; +} /* Punctuation */ +#mGPT .highlight .ch { + color: #75715e; +} /* Comment.Hashbang */ +#mGPT .highlight .cm { + color: #75715e; +} /* Comment.Multiline */ +#mGPT .highlight .cp { + color: #75715e; +} /* Comment.Preproc */ +#mGPT .highlight .cpf { + color: #75715e; +} /* Comment.PreprocFile */ +#mGPT .highlight .c1 { + color: #75715e; +} /* Comment.Single */ +#mGPT .highlight .cs { + color: #75715e; +} /* Comment.Special */ +#mGPT .highlight .gd { + color: #f92672; +} /* Generic.Deleted */ +#mGPT .highlight .ge { + font-style: italic; +} /* Generic.Emph */ +#mGPT .highlight .gi { + color: #a6e22e; +} /* Generic.Inserted */ +#mGPT .highlight .gs { + font-weight: bold; +} /* Generic.Strong */ +#mGPT .highlight .gu { + color: #75715e; +} /* Generic.Subheading */ +#mGPT .highlight .kc { + color: #66d9ef; +} /* Keyword.Constant */ +#mGPT .highlight .kd { + color: #66d9ef; +} /* Keyword.Declaration */ +#mGPT .highlight .kn { + color: #f92672; +} /* Keyword.Namespace */ +#mGPT .highlight .kp { + color: #66d9ef; +} /* Keyword.Pseudo */ +#mGPT .highlight .kr { + color: #66d9ef; +} /* Keyword.Reserved */ +#mGPT .highlight .kt { + color: #66d9ef; +} /* Keyword.Type */ +#mGPT .highlight .ld { + color: #e6db74; +} /* Literal.Date */ +#mGPT .highlight .m { + color: #ae81ff; +} /* Literal.Number */ +#mGPT .highlight .s { + color: #e6db74; +} /* Literal.String */ +#mGPT .highlight .na { + color: #a6e22e; +} /* Name.Attribute */ +#mGPT .highlight .nb { + color: #f8f8f2; +} /* Name.Builtin */ +#mGPT .highlight .nc { + color: #a6e22e; +} /* Name.Class */ +#mGPT .highlight .no { + color: #66d9ef; +} /* Name.Constant */ +#mGPT .highlight .nd { + color: #a6e22e; +} /* Name.Decorator */ +#mGPT .highlight .ni { + color: #f8f8f2; +} /* Name.Entity */ +#mGPT .highlight .ne { + color: #a6e22e; +} /* Name.Exception */ +#mGPT .highlight .nf { + color: #a6e22e; +} /* Name.Function */ +#mGPT .highlight .nl { + color: #f8f8f2; +} /* Name.Label */ +#mGPT .highlight .nn { + color: #f8f8f2; +} /* Name.Namespace */ +#mGPT .highlight .nx { + color: #a6e22e; +} /* Name.Other */ +#mGPT .highlight .py { + color: #f8f8f2; +} /* Name.Property */ +#mGPT .highlight .nt { + color: #f92672; +} /* Name.Tag */ +#mGPT .highlight .nv { + color: #f8f8f2; +} /* Name.Variable */ +#mGPT .highlight .ow { + color: #f92672; +} /* Operator.Word */ +#mGPT .highlight .w { + color: #f8f8f2; +} /* Text.Whitespace */ +#mGPT .highlight .mb { + color: #ae81ff; +} /* Literal.Number.Bin */ +#mGPT .highlight .mf { + color: #ae81ff; +} /* Literal.Number.Float */ +#mGPT .highlight .mh { + color: #ae81ff; +} /* Literal.Number.Hex */ +#mGPT .highlight .mi { + color: #ae81ff; +} /* Literal.Number.Integer */ +#mGPT .highlight .mo { + color: #ae81ff; +} /* Literal.Number.Oct */ +#mGPT .highlight .sa { + color: #e6db74; +} /* Literal.String.Affix */ +#mGPT .highlight .sb { + color: #e6db74; +} /* Literal.String.Backtick */ +#mGPT .highlight .sc { + color: #e6db74; +} /* Literal.String.Char */ +#mGPT .highlight .dl { + color: #e6db74; +} /* Literal.String.Delimiter */ +#mGPT .highlight .sd { + color: #e6db74; +} /* Literal.String.Doc */ +#mGPT .highlight .s2 { + color: #e6db74; +} /* Literal.String.Double */ +#mGPT .highlight .se { + color: #ae81ff; +} /* Literal.String.Escape */ +#mGPT .highlight .sh { + color: #e6db74; +} /* Literal.String.Heredoc */ +#mGPT .highlight .si { + color: #e6db74; +} /* Literal.String.Interpol */ +#mGPT .highlight .sx { + color: #e6db74; +} /* Literal.String.Other */ +#mGPT .highlight .sr { + color: #e6db74; +} /* Literal.String.Regex */ +#mGPT .highlight .s1 { + color: #e6db74; +} /* Literal.String.Single */ +#mGPT .highlight .ss { + color: #e6db74; +} /* Literal.String.Symbol */ +#mGPT .highlight .bp { + color: #f8f8f2; +} /* Name.Builtin.Pseudo */ +#mGPT .highlight .fm { + color: #a6e22e; +} /* Name.Function.Magic */ +#mGPT .highlight .vc { + color: #f8f8f2; +} /* Name.Variable.Class */ +#mGPT .highlight .vg { + color: #f8f8f2; +} /* Name.Variable.Global */ +#mGPT .highlight .vi { + color: #f8f8f2; +} /* Name.Variable.Instance */ +#mGPT .highlight .vm { + color: #f8f8f2; +} /* Name.Variable.Magic */ +#mGPT .highlight .il { + color: #ae81ff; +} /* Literal.Number.Integer.Long */ diff --git a/assets/images/avatar_bot.jpg b/assets/images/avatar_bot.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac39e708f460a7d7a336e9f694389f4bd9b5dc52 Binary files /dev/null and b/assets/images/avatar_bot.jpg differ diff --git a/assets/meta/mean.npy b/assets/meta/mean.npy new file mode 100644 index 0000000000000000000000000000000000000000..6c57414d9cf6242bb4b4bab4c33df5e2cc9d2f91 --- /dev/null +++ b/assets/meta/mean.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bdb5ba69a3a9e34d71990db15bc535ebc024c8d95ddb5574196f96058faa7d3 +size 2232 diff --git a/assets/meta/mean_eval.npy b/assets/meta/mean_eval.npy new file mode 100644 index 0000000000000000000000000000000000000000..6c57414d9cf6242bb4b4bab4c33df5e2cc9d2f91 --- /dev/null +++ b/assets/meta/mean_eval.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bdb5ba69a3a9e34d71990db15bc535ebc024c8d95ddb5574196f96058faa7d3 +size 2232 diff --git a/assets/meta/std.npy b/assets/meta/std.npy new file mode 100644 index 0000000000000000000000000000000000000000..93c6b7ae4c2fa23dd21c10a27da1b6966168b35b --- /dev/null +++ b/assets/meta/std.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a5f7d60301c9465972fc225f8ad0ee8f957e7720431189123eb6d15873a9557 +size 2232 diff --git a/assets/meta/std_eval.npy b/assets/meta/std_eval.npy new file mode 100644 index 0000000000000000000000000000000000000000..93c6b7ae4c2fa23dd21c10a27da1b6966168b35b --- /dev/null +++ b/assets/meta/std_eval.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a5f7d60301c9465972fc225f8ad0ee8f957e7720431189123eb6d15873a9557 +size 2232 diff --git a/assets/videos/m2t_0.mp4 b/assets/videos/m2t_0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9b63d8cdbc4fb04e7b2fd7a22e77a073bc857050 Binary files /dev/null and b/assets/videos/m2t_0.mp4 differ diff --git a/assets/videos/t2m_0.mp4 b/assets/videos/t2m_0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..849a488d78975aa537df1998b9bccb36ff4fd604 Binary files /dev/null and b/assets/videos/t2m_0.mp4 differ diff --git a/configs/assets.yaml b/configs/assets.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1b09c27fc206f91a4ccad6e610662d303ffb1922 --- /dev/null +++ b/configs/assets.yaml @@ -0,0 +1,32 @@ +CONFIG_FOLDER: configs # Config files path +FOLDER: experiments # Experiment files saving path + +TEST: + FOLDER: results # Testing files saving path + +DATASET: + TASK_ROOT: deps/mGPT_instructions + SMPL_PATH: deps/smpl + TRANSFORM_PATH: deps/transforms/ + WORD_VERTILIZER_PATH: deps/glove/ + KIT: + ROOT: datasets/kit-ml # KIT directory + SPLIT_ROOT: datasets/kit-ml # KIT splits directory + MEAN_STD_PATH: deps/t2m/ + HUMANML3D: + ROOT: datasets/humanml3d # HumanML3D directory + SPLIT_ROOT: datasets/humanml3d # HumanML3D splits directory + MEAN_STD_PATH: deps/t2m/ + +METRIC: + TM2T: + t2m_path: deps/t2m/ # path for tm2t evaluator + +model: + whisper_path: openai/whisper-large-v2 # path for whisper model, webui only + +RENDER: + BLENDER_PATH: libs/blender-2.93.2-linux-x64/blender + SMPL_MODEL_PATH: deps/smpl/smpl_models/smpl + MODEL_PATH: deps/smpl/smpl_models/ + FACES_PATH: deps/smplh/smplh.faces diff --git a/configs/default.yaml b/configs/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2db9a51287d006c38ba9f0577bb915b2aff02888 --- /dev/null +++ b/configs/default.yaml @@ -0,0 +1,141 @@ +SEED_VALUE: 1234 # Seed value +DEBUG: True # Debug mode +FULL_CONFIG: false + +TRAIN: + SPLIT: 'train' # Training split name + NUM_WORKERS: 8 # Number of workers + BATCH_SIZE: 8 # Size of batches + END_EPOCH: 2000 # End epoch + + RESUME: '' # Experiment path to be resumed training + PRETRAINED_VAE: '' # Pretrained vae/vqvae model path + PRETRAINED: '' # Pretrained model path + + OPTIM: + target: AdamW + params: + lr: 2e-4 + betas: [0.9, 0.99] + weight_decay: 0.0 + + LR_SCHEDULER: + target: CosineAnnealingLR + params: + T_max: ${eval:${LOGGER.VAL_EVERY_STEPS} * 100} + eta_min: 1e-6 + +EVAL: + SPLIT: 'val' # Validation split name + BATCH_SIZE: 16 # Validation Batch size + NUM_WORKERS: 8 # Validation Batch size + +TEST: + CHECKPOINTS: '' # Pretrained model path + SPLIT: 'test' # Testing split name + BATCH_SIZE: 16 # Testing Batch size + NUM_WORKERS: 8 # Testing Batch size + + SAVE_PREDICTIONS: False # Weather to save predictions + COUNT_TIME: False # Weather to count time during test + REPLICATION_TIMES: 20 # Number of times to replicate the test + REP_I: 0 # For counting replication times + +model: + target: mGPT.models.mgpt.MotionGPT + params: + condition: 'text' + task: 't2m' + lm: ${lm.default} + motion_vae: ${vq.default} + + # Related parameters + stage: ${TRAIN.STAGE} + debug: ${DEBUG} + codebook_size: ${model.params.motion_vae.params.code_num} + metrics_dict: ${METRIC.TYPE} + +LOSS: + LAMBDA_REC: 1.0 # Lambda for reconstruction losses + LAMBDA_JOINT: 1.0 # Lambda for joint losses + + LAMBDA_LATENT: 1e-5 # Lambda for latent losses + LAMBDA_KL: 1e-5 # Lambda for kl losses + LAMBDA_GEN: 1.0 # Lambda for text-motion generation losses + LAMBDA_CROSS: 1.0 # Lambda for cross-reconstruction losses + LAMBDA_CYCLE: 1.0 # Lambda for cycle losses + LAMBDA_PRIOR: 0.0 # Lambda for diffusion prior losses + + LAMBDA_VELOCITY: 0.5 # Lambda for velocity losses + LAMBDA_COMMIT: 0.02 # Lambda for commitment losses + + ABLATION: + RECONS_LOSS: 'l1_smooth' + +METRIC: + TASK: 't2m' + FORCE_IN_METER: True + DIST_SYNC_ON_STEP: True + MM_NUM_SAMPLES: 100 # Number of samples for multimodal test + MM_NUM_REPEATS: 30 # Number of repeats for multimodal test + MM_NUM_TIMES: 10 # Number of times to repeat the multimodal test + DIVERSITY_TIMES: 300 # Number of times to repeat the diversity test + TM2T: ${evaluator.tm2t} + +DATASET: + target: mGPT.data.HumanML3D.HumanML3DDataModule + CODE_PATH: 'VQVAE' + TASK_ROOT: '' + TASK_PATH: '' + NFEATS: 263 + KIT: + MAX_MOTION_LEN: 196 + MIN_MOTION_LEN: 24 + MAX_TEXT_LEN: 20 + PICK_ONE_TEXT: true + FRAME_RATE: 12.5 + UNIT_LEN: 4 + HUMANML3D: + MAX_MOTION_LEN: 196 + MIN_MOTION_LEN: 40 + MAX_TEXT_LEN: 20 + PICK_ONE_TEXT: true + FRAME_RATE: 20.0 + UNIT_LEN: 4 + STD_TEXT: False + +ABLATION: + # For MotionGPT + use_length: False + predict_ratio: 0.2 + inbetween_ratio: 0.25 + image_size: 256 + + # For Motion-latent-diffusion + VAE_TYPE: 'actor' # vae ablation: actor or mcross + VAE_ARCH: 'encoder_decoder' # mdiffusion vae architecture + PE_TYPE: 'actor' # mdiffusion mld or actor + DIFF_PE_TYPE: 'actor' # mdiffusion mld or actor + SKIP_CONNECT: False # skip connection for denoiser va + MLP_DIST: False # use linear to expand mean and std rather expand token nums + IS_DIST: False # Mcross distribution kl + PREDICT_EPSILON: True # noise or motion + +LOGGER: + VAL_EVERY_STEPS: 10 + LOGGERS: ['tensorboard', 'wandb'] + TENSORBOARD: + target: pytorch_lightning.loggers.TensorBoardLogger + params: + save_dir: ${FOLDER_EXP} + name: 'tensorboard' + version: '' + WANDB: + target: pytorch_lightning.loggers.WandbLogger + params: + project: null + offline: False + id: null + version: '' + name: ${NAME} + save_dir: ${FOLDER_EXP} diff --git a/configs/evaluator/tm2t.yaml b/configs/evaluator/tm2t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2cf48229d451e8756930fdd220441ece761221d4 --- /dev/null +++ b/configs/evaluator/tm2t.yaml @@ -0,0 +1,19 @@ +t2m_textencoder: + target: mGPT.archs.tm2t_evaluator.TextEncoderBiGRUCo + params: + word_size: 300 + pos_size: 15 + hidden_size: 512 + output_size: 512 +t2m_moveencoder: + target: mGPT.archs.tm2t_evaluator.MovementConvEncoder + params: + input_size: ${eval:${DATASET.NFEATS} - 4} + hidden_size: 512 + output_size: 512 +t2m_motionencoder: + target: mGPT.archs.tm2t_evaluator.MotionEncoderBiGRUCo + params: + input_size: ${evaluator.tm2t.t2m_moveencoder.params.output_size} + hidden_size: 1024 + output_size: 512 diff --git a/configs/lm/default.yaml b/configs/lm/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..14d9d4cfe68ac2a2a91ae00c347dbe5abf19314f --- /dev/null +++ b/configs/lm/default.yaml @@ -0,0 +1,7 @@ +target: mGPT.archs.mgpt_lm.MLM +params: + model_type: t5 + model_path: google/flan-t5-base + stage: ${TRAIN.STAGE} + motion_codebook_size: ${model.params.codebook_size} + ablation: ${ABLATION} diff --git a/configs/render.yaml b/configs/render.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e8ca3f7623207bb1c7a4f7cc81e77313a2c752b --- /dev/null +++ b/configs/render.yaml @@ -0,0 +1,23 @@ +NAME: '___render_do_not_need_name__' # Experiment name +ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto” +DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3] + +RENDER: + FOLDER: '___no_need__' + INPUT_MODE: 'npy' + DIR: '' + NPY: '___no_need__' + DENOISING: True + OLDRENDER: True + # ["ultra", "high", "med", "low"] + # RES: 'high' + RES: 'med' + DOWNSAMPLE: False + FPS: 20.0 + CANONICALIZE: True + EXACT_FRAME: 0.5 + NUM: 8 + MODE: '___no_need__' #sequence frame video + VID_EXT: mp4 + ALWAYS_ON_FLOOR: false + GT: false diff --git a/configs/vq/default.yaml b/configs/vq/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..27ac6fa7c498748cb5b9c92d0cc050b31a677fd3 --- /dev/null +++ b/configs/vq/default.yaml @@ -0,0 +1,15 @@ +target: mGPT.archs.mgpt_vq.VQVae +params: + quantizer: 'ema_reset' + code_num: 512 + code_dim: 512 + output_emb_width: 512 + down_t: 2 + stride_t: 2 + width: 512 + depth: 3 + dilation_growth_rate: 3 + norm: None + activation: 'relu' + nfeats: ${DATASET.NFEATS} + ablation: ${ABLATION} diff --git a/configs/webui.yaml b/configs/webui.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9a05f7f0d88d110f06627e1cd878c5ec83c8169b --- /dev/null +++ b/configs/webui.yaml @@ -0,0 +1,74 @@ +NAME: Webui # Experiment name +DEBUG: False # Debug mode +ACCELERATOR: 'cpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto” +DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3] + +# Training configuration +TRAIN: + #--------------------------------- + STAGE: lm_instruct + DATASETS: ['humanml3d'] # Training datasets + NUM_WORKERS: 32 # Number of workers + BATCH_SIZE: 16 # Size of batches + START_EPOCH: 0 # Start epochMMOTIONENCODER + END_EPOCH: 99999 # End epoch + ABLATION: + pkeep: 0.5 + OPTIM: + TYPE: AdamW # Optimizer type + LR: 2e-4 # Learning rate + WEIGHT_DECAY: 0.0 + LR_SCHEDULER: [100, 200, 300, 400] + GAMMA: 0.8 + +# Evaluating Configuration +EVAL: + DATASETS: ['humanml3d'] # Evaluating datasets + BATCH_SIZE: 32 # Evaluating Batch size + SPLIT: test + +# Test Configuration +TEST: + CHECKPOINTS: checkpoints/MotionGPT-base/motiongpt_s3_h3d.ckpt + DATASETS: ['humanml3d'] # training datasets + SPLIT: test + BATCH_SIZE: 32 # training Batch size + MEAN: False + NUM_SAMPLES: 1 + FACT: 1 + +# Datasets Configuration +DATASET: + JOINT_TYPE: 'humanml3d' # join type + CODE_PATH: 'VQBEST' +METRIC: + TYPE: ['TM2TMetrics'] +# Losses Configuration +LOSS: + TYPE: t2mgpt # Losses type + LAMBDA_FEATURE: 1.0 + LAMBDA_VELOCITY: 0.5 + LAMBDA_COMMIT: 0.02 + LAMBDA_CLS: 1.0 + LAMBDA_M2T2M: 1.0 + LAMBDA_T2M2T: 10.0 + ABLATION: + RECONS_LOSS: 'l1_smooth' + +# Model Configuration +model: + target: mGPT.models.mgpt.MotionGPT + params: + condition: 'text' + task: 't2m' + lm: ${lm.default} + motion_vae: ${vq.default} + +# Logger configuration +LOGGER: + LOG_EVERY_STEPS: 5 + VAL_EVERY_STEPS: 10 + TENSORBOARD: True + wandb: + params: + project: null diff --git a/deps/smpl/smpl_models/SMPL_downsample_index.pkl b/deps/smpl/smpl_models/SMPL_downsample_index.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7bb54c4f1e03340ad58b60485abaed1641d68d47 --- /dev/null +++ b/deps/smpl/smpl_models/SMPL_downsample_index.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5b783c1677079397ee4bc26df5c72d73b8bb393bea41fa295b951187443daec +size 3556 diff --git a/deps/smpl/smpl_models/gmm_08.pkl b/deps/smpl/smpl_models/gmm_08.pkl new file mode 100644 index 0000000000000000000000000000000000000000..c97a1d7ef396581e56ce74a12cc39175680ce028 --- /dev/null +++ b/deps/smpl/smpl_models/gmm_08.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1374908aae055a2afa01a2cd9a169bc6cfec1ceb7aa590e201a47b383060491 +size 839127 diff --git a/deps/smpl/smpl_models/neutral_smpl_mean_params.h5 b/deps/smpl/smpl_models/neutral_smpl_mean_params.h5 new file mode 100644 index 0000000000000000000000000000000000000000..b6ecce2a748128cfde09b219ccc74307de50bbae --- /dev/null +++ b/deps/smpl/smpl_models/neutral_smpl_mean_params.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac9b474c74daec0253ed084720f662059336e976850f08a4a9a3f76d06613776 +size 4848 diff --git a/deps/smpl/smpl_models/smpl.faces b/deps/smpl/smpl_models/smpl.faces new file mode 100644 index 0000000000000000000000000000000000000000..3b1493c4a3853b78df703c0417e08b7464b99e75 Binary files /dev/null and b/deps/smpl/smpl_models/smpl.faces differ diff --git a/deps/smpl/smpl_models/smpl.tar.gz b/deps/smpl/smpl_models/smpl.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..283061b808f77e0db81b81e3398641d0dbab57b9 --- /dev/null +++ b/deps/smpl/smpl_models/smpl.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf4793af6b29677b0841c58db392642cb70b477890dc91de01128c7f34738d8d +size 45 diff --git a/deps/smpl/smpl_models/smpl/SMPL_FEMALE.pkl b/deps/smpl/smpl_models/smpl/SMPL_FEMALE.pkl new file mode 100644 index 0000000000000000000000000000000000000000..92a201f4839bd95c1c1986437c7c6a02d7d1ae99 --- /dev/null +++ b/deps/smpl/smpl_models/smpl/SMPL_FEMALE.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a583c1b98e4afc19042641f1bae5cd8a1f712a6724886291a7627ec07acd408d +size 39056454 diff --git a/deps/smpl/smpl_models/smpl/SMPL_MALE.pkl b/deps/smpl/smpl_models/smpl/SMPL_MALE.pkl new file mode 100644 index 0000000000000000000000000000000000000000..43dfecc57d9b7aa99cd2398df818ba252be7f605 --- /dev/null +++ b/deps/smpl/smpl_models/smpl/SMPL_MALE.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e8c0bbbbc635dcb166ed29c303fb4bef16ea5f623e5a89263495a9e403575bd +size 39056404 diff --git a/deps/smpl/smpl_models/smpl/SMPL_NEUTRAL.pkl b/deps/smpl/smpl_models/smpl/SMPL_NEUTRAL.pkl new file mode 100644 index 0000000000000000000000000000000000000000..26574fd104c4b69467f3c7c3516a8508d8a1a36e --- /dev/null +++ b/deps/smpl/smpl_models/smpl/SMPL_NEUTRAL.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98e65c74ad9b998783132f00880d1025a8d64b158e040e6ef13a557e5098bc42 +size 39001280 diff --git a/deps/smpl/smpl_models/smpl/readme.txt b/deps/smpl/smpl_models/smpl/readme.txt new file mode 100644 index 0000000000000000000000000000000000000000..ea6871b75940d38c2de4c4fccaa0ef3c4c858f67 --- /dev/null +++ b/deps/smpl/smpl_models/smpl/readme.txt @@ -0,0 +1 @@ +This directory leaves for SMPL models diff --git a/deps/smpl/smpl_models/smplh/SMPLH_FEMALE.npz b/deps/smpl/smpl_models/smplh/SMPLH_FEMALE.npz new file mode 100644 index 0000000000000000000000000000000000000000..ec5dac2f22900da972133e7c7c759bf17d6e4443 --- /dev/null +++ b/deps/smpl/smpl_models/smplh/SMPLH_FEMALE.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0fba73ef2494b26de243c1d88a1dbe1047e5566128cf7222c942089543f4560 +size 39708434 diff --git a/deps/smpl/smpl_models/smplh/SMPLH_MALE.npz b/deps/smpl/smpl_models/smplh/SMPLH_MALE.npz new file mode 100644 index 0000000000000000000000000000000000000000..77701908a303d027fe8a7c95989a7b956fc02889 --- /dev/null +++ b/deps/smpl/smpl_models/smplh/SMPLH_MALE.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10b617fdd329557937d6fe38e8a542afab236a8887522d9da0bd42e7f2b76eaa +size 39686902 diff --git a/deps/smpl/smpl_models/smplh/SMPLH_NEUTRAL.npz b/deps/smpl/smpl_models/smplh/SMPLH_NEUTRAL.npz new file mode 100644 index 0000000000000000000000000000000000000000..358e126e49758ef7e42c62000e9ffc3cd41eebb0 --- /dev/null +++ b/deps/smpl/smpl_models/smplh/SMPLH_NEUTRAL.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42969b34d8cd383e172515a7bca6ff3b2c37aa2c5c78088c69d20e517fa96026 +size 39708959 diff --git a/deps/smpl/smpl_models/smplh/mano_v1_2.zip b/deps/smpl/smpl_models/smplh/mano_v1_2.zip new file mode 100644 index 0000000000000000000000000000000000000000..a343e0ce79585fa35ec6d7f421f2a101b47780fa --- /dev/null +++ b/deps/smpl/smpl_models/smplh/mano_v1_2.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50976831790ea9657d8110e0c94e50e90eaf35cd76169f0b27e5d32f3fcd951f +size 175200815 diff --git a/deps/smpl/smpl_models/smplh/smplh.faces b/deps/smpl/smpl_models/smplh/smplh.faces new file mode 100644 index 0000000000000000000000000000000000000000..20c84e393f281fa5d98769a90a33a89cd00d8fea Binary files /dev/null and b/deps/smpl/smpl_models/smplh/smplh.faces differ diff --git a/deps/smpl/smpl_models/smplh/smplh.tar.xz b/deps/smpl/smpl_models/smplh/smplh.tar.xz new file mode 100644 index 0000000000000000000000000000000000000000..cac9d77988a33a50d9a6db15acd958198ccb122c --- /dev/null +++ b/deps/smpl/smpl_models/smplh/smplh.tar.xz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46d5b8687be48c91181fa88271feff3a5e83aa62a481fad8a0bcb9254b2a74f1 +size 113231292 diff --git a/deps/smpl/smpl_models/smplx_parts_segm.pkl b/deps/smpl/smpl_models/smplx_parts_segm.pkl new file mode 100644 index 0000000000000000000000000000000000000000..77ce98631741ba3887d689077baf35422d39299d --- /dev/null +++ b/deps/smpl/smpl_models/smplx_parts_segm.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb69c10801205c9cfb5353fdeb1b9cc5ade53d14c265c3339421cdde8b9c91e7 +size 1323168 diff --git a/mGPT/__init__.py b/mGPT/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mGPT/archs/__init__.py b/mGPT/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mGPT/archs/mgpt_lm.py b/mGPT/archs/mgpt_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..c30307adbcdbc9c8137f8e021c991fac783dbf0f --- /dev/null +++ b/mGPT/archs/mgpt_lm.py @@ -0,0 +1,592 @@ +import os +from typing import List, Union +import numpy as np +import math +import time +import heapq +import torch +from torch import Tensor, nn +from torch.distributions.distribution import Distribution +from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer +import random +from typing import Optional +from .tools.token_emb import NewTokenEmb + + +class MLM(nn.Module): + + def __init__( + self, + model_path: str, + model_type: str = "t5", + stage: str = "lm_pretrain", + new_token_type: str = "insert", + motion_codebook_size: int = 512, + framerate: float = 20.0, + down_t: int = 4, + predict_ratio: float = 0.2, + inbetween_ratio: float = 0.25, + max_length: int = 256, + lora: bool = False, + quota_ratio: float = 0.5, + noise_density: float = 0.15, + mean_noise_span_length: int = 3, + **kwargs, + ) -> None: + + super().__init__() + + # Parameters + self.m_codebook_size = motion_codebook_size + self.max_length = max_length + self.framerate = framerate + self.down_t = down_t + self.predict_ratio = predict_ratio + self.inbetween_ratio = inbetween_ratio + self.noise_density = noise_density + self.mean_noise_span_length = mean_noise_span_length + self.quota_ratio = quota_ratio + self.stage = stage + + # Instantiate language model + self.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=True) + if model_type == "t5": + self.language_model = T5ForConditionalGeneration.from_pretrained( + model_path) + self.lm_type = 'encdec' + elif model_type == "gpt2": + self.language_model = GPT2LMHeadModel.from_pretrained(model_path) + self.lm_type = 'dec' + else: + raise ValueError("type must be either seq2seq or conditional") + + if self.lm_type == 'dec': + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Add motion tokens + self.tokenizer.add_tokens( + [f'' for i in range(self.m_codebook_size + 3)]) + + if new_token_type == "insert": + self.language_model.resize_token_embeddings(len(self.tokenizer)) + elif new_token_type == "mlp": + shared = NewTokenEmb(self.language_model.shared, + self.m_codebook_size + 3) + # lm_head = NewTokenEmb(self.language_model.lm_head, + # self.m_codebook_size + 3) + self.language_model.resize_token_embeddings(len(self.tokenizer)) + self.language_model.shared = shared + # self.language_model.lm_head = lm_head + + # Lora + if lora: + from peft import LoraConfig, TaskType, get_peft_model, get_peft_model_state_dict + from peft.utils.other import fsdp_auto_wrap_policy + peft_config = LoraConfig( + bias="none", + task_type="CAUSAL_LM", + # inference_mode=False, + r=8, + lora_alpha=16, + lora_dropout=0.05) + self.language_model = get_peft_model(self.language_model, + peft_config) + + def forward(self, texts: List[str], motion_tokens: Tensor, + lengths: List[int], tasks: dict): + if self.lm_type == 'encdec': + return self.forward_encdec(texts, motion_tokens, lengths, tasks) + elif self.lm_type == 'dec': + return self.forward_dec(texts, motion_tokens, lengths, tasks) + else: + raise NotImplementedError("Only conditional_multitask supported") + + def forward_encdec( + self, + texts: List[str], + motion_tokens: Tensor, + lengths: List[int], + tasks: dict, + ): + + # Tensor to string + motion_strings = self.motion_token_to_string(motion_tokens, lengths) + + # Supervised or unsupervised + # condition = random.choice( + # ['text', 'motion', 'supervised', 'supervised', 'supervised']) + condition = random.choice(['supervised', 'supervised', 'supervised']) + + if condition == 'text': + inputs = texts + outputs = texts + elif condition == 'motion': + inputs = motion_strings + outputs = motion_strings + else: + inputs, outputs = self.template_fulfill(tasks, lengths, + motion_strings, texts) + + # Tokenize + source_encoding = self.tokenizer(inputs, + padding='max_length', + max_length=self.max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt") + + source_attention_mask = source_encoding.attention_mask.to( + motion_tokens.device) + source_input_ids = source_encoding.input_ids.to(motion_tokens.device) + + if condition in ['text', 'motion']: + batch_size, expandend_input_length = source_input_ids.shape + mask_indices = np.asarray([ + self.random_spans_noise_mask(expandend_input_length) + for i in range(batch_size) + ]) + target_mask = ~mask_indices + input_ids_sentinel = self.create_sentinel_ids( + mask_indices.astype(np.int8)) + target_sentinel = self.create_sentinel_ids( + target_mask.astype(np.int8)) + + labels_input_ids = self.filter_input_ids(source_input_ids, + target_sentinel) + source_input_ids = self.filter_input_ids(source_input_ids, + input_ids_sentinel) + + else: + target_inputs = self.tokenizer(outputs, + padding='max_length', + max_length=self.max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt") + + labels_input_ids = target_inputs.input_ids.to(motion_tokens.device) + lables_attention_mask = target_inputs.attention_mask.to( + motion_tokens.device) + + labels_input_ids[labels_input_ids == 0] = -100 + outputs = self.language_model( + input_ids=source_input_ids, + attention_mask=source_attention_mask + if condition == 'supervised' else None, + labels=labels_input_ids, + decoder_attention_mask=lables_attention_mask + if condition == 'supervised' else None, + ) + + return outputs + + def forward_dec( + self, + texts: List[str], + motion_tokens: Tensor, + lengths: List[int], + tasks: dict, + ): + self.tokenizer.padding_side = "right" + + # Tensor to string + motion_strings = self.motion_token_to_string(motion_tokens, lengths) + + # Supervised or unsupervised + condition = random.choice( + ['text', 'motion', 'supervised', 'supervised', 'supervised']) + + if condition == 'text': + labels = texts + elif condition == 'motion': + labels = motion_strings + else: + inputs, outputs = self.template_fulfill(tasks, lengths, + motion_strings, texts) + labels = [] + for i in range(len(inputs)): + labels.append(inputs[i] + ' \n ' + outputs[i] + + self.tokenizer.eos_token) + + # Tokenize + inputs = self.tokenizer(labels, + padding='max_length', + max_length=self.max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt") + + labels_input_ids = inputs.input_ids.to(motion_tokens.device) + lables_attention_mask = inputs.attention_mask.to(motion_tokens.device) + + # print(labels_input_ids[0:5]) + + outputs = self.language_model(input_ids=labels_input_ids, + attention_mask=lables_attention_mask, + labels=inputs["input_ids"]) + + return outputs + + def generate_direct(self, + texts: List[str], + max_length: int = 256, + num_beams: int = 1, + do_sample: bool = True, + bad_words_ids: List[int] = None): + + # Device + self.device = self.language_model.device + + # Tokenize + if self.lm_type == 'dec': + texts = [text + " \n " for text in texts] + + source_encoding = self.tokenizer(texts, + padding='max_length', + max_length=self.max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt") + + source_input_ids = source_encoding.input_ids.to(self.device) + source_attention_mask = source_encoding.attention_mask.to(self.device) + + if self.lm_type == 'encdec': + outputs = self.language_model.generate( + source_input_ids, + max_length=max_length, + num_beams=num_beams, + do_sample=do_sample, + bad_words_ids=bad_words_ids, + ) + elif self.lm_type == 'dec': + outputs = self.language_model.generate( + input_ids=source_input_ids, + attention_mask=source_attention_mask, + pad_token_id=self.tokenizer.pad_token_id, + do_sample=do_sample, + max_new_tokens=max_length) + self.tokenizer.padding_side = 'left' + + outputs_string = self.tokenizer.batch_decode(outputs, + skip_special_tokens=True) + + print(texts[:2]) + print(outputs_string[:2]) + + outputs_tokens, cleaned_text = self.motion_string_to_token( + outputs_string) + + return outputs_tokens, cleaned_text + + def generate_conditional(self, + texts: Optional[List[str]] = None, + motion_tokens: Optional[Tensor] = None, + lengths: Optional[List[int]] = None, + task: str = "t2m", + with_len: bool = False, + stage: str = 'train', + tasks: dict = None): + + self.device = self.language_model.device + + if task in ["t2m", "m2m", "pred", "inbetween"]: + + if task == "t2m": + assert texts is not None + motion_strings = [''] * len(texts) + if not with_len: + if tasks is None: + tasks = [{ + 'input': + ['Generate motion: '], + 'output': [''] + }] * len(texts) + + lengths = [0] * len(texts) + else: + tasks = [{ + 'input': [ + 'Generate motion with frames: ' + ], + 'output': [''] + }] * len(texts) + + elif task == "pred": + assert motion_tokens is not None and lengths is not None + texts = [''] * len(lengths) + tasks = [{ + 'input': ['Predict motion: '], + 'output': [''] + }] * len(lengths) + + motion_strings_old = self.motion_token_to_string( + motion_tokens, lengths) + motion_strings = [] + for i, length in enumerate(lengths): + split = length // 5 + motion_strings.append( + '>'.join(motion_strings_old[i].split('>')[:split]) + + '>') + + elif task == "inbetween": + assert motion_tokens is not None and lengths is not None + texts = [''] * len(lengths) + tasks = [{ + 'input': [ + "Complete the masked motion: " + ], + 'output': [''] + }] * len(lengths) + motion_strings = self.motion_token_to_string( + motion_tokens, lengths) + + inputs, outputs = self.template_fulfill(tasks, lengths, + motion_strings, texts, + stage) + + outputs_tokens, cleaned_text = self.generate_direct(inputs, + max_length=128, + num_beams=1, + do_sample=True) + + return outputs_tokens + + elif task == "m2t": + assert motion_tokens is not None and lengths is not None + + motion_strings = self.motion_token_to_string( + motion_tokens, lengths) + + if not with_len: + tasks = [{ + 'input': ['Generate text: '], + 'output': [''] + }] * len(lengths) + else: + tasks = [{ + 'input': [ + 'Generate text with frames: ' + ], + 'output': [''] + }] * len(lengths) + + texts = [''] * len(lengths) + + inputs, outputs = self.template_fulfill(tasks, lengths, + motion_strings, texts) + outputs_tokens, cleaned_text = self.generate_direct( + inputs, + max_length=40, + num_beams=1, + do_sample=False, + # bad_words_ids=self.bad_words_ids + ) + return cleaned_text + + def motion_token_to_string(self, motion_token: Tensor, lengths: List[int]): + motion_string = [] + for i in range(len(motion_token)): + motion_i = motion_token[i].cpu( + ) if motion_token[i].device.type == 'cuda' else motion_token[i] + motion_list = motion_i.tolist()[:lengths[i]] + motion_string.append( + (f'' + + ''.join([f'' for i in motion_list]) + + f'')) + return motion_string + + def motion_token_list_to_string(self, motion_token: Tensor): + motion_string = [] + for i in range(len(motion_token)): + motion_i = motion_token[i].cpu( + ) if motion_token[i].device.type == 'cuda' else motion_token[i] + motion_list = motion_i.tolist() + motion_string.append( + (f'' + + ''.join([f'' for i in motion_list]) + + f'')) + return motion_string + + def motion_string_to_token(self, motion_string: List[str]): + motion_tokens = [] + output_string = [] + for i in range(len(motion_string)): + string = self.get_middle_str( + motion_string[i], f'', + f'') + string_list = string.split('><') + token_list = [ + int(i.split('_')[-1].replace('>', '')) + for i in string_list[1:-1] + ] + if len(token_list) == 0: + token_list = [0] + token_list_padded = torch.tensor(token_list, + dtype=int).to(self.device) + motion_tokens.append(token_list_padded) + output_string.append(motion_string[i].replace( + string, '')) + + return motion_tokens, output_string + + def placeholder_fulfill(self, prompt: str, length: int, motion_string: str, + text: str): + + seconds = math.floor(length / self.framerate) + motion_splited = motion_string.split('>') + token_length = length / self.down_t + predict_head = int(token_length * self.predict_ratio + 1) + masked_head = int(token_length * self.inbetween_ratio + 1) + masked_tail = int(token_length * (1 - self.inbetween_ratio) + 1) + + motion_predict_head = '>'.join( + motion_splited[:predict_head] + ) + f'>' + motion_predict_last = f'' + '>'.join( + motion_splited[predict_head:]) + + motion_masked = '>'.join( + motion_splited[:masked_head] + ) + '>' + f'' * ( + masked_tail - masked_head) + '>'.join(motion_splited[masked_tail:]) + + if random.random() < self.quota_ratio: + text = f'\"{text}\"' + + prompt = prompt.replace('', text).replace( + '', + motion_string).replace('', f'{length}').replace( + '', '%.1f' % seconds).replace( + '', motion_predict_head).replace( + '', + motion_predict_last).replace( + '', motion_masked) + + return prompt + + def template_fulfill(self, + tasks, + lengths, + motion_strings, + texts, + stage='test'): + inputs = [] + outputs = [] + for i in range(len(lengths)): + input_template = random.choice(tasks[i]['input']) + output_template = random.choice(tasks[i]['output']) + length = lengths[i] + inputs.append( + self.placeholder_fulfill(input_template, length, + motion_strings[i], texts[i])) + outputs.append( + self.placeholder_fulfill(output_template, length, + motion_strings[i], texts[i])) + + return inputs, outputs + + def get_middle_str(self, content, startStr, endStr): + try: + startIndex = content.index(startStr) + if startIndex >= 0: + startIndex += len(startStr) + endIndex = content.index(endStr) + except: + return f'' + + return f'' + content[ + startIndex:endIndex] + f'' + + def random_spans_noise_mask(self, length): + # From https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py + + orig_length = length + + num_noise_tokens = int(np.round(length * self.noise_density)) + # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. + num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) + num_noise_spans = int( + np.round(num_noise_tokens / self.mean_noise_span_length)) + + # avoid degeneracy by ensuring positive number of noise spans + num_noise_spans = max(num_noise_spans, 1) + num_nonnoise_tokens = length - num_noise_tokens + + # pick the lengths of the noise spans and the non-noise spans + def _random_segmentation(num_items, num_segments): + """Partition a sequence of items randomly into non-empty segments. + Args: + num_items: an integer scalar > 0 + num_segments: an integer scalar in [1, num_items] + Returns: + a Tensor with shape [num_segments] containing positive integers that add + up to num_items + """ + mask_indices = np.arange(num_items - 1) < (num_segments - 1) + np.random.shuffle(mask_indices) + first_in_segment = np.pad(mask_indices, [[1, 0]]) + segment_id = np.cumsum(first_in_segment) + # count length of sub segments assuming that list is sorted + _, segment_length = np.unique(segment_id, return_counts=True) + return segment_length + + noise_span_lengths = _random_segmentation(num_noise_tokens, + num_noise_spans) + nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, + num_noise_spans) + + interleaved_span_lengths = np.reshape( + np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), + [num_noise_spans * 2], + ) + span_starts = np.cumsum(interleaved_span_lengths)[:-1] + span_start_indicator = np.zeros((length, ), dtype=np.int8) + span_start_indicator[span_starts] = True + span_num = np.cumsum(span_start_indicator) + is_noise = np.equal(span_num % 2, 1) + + return is_noise[:orig_length] + + def create_sentinel_ids(self, mask_indices): + # From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py + start_indices = mask_indices - np.roll(mask_indices, 1, + axis=-1) * mask_indices + start_indices[:, 0] = mask_indices[:, 0] + + sentinel_ids = np.where(start_indices != 0, + np.cumsum(start_indices, axis=-1), + start_indices) + sentinel_ids = np.where(sentinel_ids != 0, + (len(self.tokenizer) - sentinel_ids), 0) + sentinel_ids -= mask_indices - start_indices + + return sentinel_ids + + def filter_input_ids(self, input_ids, sentinel_ids): + # From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py + batch_size = input_ids.shape[0] + + input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, + input_ids.to('cpu')) + + # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are + # masked tokens coming after sentinel tokens and should be removed + input_ids = input_ids_full[input_ids_full >= 0].reshape( + (batch_size, -1)) + input_ids = np.concatenate( + [ + input_ids, + np.full((batch_size, 1), + self.tokenizer.eos_token_id, + dtype=np.int32), + ], + axis=-1, + ) + + input_ids = torch.tensor(input_ids, device=self.device) + + return input_ids diff --git a/mGPT/archs/mgpt_vq.py b/mGPT/archs/mgpt_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..077dc4896b26b88291f9a227574ffeaeaa593d3d --- /dev/null +++ b/mGPT/archs/mgpt_vq.py @@ -0,0 +1,190 @@ +# Partially from https://github.com/Mael-zys/T2M-GPT + +from typing import List, Optional, Union +import torch +import torch.nn as nn +from torch import Tensor, nn +from torch.distributions.distribution import Distribution +from .tools.resnet import Resnet1D +from .tools.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset +from collections import OrderedDict + + +class VQVae(nn.Module): + + def __init__(self, + nfeats: int, + quantizer: str = "ema_reset", + code_num=512, + code_dim=512, + output_emb_width=512, + down_t=3, + stride_t=2, + width=512, + depth=3, + dilation_growth_rate=3, + norm=None, + activation: str = "relu", + **kwargs) -> None: + + super().__init__() + + self.code_dim = code_dim + + self.encoder = Encoder(nfeats, + output_emb_width, + down_t, + stride_t, + width, + depth, + dilation_growth_rate, + activation=activation, + norm=norm) + + self.decoder = Decoder(nfeats, + output_emb_width, + down_t, + stride_t, + width, + depth, + dilation_growth_rate, + activation=activation, + norm=norm) + + if quantizer == "ema_reset": + self.quantizer = QuantizeEMAReset(code_num, code_dim, mu=0.99) + elif quantizer == "orig": + self.quantizer = Quantizer(code_num, code_dim, beta=1.0) + elif quantizer == "ema": + self.quantizer = QuantizeEMA(code_num, code_dim, mu=0.99) + elif quantizer == "reset": + self.quantizer = QuantizeReset(code_num, code_dim) + + def preprocess(self, x): + # (bs, T, Jx3) -> (bs, Jx3, T) + x = x.permute(0, 2, 1) + return x + + def postprocess(self, x): + # (bs, Jx3, T) -> (bs, T, Jx3) + x = x.permute(0, 2, 1) + return x + + def forward(self, features: Tensor): + # Preprocess + x_in = self.preprocess(features) + + # Encode + x_encoder = self.encoder(x_in) + + # quantization + x_quantized, loss, perplexity = self.quantizer(x_encoder) + + # decoder + x_decoder = self.decoder(x_quantized) + x_out = self.postprocess(x_decoder) + + return x_out, loss, perplexity + + def encode( + self, + features: Tensor, + ) -> Union[Tensor, Distribution]: + + N, T, _ = features.shape + x_in = self.preprocess(features) + x_encoder = self.encoder(x_in) + x_encoder = self.postprocess(x_encoder) + x_encoder = x_encoder.contiguous().view(-1, + x_encoder.shape[-1]) # (NT, C) + code_idx = self.quantizer.quantize(x_encoder) + code_idx = code_idx.view(N, -1) + + # latent, dist + return code_idx, None + + def decode(self, z: Tensor): + + x_d = self.quantizer.dequantize(z) + x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous() + + # decoder + x_decoder = self.decoder(x_d) + x_out = self.postprocess(x_decoder) + return x_out + + +class Encoder(nn.Module): + + def __init__(self, + input_emb_width=3, + output_emb_width=512, + down_t=3, + stride_t=2, + width=512, + depth=3, + dilation_growth_rate=3, + activation='relu', + norm=None): + super().__init__() + + blocks = [] + filter_t, pad_t = stride_t * 2, stride_t // 2 + blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1)) + blocks.append(nn.ReLU()) + + for i in range(down_t): + input_dim = width + block = nn.Sequential( + nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t), + Resnet1D(width, + depth, + dilation_growth_rate, + activation=activation, + norm=norm), + ) + blocks.append(block) + blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1)) + self.model = nn.Sequential(*blocks) + + def forward(self, x): + return self.model(x) + + +class Decoder(nn.Module): + + def __init__(self, + input_emb_width=3, + output_emb_width=512, + down_t=3, + stride_t=2, + width=512, + depth=3, + dilation_growth_rate=3, + activation='relu', + norm=None): + super().__init__() + blocks = [] + + filter_t, pad_t = stride_t * 2, stride_t // 2 + blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1)) + blocks.append(nn.ReLU()) + for i in range(down_t): + out_dim = width + block = nn.Sequential( + Resnet1D(width, + depth, + dilation_growth_rate, + reverse_dilation=True, + activation=activation, + norm=norm), nn.Upsample(scale_factor=2, + mode='nearest'), + nn.Conv1d(width, out_dim, 3, 1, 1)) + blocks.append(block) + blocks.append(nn.Conv1d(width, width, 3, 1, 1)) + blocks.append(nn.ReLU()) + blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1)) + self.model = nn.Sequential(*blocks) + + def forward(self, x): + return self.model(x) diff --git a/mGPT/archs/tm2t_evaluator.py b/mGPT/archs/tm2t_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..4654440f9cbdb1c0bb908ba9459c15c30ce10be6 --- /dev/null +++ b/mGPT/archs/tm2t_evaluator.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence + + +class MovementConvEncoder(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MovementConvEncoder, self).__init__() + self.main = nn.Sequential( + nn.Conv1d(input_size, hidden_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(hidden_size, output_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ) + self.out_net = nn.Linear(output_size, output_size) + # self.main.apply(init_weight) + # self.out_net.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + # print(outputs.shape) + return self.out_net(outputs) + + +class MotionEncoderBiGRUCo(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MotionEncoderBiGRUCo, self).__init__() + + self.input_emb = nn.Linear(input_size, hidden_size) + self.gru = nn.GRU( + hidden_size, hidden_size, batch_first=True, bidirectional=True + ) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size), + ) + + # self.input_emb.apply(init_weight) + # self.output_net.apply(init_weight) + self.hidden_size = hidden_size + self.hidden = nn.Parameter( + torch.randn((2, 1, self.hidden_size), requires_grad=True) + ) + + # input(batch_size, seq_len, dim) + def forward(self, inputs, m_lens): + num_samples = inputs.shape[0] + + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = m_lens.data.tolist() + + # emb = pack_padded_sequence(input=input_embs, lengths=cap_lens, batch_first=True) + emb = input_embs + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) + + +class TextEncoderBiGRUCo(nn.Module): + def __init__(self, word_size, pos_size, hidden_size, output_size): + super(TextEncoderBiGRUCo, self).__init__() + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU( + hidden_size, hidden_size, batch_first=True, bidirectional=True + ) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size), + ) + + # self.input_emb.apply(init_weight) + # self.pos_emb.apply(init_weight) + # self.output_net.apply(init_weight) + # self.linear2.apply(init_weight) + # self.batch_size = batch_size + self.hidden_size = hidden_size + self.hidden = nn.Parameter( + torch.randn((2, 1, self.hidden_size), requires_grad=True) + ) + + # input(batch_size, seq_len, dim) + def forward(self, word_embs, pos_onehot, cap_lens): + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input=input_embs, lengths=cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) diff --git a/mGPT/archs/tools/embeddings.py b/mGPT/archs/tools/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..b53470ce9491addf70d28726571330c834a88e68 --- /dev/null +++ b/mGPT/archs/tools/embeddings.py @@ -0,0 +1,322 @@ +# This file is taken from signjoey repository +import math + +import torch +from torch import Tensor, nn + + +def get_activation(activation_type): + if activation_type == "relu": + return nn.ReLU() + elif activation_type == "relu6": + return nn.ReLU6() + elif activation_type == "prelu": + return nn.PReLU() + elif activation_type == "selu": + return nn.SELU() + elif activation_type == "celu": + return nn.CELU() + elif activation_type == "gelu": + return nn.GELU() + elif activation_type == "sigmoid": + return nn.Sigmoid() + elif activation_type == "softplus": + return nn.Softplus() + elif activation_type == "softshrink": + return nn.Softshrink() + elif activation_type == "softsign": + return nn.Softsign() + elif activation_type == "tanh": + return nn.Tanh() + elif activation_type == "tanhshrink": + return nn.Tanhshrink() + else: + raise ValueError("Unknown activation type {}".format(activation_type)) + + +class MaskedNorm(nn.Module): + """ + Original Code from: + https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8 + """ + + def __init__(self, norm_type, num_groups, num_features): + super().__init__() + self.norm_type = norm_type + if self.norm_type == "batch": + self.norm = nn.BatchNorm1d(num_features=num_features) + elif self.norm_type == "group": + self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features) + elif self.norm_type == "layer": + self.norm = nn.LayerNorm(normalized_shape=num_features) + else: + raise ValueError("Unsupported Normalization Layer") + + self.num_features = num_features + + def forward(self, x: Tensor, mask: Tensor): + if self.training: + reshaped = x.reshape([-1, self.num_features]) + reshaped_mask = mask.reshape([-1, 1]) > 0 + selected = torch.masked_select(reshaped, reshaped_mask).reshape( + [-1, self.num_features] + ) + batch_normed = self.norm(selected) + scattered = reshaped.masked_scatter(reshaped_mask, batch_normed) + return scattered.reshape([x.shape[0], -1, self.num_features]) + else: + reshaped = x.reshape([-1, self.num_features]) + batched_normed = self.norm(reshaped) + return batched_normed.reshape([x.shape[0], -1, self.num_features]) + + +# TODO (Cihan): Spatial and Word Embeddings are pretty much the same +# We might as well convert them into a single module class. +# Only difference is the lut vs linear layers. +class Embeddings(nn.Module): + + """ + Simple embeddings class + """ + + # pylint: disable=unused-argument + def __init__( + self, + embedding_dim: int = 64, + num_heads: int = 8, + scale: bool = False, + scale_factor: float = None, + norm_type: str = None, + activation_type: str = None, + vocab_size: int = 0, + padding_idx: int = 1, + freeze: bool = False, + **kwargs + ): + """ + Create new embeddings for the vocabulary. + Use scaling for the Transformer. + + :param embedding_dim: + :param scale: + :param vocab_size: + :param padding_idx: + :param freeze: freeze the embeddings during training + """ + super().__init__() + + self.embedding_dim = embedding_dim + self.vocab_size = vocab_size + self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx) + + self.norm_type = norm_type + if self.norm_type: + self.norm = MaskedNorm( + norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim + ) + + self.activation_type = activation_type + if self.activation_type: + self.activation = get_activation(activation_type) + + self.scale = scale + if self.scale: + if scale_factor: + self.scale_factor = scale_factor + else: + self.scale_factor = math.sqrt(self.embedding_dim) + + if freeze: + freeze_params(self) + + # pylint: disable=arguments-differ + def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: + """ + Perform lookup for input `x` in the embedding table. + + :param mask: token masks + :param x: index in the vocabulary + :return: embedded representation for `x` + """ + + x = self.lut(x) + + if self.norm_type: + x = self.norm(x, mask) + + if self.activation_type: + x = self.activation(x) + + if self.scale: + return x * self.scale_factor + else: + return x + + def __repr__(self): + return "%s(embedding_dim=%d, vocab_size=%d)" % ( + self.__class__.__name__, + self.embedding_dim, + self.vocab_size, + ) + + +class SpatialEmbeddings(nn.Module): + + """ + Simple Linear Projection Layer + (For encoder outputs to predict glosses) + """ + + # pylint: disable=unused-argument + def __init__( + self, + embedding_dim: int, + input_size: int, + num_heads: int, + freeze: bool = False, + norm_type: str = "batch", + activation_type: str = "softsign", + scale: bool = False, + scale_factor: float = None, + **kwargs + ): + """ + Create new embeddings for the vocabulary. + Use scaling for the Transformer. + + :param embedding_dim: + :param input_size: + :param freeze: freeze the embeddings during training + """ + super().__init__() + + self.embedding_dim = embedding_dim + self.input_size = input_size + self.ln = nn.Linear(self.input_size, self.embedding_dim) + + self.norm_type = norm_type + if self.norm_type: + self.norm = MaskedNorm( + norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim + ) + + self.activation_type = activation_type + if self.activation_type: + self.activation = get_activation(activation_type) + + self.scale = scale + if self.scale: + if scale_factor: + self.scale_factor = scale_factor + else: + self.scale_factor = math.sqrt(self.embedding_dim) + + if freeze: + freeze_params(self) + + # pylint: disable=arguments-differ + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + """ + :param mask: frame masks + :param x: input frame features + :return: embedded representation for `x` + """ + + x = self.ln(x) + + if self.norm_type: + x = self.norm(x, mask) + + if self.activation_type: + x = self.activation(x) + + if self.scale: + return x * self.scale_factor + else: + return x + + def __repr__(self): + return "%s(embedding_dim=%d, input_size=%d)" % ( + self.__class__.__name__, + self.embedding_dim, + self.input_size, + ) + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): + super().__init__() + + self.linear_1 = nn.Linear(channel, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample): + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb diff --git a/mGPT/archs/tools/quantize_cnn.py b/mGPT/archs/tools/quantize_cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..98ca858a7af1eb74fccc58f4305e96247a7cd7c3 --- /dev/null +++ b/mGPT/archs/tools/quantize_cnn.py @@ -0,0 +1,414 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class QuantizeEMAReset(nn.Module): + def __init__(self, nb_code, code_dim, mu): + super().__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.mu = mu + self.reset_codebook() + + def reset_codebook(self): + self.init = False + self.code_sum = None + self.code_count = None + device = "cuda" if torch.cuda.is_available() else "cpu" + self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).to(device)) + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else : + out = x + return out + + def init_codebook(self, x): + out = self._tile(x) + self.codebook = out[:self.nb_code] + self.code_sum = self.codebook.clone() + self.code_count = torch.ones(self.nb_code, device=self.codebook.device) + self.init = True + + @torch.no_grad() + def compute_perplexity(self, code_idx) : + # Calculate new centres + code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + @torch.no_grad() + def update_codebook(self, x, code_idx): + + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + + code_sum = torch.matmul(code_onehot, x) # nb_code, w + code_count = code_onehot.sum(dim=-1) # nb_code + + out = self._tile(x) + code_rand = out[:self.nb_code] + + # Update centres + self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code + self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code + + usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() + code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) + + self.codebook = usage * code_update + (1 - usage) * code_rand + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + + return perplexity + + def preprocess(self, x): + # NCT -> NTC -> [NT, C] + x = x.permute(0, 2, 1).contiguous() + x = x.view(-1, x.shape[-1]) + return x + + def quantize(self, x): + # Calculate latent code x_l + k_w = self.codebook.t() + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, + keepdim=True) # (N * L, b) + _, code_idx = torch.min(distance, dim=-1) + return code_idx + + def dequantize(self, code_idx): + x = F.embedding(code_idx, self.codebook) + return x + + + def forward(self, x): + N, width, T = x.shape + + # Preprocess + x = self.preprocess(x) + + # Init codebook if not inited + if self.training and not self.init: + self.init_codebook(x) + + # quantize and dequantize through bottleneck + code_idx = self.quantize(x) + x_d = self.dequantize(code_idx) + + # Update embeddings + if self.training: + perplexity = self.update_codebook(x, code_idx) + else : + perplexity = self.compute_perplexity(code_idx) + + # Loss + commit_loss = F.mse_loss(x, x_d.detach()) + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) + + return x_d, commit_loss, perplexity + + + +class Quantizer(nn.Module): + def __init__(self, n_e, e_dim, beta): + super(Quantizer, self).__init__() + + self.e_dim = e_dim + self.n_e = n_e + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + + N, width, T = z.shape + z = self.preprocess(z) + assert z.shape[-1] == self.e_dim + z_flattened = z.contiguous().view(-1, self.e_dim) + + # B x V + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + # B x 1 + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + + # compute loss for embedding + loss = torch.mean((z_q - z.detach())**2) + self.beta * \ + torch.mean((z_q.detach() - z)**2) + + # preserve gradients + z_q = z + (z_q - z).detach() + z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) + + min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10))) + return z_q, loss, perplexity + + def quantize(self, z): + + assert z.shape[-1] == self.e_dim + + # B x V + d = torch.sum(z ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.matmul(z, self.embedding.weight.t()) + # B x 1 + min_encoding_indices = torch.argmin(d, dim=1) + return min_encoding_indices + + def dequantize(self, indices): + + index_flattened = indices.view(-1) + z_q = self.embedding(index_flattened) + z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous() + return z_q + + def preprocess(self, x): + # NCT -> NTC -> [NT, C] + x = x.permute(0, 2, 1).contiguous() + x = x.view(-1, x.shape[-1]) + return x + + + +class QuantizeReset(nn.Module): + def __init__(self, nb_code, code_dim): + super().__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.reset_codebook() + self.codebook = nn.Parameter(torch.randn(nb_code, code_dim)) + + def reset_codebook(self): + self.init = False + self.code_count = None + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else : + out = x + return out + + def init_codebook(self, x): + out = self._tile(x) + self.codebook = nn.Parameter(out[:self.nb_code]) + self.code_count = torch.ones(self.nb_code, device=self.codebook.device) + self.init = True + + @torch.no_grad() + def compute_perplexity(self, code_idx) : + # Calculate new centres + code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + def update_codebook(self, x, code_idx): + + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + + out = self._tile(x) + code_rand = out[:self.nb_code] + + # Update centres + self.code_count = code_count # nb_code + usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() + + self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + + return perplexity + + def preprocess(self, x): + # NCT -> NTC -> [NT, C] + x = x.permute(0, 2, 1).contiguous() + x = x.view(-1, x.shape[-1]) + return x + + def quantize(self, x): + # Calculate latent code x_l + k_w = self.codebook.t() + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, + keepdim=True) # (N * L, b) + _, code_idx = torch.min(distance, dim=-1) + return code_idx + + def dequantize(self, code_idx): + x = F.embedding(code_idx, self.codebook) + return x + + + def forward(self, x): + N, width, T = x.shape + # Preprocess + x = self.preprocess(x) + # Init codebook if not inited + if self.training and not self.init: + self.init_codebook(x) + # quantize and dequantize through bottleneck + code_idx = self.quantize(x) + x_d = self.dequantize(code_idx) + # Update embeddings + if self.training: + perplexity = self.update_codebook(x, code_idx) + else : + perplexity = self.compute_perplexity(code_idx) + + # Loss + commit_loss = F.mse_loss(x, x_d.detach()) + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) + + return x_d, commit_loss, perplexity + + +class QuantizeEMA(nn.Module): + def __init__(self, nb_code, code_dim, mu): + super().__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.mu = mu + self.reset_codebook() + + def reset_codebook(self): + self.init = False + self.code_sum = None + self.code_count = None + self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda()) + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else : + out = x + return out + + def init_codebook(self, x): + out = self._tile(x) + self.codebook = out[:self.nb_code] + self.code_sum = self.codebook.clone() + self.code_count = torch.ones(self.nb_code, device=self.codebook.device) + self.init = True + + @torch.no_grad() + def compute_perplexity(self, code_idx) : + # Calculate new centres + code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + @torch.no_grad() + def update_codebook(self, x, code_idx): + + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + + code_sum = torch.matmul(code_onehot, x) # nb_code, w + code_count = code_onehot.sum(dim=-1) # nb_code + + # Update centres + self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code + self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code + + code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) + + self.codebook = code_update + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + return perplexity + + def preprocess(self, x): + # NCT -> NTC -> [NT, C] + x = x.permute(0, 2, 1).contiguous() + x = x.view(-1, x.shape[-1]) + return x + + def quantize(self, x): + # Calculate latent code x_l + k_w = self.codebook.t() + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, + keepdim=True) # (N * L, b) + _, code_idx = torch.min(distance, dim=-1) + return code_idx + + def dequantize(self, code_idx): + x = F.embedding(code_idx, self.codebook) + return x + + + def forward(self, x): + N, width, T = x.shape + + # Preprocess + x = self.preprocess(x) + + # Init codebook if not inited + if self.training and not self.init: + self.init_codebook(x) + + # quantize and dequantize through bottleneck + code_idx = self.quantize(x) + x_d = self.dequantize(code_idx) + + # Update embeddings + if self.training: + perplexity = self.update_codebook(x, code_idx) + else : + perplexity = self.compute_perplexity(code_idx) + + # Loss + commit_loss = F.mse_loss(x, x_d.detach()) + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) + + return x_d, commit_loss, perplexity diff --git a/mGPT/archs/tools/resnet.py b/mGPT/archs/tools/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..062346e3ba2fc4d6ae5636f228c5b7565bdb62b7 --- /dev/null +++ b/mGPT/archs/tools/resnet.py @@ -0,0 +1,82 @@ +import torch.nn as nn +import torch + +class nonlinearity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # swish + return x * torch.sigmoid(x) + +class ResConv1DBlock(nn.Module): + def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None): + super().__init__() + padding = dilation + self.norm = norm + if norm == "LN": + self.norm1 = nn.LayerNorm(n_in) + self.norm2 = nn.LayerNorm(n_in) + elif norm == "GN": + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) + elif norm == "BN": + self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) + self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) + + else: + self.norm1 = nn.Identity() + self.norm2 = nn.Identity() + + if activation == "relu": + self.activation1 = nn.ReLU() + self.activation2 = nn.ReLU() + + elif activation == "silu": + self.activation1 = nonlinearity() + self.activation2 = nonlinearity() + + elif activation == "gelu": + self.activation1 = nn.GELU() + self.activation2 = nn.GELU() + + + + self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation) + self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,) + + + def forward(self, x): + x_orig = x + if self.norm == "LN": + x = self.norm1(x.transpose(-2, -1)) + x = self.activation1(x.transpose(-2, -1)) + else: + x = self.norm1(x) + x = self.activation1(x) + + x = self.conv1(x) + + if self.norm == "LN": + x = self.norm2(x.transpose(-2, -1)) + x = self.activation2(x.transpose(-2, -1)) + else: + x = self.norm2(x) + x = self.activation2(x) + + x = self.conv2(x) + x = x + x_orig + return x + +class Resnet1D(nn.Module): + def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None): + super().__init__() + + blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)] + if reverse_dilation: + blocks = blocks[::-1] + + self.model = nn.Sequential(*blocks) + + def forward(self, x): + return self.model(x) \ No newline at end of file diff --git a/mGPT/archs/tools/token_emb.py b/mGPT/archs/tools/token_emb.py new file mode 100644 index 0000000000000000000000000000000000000000..9bfa610d89f4b68fb38599250b554737a7aa1b6b --- /dev/null +++ b/mGPT/archs/tools/token_emb.py @@ -0,0 +1,73 @@ + +from torch import Tensor, nn + +class NewTokenEmb(nn.Module): + """ + For adding new tokens to a pretrained model + """ + + def __init__(self, + old_embeddings: nn.Embedding, + new_num_tokens: int = None) -> None: + + super().__init__() + + self.num_tokens = old_embeddings.num_embeddings + new_num_tokens + self.old_num_tokens = old_embeddings.num_embeddings + self.new_num_tokens = new_num_tokens + self.embedding_dim = old_embeddings.embedding_dim + + # For text embeddings + self.text_embeddings = nn.Embedding( + self.num_tokens, + self.embedding_dim, + device=old_embeddings.weight.device, + dtype=old_embeddings.weight.dtype) + with torch.no_grad(): + self.text_embeddings.weight.data[:old_embeddings. + num_embeddings] = old_embeddings.weight.data + self.text_embeddings.weight.data[ + self.old_num_tokens:] = torch.zeros( + self.new_num_tokens, + self.embedding_dim, + dtype=old_embeddings.weight.dtype, + device=old_embeddings.weight.device) + self.text_embeddings.weight.requires_grad_(False) + + # For motion embeddings + self.motion_embeddings = nn.Embedding( + new_num_tokens, + self.embedding_dim, + device=old_embeddings.weight.device, + dtype=old_embeddings.weight.dtype) + with torch.no_grad(): + self.motion_embeddings.weight.data[:self. + old_num_tokens] = torch.zeros( + new_num_tokens, + self.embedding_dim, + dtype=old_embeddings.weight. + dtype, + device=old_embeddings. + weight.device) + self.word2motionProj = nn.Linear(self.old_num_tokens, new_num_tokens) + + def forward(self, input: Tensor) -> Tensor: + + with torch.no_grad(): + self.motion_embeddings.weight.data[:self. + old_num_tokens] = torch.zeros( + self.new_num_tokens, + self.embedding_dim, + dtype=self.motion_embeddings + .weight.dtype, + device=self. + motion_embeddings.weight. + device) + + self.motion_embeddings.weight.data[ + self.old_num_tokens:] = self.word2motionProj( + self.text_embeddings.weight.data[:self.old_num_tokens].permute( + 1, 0)).permute(1, 0) + + return self.text_embeddings(input) + self.motion_embeddings(input) + diff --git a/mGPT/archs/tools/transformer_layers.py b/mGPT/archs/tools/transformer_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..7b53429a5168de69cb581a2016c00f3560da0e1f --- /dev/null +++ b/mGPT/archs/tools/transformer_layers.py @@ -0,0 +1,285 @@ +# -*- coding: utf-8 -*- +import math +import torch +import torch.nn as nn +from torch import Tensor + +# Took from https://github.com/joeynmt/joeynmt/blob/fb66afcbe1beef9acd59283bcc084c4d4c1e6343/joeynmt/transformer_layers.py + + +# pylint: disable=arguments-differ +class MultiHeadedAttention(nn.Module): + """ + Multi-Head Attention module from "Attention is All You Need" + + Implementation modified from OpenNMT-py. + https://github.com/OpenNMT/OpenNMT-py + """ + + def __init__(self, num_heads: int, size: int, dropout: float = 0.1): + """ + Create a multi-headed attention layer. + :param num_heads: the number of heads + :param size: model size (must be divisible by num_heads) + :param dropout: probability of dropping a unit + """ + super().__init__() + + assert size % num_heads == 0 + + self.head_size = head_size = size // num_heads + self.model_size = size + self.num_heads = num_heads + + self.k_layer = nn.Linear(size, num_heads * head_size) + self.v_layer = nn.Linear(size, num_heads * head_size) + self.q_layer = nn.Linear(size, num_heads * head_size) + + self.output_layer = nn.Linear(size, size) + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None): + """ + Computes multi-headed attention. + + :param k: keys [B, M, D] with M being the sentence length. + :param v: values [B, M, D] + :param q: query [B, M, D] + :param mask: optional mask [B, 1, M] or [B, M, M] + :return: + """ + batch_size = k.size(0) + num_heads = self.num_heads + + # project the queries (q), keys (k), and values (v) + k = self.k_layer(k) + v = self.v_layer(v) + q = self.q_layer(q) + + # reshape q, k, v for our computation to [batch_size, num_heads, ..] + k = k.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) + v = v.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) + q = q.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) + + # compute scores + q = q / math.sqrt(self.head_size) + + # batch x num_heads x query_len x key_len + scores = torch.matmul(q, k.transpose(2, 3)) + # torch.Size([48, 8, 183, 183]) + + # apply the mask (if we have one) + # we add a dimension for the heads to it below: [B, 1, 1, M] + if mask is not None: + scores = scores.masked_fill(~mask.unsqueeze(1), float('-inf')) + + # apply attention dropout and compute context vectors. + attention = self.softmax(scores) + attention = self.dropout(attention) + # torch.Size([48, 8, 183, 183]) [bs, nheads, time, time] (for decoding) + + # v: torch.Size([48, 8, 183, 32]) (32 is 256/8) + # get context vector (select values with attention) and reshape + # back to [B, M, D] + context = torch.matmul(attention, v) # torch.Size([48, 8, 183, 32]) + context = context.transpose(1, 2).contiguous().view( + batch_size, -1, num_heads * self.head_size) + # torch.Size([48, 183, 256]) put back to 256 (combine the heads) + + output = self.output_layer(context) + # torch.Size([48, 183, 256]): 1 output per time step + + return output + + +# pylint: disable=arguments-differ +class PositionwiseFeedForward(nn.Module): + """ + Position-wise Feed-forward layer + Projects to ff_size and then back down to input_size. + """ + + def __init__(self, input_size, ff_size, dropout=0.1): + """ + Initializes position-wise feed-forward layer. + :param input_size: dimensionality of the input. + :param ff_size: dimensionality of intermediate representation + :param dropout: + """ + super().__init__() + self.layer_norm = nn.LayerNorm(input_size, eps=1e-6) + self.pwff_layer = nn.Sequential( + nn.Linear(input_size, ff_size), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(ff_size, input_size), + nn.Dropout(dropout), + ) + + def forward(self, x): + x_norm = self.layer_norm(x) + return self.pwff_layer(x_norm) + x + + +# pylint: disable=arguments-differ +class PositionalEncoding(nn.Module): + """ + Pre-compute position encodings (PE). + In forward pass, this adds the position-encodings to the + input for as many time steps as necessary. + + Implementation based on OpenNMT-py. + https://github.com/OpenNMT/OpenNMT-py + """ + + def __init__(self, size: int = 0, max_len: int = 5000): + """ + Positional Encoding with maximum length max_len + :param size: + :param max_len: + :param dropout: + """ + if size % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(size)) + pe = torch.zeros(max_len, size) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp((torch.arange(0, size, 2, dtype=torch.float) * + -(math.log(10000.0) / size))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0) # shape: [1, size, max_len] + super().__init__() + self.register_buffer('pe', pe) + self.dim = size + + def forward(self, emb): + """Embed inputs. + Args: + emb (FloatTensor): Sequence of word vectors + ``(seq_len, batch_size, self.dim)`` + """ + # Add position encodings + return emb + self.pe[:, :emb.size(1)] + + +class TransformerEncoderLayer(nn.Module): + """ + One Transformer encoder layer has a Multi-head attention layer plus + a position-wise feed-forward layer. + """ + + def __init__(self, + size: int = 0, + ff_size: int = 0, + num_heads: int = 0, + dropout: float = 0.1): + """ + A single Transformer layer. + :param size: + :param ff_size: + :param num_heads: + :param dropout: + """ + super().__init__() + + self.layer_norm = nn.LayerNorm(size, eps=1e-6) + self.src_src_att = MultiHeadedAttention(num_heads, + size, + dropout=dropout) + self.feed_forward = PositionwiseFeedForward(size, + ff_size=ff_size, + dropout=dropout) + self.dropout = nn.Dropout(dropout) + self.size = size + + # pylint: disable=arguments-differ + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + """ + Forward pass for a single transformer encoder layer. + First applies layer norm, then self attention, + then dropout with residual connection (adding the input to the result), + and then a position-wise feed-forward layer. + + :param x: layer input + :param mask: input mask + :return: output tensor + """ + x_norm = self.layer_norm(x) + h = self.src_src_att(x_norm, x_norm, x_norm, mask) + h = self.dropout(h) + x + o = self.feed_forward(h) + return o + + +class TransformerDecoderLayer(nn.Module): + """ + Transformer decoder layer. + + Consists of self-attention, source-attention, and feed-forward. + """ + + def __init__(self, + size: int = 0, + ff_size: int = 0, + num_heads: int = 0, + dropout: float = 0.1): + """ + Represents a single Transformer decoder layer. + + It attends to the source representation and the previous decoder states. + + :param size: model dimensionality + :param ff_size: size of the feed-forward intermediate layer + :param num_heads: number of heads + :param dropout: dropout to apply to input + """ + super().__init__() + self.size = size + + self.trg_trg_att = MultiHeadedAttention(num_heads, + size, + dropout=dropout) + self.src_trg_att = MultiHeadedAttention(num_heads, + size, + dropout=dropout) + + self.feed_forward = PositionwiseFeedForward(size, + ff_size=ff_size, + dropout=dropout) + + self.x_layer_norm = nn.LayerNorm(size, eps=1e-6) + self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6) + + self.dropout = nn.Dropout(dropout) + + # pylint: disable=arguments-differ + def forward(self, + x: Tensor = None, + memory: Tensor = None, + src_mask: Tensor = None, + trg_mask: Tensor = None) -> Tensor: + """ + Forward pass of a single Transformer decoder layer. + + :param x: inputs + :param memory: source representations + :param src_mask: source mask + :param trg_mask: target mask (so as to not condition on future steps) + :return: output tensor + """ + # decoder/target self-attention + x_norm = self.x_layer_norm(x) # torch.Size([48, 183, 256]) + h1 = self.trg_trg_att(x_norm, x_norm, x_norm, mask=trg_mask) + h1 = self.dropout(h1) + x + + # source-target attention + h1_norm = self.dec_layer_norm( + h1) # torch.Size([48, 183, 256]) (same for memory) + h2 = self.src_trg_att(memory, memory, h1_norm, mask=src_mask) + + # final position-wise feed-forward layer + o = self.feed_forward(self.dropout(h2) + h1) + + return o diff --git a/mGPT/callback.py b/mGPT/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..19d6421c92ec58b0aacb4b51dd0c5af2f8e8ec0b --- /dev/null +++ b/mGPT/callback.py @@ -0,0 +1,200 @@ +import os +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import Callback, RichProgressBar, ModelCheckpoint + + +def build_callbacks(cfg, logger=None, phase='test', **kwargs): + callbacks = [] + logger = logger + + # Rich Progress Bar + callbacks.append(progressBar()) + + # Checkpoint Callback + if phase == 'train': + callbacks.extend(getCheckpointCallback(cfg, logger=logger, **kwargs)) + + return callbacks + +def getCheckpointCallback(cfg, logger=None, **kwargs): + callbacks = [] + # Logging + metric_monitor = { + "loss_total": "total/train", + "Train_jf": "recons/text2jfeats/train", + "Val_jf": "recons/text2jfeats/val", + "Train_rf": "recons/text2rfeats/train", + "Val_rf": "recons/text2rfeats/val", + "APE root": "Metrics/APE_root", + "APE mean pose": "Metrics/APE_mean_pose", + "AVE root": "Metrics/AVE_root", + "AVE mean pose": "Metrics/AVE_mean_pose", + "R_TOP_1": "Metrics/R_precision_top_1", + "R_TOP_2": "Metrics/R_precision_top_2", + "R_TOP_3": "Metrics/R_precision_top_3", + "gt_R_TOP_3": "Metrics/gt_R_precision_top_3", + "FID": "Metrics/FID", + "gt_FID": "Metrics/gt_FID", + "Diversity": "Metrics/Diversity", + "MM dist": "Metrics/Matching_score", + "Accuracy": "Metrics/accuracy", + } + callbacks.append( + progressLogger(logger,metric_monitor=metric_monitor,log_every_n_steps=1)) + + # Save 10 latest checkpoints + checkpointParams = { + 'dirpath': os.path.join(cfg.FOLDER_EXP, "checkpoints"), + 'filename': "{epoch}", + 'monitor': "step", + 'mode': "max", + 'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS, + 'save_top_k': 8, + 'save_last': True, + 'save_on_train_epoch_end': True + } + callbacks.append(ModelCheckpoint(**checkpointParams)) + + # Save checkpoint every n*10 epochs + checkpointParams.update({ + 'every_n_epochs': + cfg.LOGGER.VAL_EVERY_STEPS * 10, + 'save_top_k': + -1, + 'save_last': + False + }) + callbacks.append(ModelCheckpoint(**checkpointParams)) + + metrics = cfg.METRIC.TYPE + metric_monitor_map = { + 'TemosMetric': { + 'Metrics/APE_root': { + 'abbr': 'APEroot', + 'mode': 'min' + }, + }, + 'TM2TMetrics': { + 'Metrics/FID': { + 'abbr': 'FID', + 'mode': 'min' + }, + 'Metrics/R_precision_top_3': { + 'abbr': 'R3', + 'mode': 'max' + } + }, + 'MRMetrics': { + 'Metrics/MPJPE': { + 'abbr': 'MPJPE', + 'mode': 'min' + } + }, + 'HUMANACTMetrics': { + 'Metrics/Accuracy': { + 'abbr': 'Accuracy', + 'mode': 'max' + } + }, + 'UESTCMetrics': { + 'Metrics/Accuracy': { + 'abbr': 'Accuracy', + 'mode': 'max' + } + }, + 'UncondMetrics': { + 'Metrics/FID': { + 'abbr': 'FID', + 'mode': 'min' + } + } + } + + checkpointParams.update({ + 'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS, + 'save_top_k': 1, + }) + + for metric in metrics: + if metric in metric_monitor_map.keys(): + metric_monitors = metric_monitor_map[metric] + + # Delete R3 if training VAE + if cfg.TRAIN.STAGE == 'vae' and metric == 'TM2TMetrics': + del metric_monitors['Metrics/R_precision_top_3'] + + for metric_monitor in metric_monitors: + checkpointParams.update({ + 'filename': + metric_monitor_map[metric][metric_monitor]['mode'] + + "-" + + metric_monitor_map[metric][metric_monitor]['abbr'] + + "{ep}", + 'monitor': + metric_monitor, + 'mode': + metric_monitor_map[metric][metric_monitor]['mode'], + }) + callbacks.append( + ModelCheckpoint(**checkpointParams)) + return callbacks + +class progressBar(RichProgressBar): + def __init__(self, ): + super().__init__() + + def get_metrics(self, trainer, model): + # Don't show the version number + items = super().get_metrics(trainer, model) + items.pop("v_num", None) + return items + +class progressLogger(Callback): + def __init__(self, + logger, + metric_monitor: dict, + precision: int = 3, + log_every_n_steps: int = 1): + # Metric to monitor + self.logger = logger + self.metric_monitor = metric_monitor + self.precision = precision + self.log_every_n_steps = log_every_n_steps + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule, + **kwargs) -> None: + self.logger.info("Training started") + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule, + **kwargs) -> None: + self.logger.info("Training done") + + def on_validation_epoch_end(self, trainer: Trainer, + pl_module: LightningModule, **kwargs) -> None: + if trainer.sanity_checking: + self.logger.info("Sanity checking ok.") + + def on_train_epoch_end(self, + trainer: Trainer, + pl_module: LightningModule, + padding=False, + **kwargs) -> None: + metric_format = f"{{:.{self.precision}e}}" + line = f"Epoch {trainer.current_epoch}" + if padding: + line = f"{line:>{len('Epoch xxxx')}}" # Right padding + + if trainer.current_epoch % self.log_every_n_steps == 0: + metrics_str = [] + + losses_dict = trainer.callback_metrics + for metric_name, dico_name in self.metric_monitor.items(): + if dico_name in losses_dict: + metric = losses_dict[dico_name].item() + metric = metric_format.format(metric) + metric = f"{metric_name} {metric}" + metrics_str.append(metric) + + line = line + ": " + " ".join(metrics_str) + + self.logger.info(line) diff --git a/mGPT/config.py b/mGPT/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f5376142daa4a9c3da779bdc9c1a8df67912da5f --- /dev/null +++ b/mGPT/config.py @@ -0,0 +1,217 @@ +import importlib +from argparse import ArgumentParser +from omegaconf import OmegaConf +from os.path import join as pjoin +import os +import glob + + +def get_module_config(cfg, filepath="./configs"): + """ + Load yaml config files from subfolders + """ + + yamls = glob.glob(pjoin(filepath, '*', '*.yaml')) + yamls = [y.replace(filepath, '') for y in yamls] + for yaml in yamls: + nodes = yaml.replace('.yaml', '').replace('/', '.') + nodes = nodes[1:] if nodes[0] == '.' else nodes + OmegaConf.update(cfg, nodes, OmegaConf.load('./configs' + yaml)) + + return cfg + + +def get_obj_from_str(string, reload=False): + """ + Get object from string + """ + + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + """ + Instantiate object from config + """ + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def resume_config(cfg: OmegaConf): + """ + Resume model and wandb + """ + + if cfg.TRAIN.RESUME: + resume = cfg.TRAIN.RESUME + if os.path.exists(resume): + # Checkpoints + cfg.TRAIN.PRETRAINED = pjoin(resume, "checkpoints", "last.ckpt") + # Wandb + wandb_files = os.listdir(pjoin(resume, "wandb", "latest-run")) + wandb_run = [item for item in wandb_files if "run-" in item][0] + cfg.LOGGER.WANDB.params.id = wandb_run.replace("run-","").replace(".wandb", "") + else: + raise ValueError("Resume path is not right.") + + return cfg + +def parse_args(phase="train"): + """ + Parse arguments and load config files + """ + + parser = ArgumentParser() + group = parser.add_argument_group("Training options") + + # Assets + group.add_argument( + "--cfg_assets", + type=str, + required=False, + default="./configs/assets.yaml", + help="config file for asset paths", + ) + + # Default config + if phase in ["train", "test"]: + cfg_defualt = "./configs/default.yaml" + elif phase == "render": + cfg_defualt = "./configs/render.yaml" + elif phase == "webui": + cfg_defualt = "./configs/webui.yaml" + + group.add_argument( + "--cfg", + type=str, + required=False, + default=cfg_defualt, + help="config file", + ) + + # Parse for each phase + if phase in ["train", "test"]: + group.add_argument("--batch_size", + type=int, + required=False, + help="training batch size") + group.add_argument("--num_nodes", + type=int, + required=False, + help="number of nodes") + group.add_argument("--device", + type=int, + nargs="+", + required=False, + help="training device") + group.add_argument("--task", + type=str, + required=False, + help="evaluation task type") + group.add_argument("--nodebug", + action="store_true", + required=False, + help="debug or not") + + + if phase == "demo": + group.add_argument( + "--example", + type=str, + required=False, + help="input text and lengths with txt format", + ) + group.add_argument( + "--out_dir", + type=str, + required=False, + help="output dir", + ) + group.add_argument("--task", + type=str, + required=False, + help="evaluation task type") + + if phase == "render": + group.add_argument("--npy", + type=str, + required=False, + default=None, + help="npy motion files") + group.add_argument("--dir", + type=str, + required=False, + default=None, + help="npy motion folder") + group.add_argument("--fps", + type=int, + required=False, + default=30, + help="render fps") + group.add_argument( + "--mode", + type=str, + required=False, + default="sequence", + help="render target: video, sequence, frame", + ) + + params = parser.parse_args() + + # Load yaml config files + OmegaConf.register_new_resolver("eval", eval) + cfg_assets = OmegaConf.load(params.cfg_assets) + cfg_base = OmegaConf.load(pjoin(cfg_assets.CONFIG_FOLDER, 'default.yaml')) + cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load(params.cfg)) + if not cfg_exp.FULL_CONFIG: + cfg_exp = get_module_config(cfg_exp, cfg_assets.CONFIG_FOLDER) + cfg = OmegaConf.merge(cfg_exp, cfg_assets) + + # Update config with arguments + if phase in ["train", "test"]: + cfg.TRAIN.BATCH_SIZE = params.batch_size if params.batch_size else cfg.TRAIN.BATCH_SIZE + cfg.DEVICE = params.device if params.device else cfg.DEVICE + cfg.NUM_NODES = params.num_nodes if params.num_nodes else cfg.NUM_NODES + cfg.model.params.task = params.task if params.task else cfg.model.params.task + cfg.DEBUG = not params.nodebug if params.nodebug is not None else cfg.DEBUG + + # Force no debug in test + if phase == "test": + cfg.DEBUG = False + cfg.DEVICE = [0] + print("Force no debugging and one gpu when testing") + + if phase == "demo": + cfg.DEMO.RENDER = params.render + cfg.DEMO.FRAME_RATE = params.frame_rate + cfg.DEMO.EXAMPLE = params.example + cfg.DEMO.TASK = params.task + cfg.TEST.FOLDER = params.out_dir if params.out_dir else cfg.TEST.FOLDER + os.makedirs(cfg.TEST.FOLDER, exist_ok=True) + + if phase == "render": + if params.npy: + cfg.RENDER.NPY = params.npy + cfg.RENDER.INPUT_MODE = "npy" + if params.dir: + cfg.RENDER.DIR = params.dir + cfg.RENDER.INPUT_MODE = "dir" + if params.fps: + cfg.RENDER.FPS = float(params.fps) + cfg.RENDER.MODE = params.mode + + # Debug mode + if cfg.DEBUG: + cfg.NAME = "debug--" + cfg.NAME + cfg.LOGGER.WANDB.params.offline = True + cfg.LOGGER.VAL_EVERY_STEPS = 1 + + # Resume config + cfg = resume_config(cfg) + + return cfg diff --git a/mGPT/data/HumanML3D.py b/mGPT/data/HumanML3D.py new file mode 100644 index 0000000000000000000000000000000000000000..380d43c482ca097ad499cbee17b5b8f49318f0bb --- /dev/null +++ b/mGPT/data/HumanML3D.py @@ -0,0 +1,117 @@ +import numpy as np +import torch +from os.path import join as pjoin +from .humanml.utils.word_vectorizer import WordVectorizer +from .humanml.scripts.motion_process import (process_file, recover_from_ric) +from . import BASEDataModule +from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken, Text2MotionDatasetM2T +from .utils import humanml3d_collate + + +class HumanML3DDataModule(BASEDataModule): + def __init__(self, cfg, **kwargs): + + super().__init__(collate_fn=humanml3d_collate) + self.cfg = cfg + self.save_hyperparameters(logger=False) + + # Basic info of the dataset + cfg.DATASET.JOINT_TYPE = 'humanml3d' + self.name = "humanml3d" + self.njoints = 22 + + # Path to the dataset + data_root = cfg.DATASET.HUMANML3D.ROOT + self.hparams.data_root = data_root + self.hparams.text_dir = pjoin(data_root, "texts") + self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs') + + # Mean and std of the dataset + self.hparams.mean = np.load(pjoin('assets/meta', "mean.npy")) + self.hparams.std = np.load(pjoin('assets/meta', "std.npy")) + + # Mean and std for fair evaluation + self.hparams.mean_eval = np.load(pjoin('assets/meta', "mean_eval.npy")) + self.hparams.std_eval = np.load(pjoin('assets/meta', "std_eval.npy")) + + # Length of the dataset + self.hparams.max_motion_length = cfg.DATASET.HUMANML3D.MAX_MOTION_LEN + self.hparams.min_motion_length = cfg.DATASET.HUMANML3D.MIN_MOTION_LEN + self.hparams.max_text_len = cfg.DATASET.HUMANML3D.MAX_TEXT_LEN + self.hparams.unit_length = cfg.DATASET.HUMANML3D.UNIT_LEN + + # Additional parameters + self.hparams.debug = cfg.DEBUG + self.hparams.stage = cfg.TRAIN.STAGE + + # Dataset switch + self.DatasetEval = Text2MotionDatasetEval + + if cfg.TRAIN.STAGE == "vae": + if cfg.model.params.motion_vae.target.split('.')[-1].lower() == "vqvae": + self.hparams.win_size = 64 + self.Dataset = MotionDatasetVQ + else: + self.Dataset = MotionDataset + elif 'lm' in cfg.TRAIN.STAGE: + self.hparams.code_path = cfg.DATASET.CODE_PATH + self.hparams.task_path = cfg.DATASET.TASK_PATH + self.hparams.std_text = cfg.DATASET.HUMANML3D.STD_TEXT + self.Dataset = Text2MotionDatasetCB + elif cfg.TRAIN.STAGE == "token": + self.Dataset = Text2MotionDatasetToken + self.DatasetEval = Text2MotionDatasetToken + elif cfg.TRAIN.STAGE == "m2t": + self.Dataset = Text2MotionDatasetM2T + self.DatasetEval = Text2MotionDatasetM2T + else: + self.Dataset = Text2MotionDataset + + # Get additional info of the dataset + self.nfeats = 263 + cfg.DATASET.NFEATS = self.nfeats + + + def feats2joints(self, features): + mean = torch.tensor(self.hparams.mean).to(features) + std = torch.tensor(self.hparams.std).to(features) + features = features * std + mean + return recover_from_ric(features, self.njoints) + + def joints2feats(self, features): + features = process_file(features, self.njoints)[0] + return features + + def normalize(self, features): + mean = torch.tensor(self.hparams.mean).to(features) + std = torch.tensor(self.hparams.std).to(features) + features = (features - mean) / std + return features + + def denormalize(self, features): + mean = torch.tensor(self.hparams.mean).to(features) + std = torch.tensor(self.hparams.std).to(features) + features = features * std + mean + return features + + def renorm4t2m(self, features): + # renorm to t2m norms for using t2m evaluators + ori_mean = torch.tensor(self.hparams.mean).to(features) + ori_std = torch.tensor(self.hparams.std).to(features) + eval_mean = torch.tensor(self.hparams.mean_eval).to(features) + eval_std = torch.tensor(self.hparams.std_eval).to(features) + features = features * ori_std + ori_mean + features = (features - eval_mean) / eval_std + return features + + def mm_mode(self, mm_on=True): + if mm_on: + self.is_mm = True + self.name_list = self.test_dataset.name_list + self.mm_list = np.random.choice(self.name_list, + self.cfg.METRIC.MM_NUM_SAMPLES, + replace=False) + self.test_dataset.name_list = self.mm_list + else: + self.is_mm = False + self.test_dataset.name_list = self.name_list diff --git a/mGPT/data/Kit.py b/mGPT/data/Kit.py new file mode 100644 index 0000000000000000000000000000000000000000..1eecaa101a765f6899d6bf7ab6aae66b7deb43ab --- /dev/null +++ b/mGPT/data/Kit.py @@ -0,0 +1,88 @@ +import numpy as np +import torch +from os.path import join as pjoin +from .humanml.utils.word_vectorizer import WordVectorizer +from .humanml.scripts.motion_process import (process_file, recover_from_ric) +from .HumanML3D import HumanML3DDataModule +from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken + + +class KitDataModule(HumanML3DDataModule): + def __init__(self, cfg, **kwargs): + + super().__init__(cfg, **kwargs) + + # Basic info of the dataset + self.name = "kit" + self.njoints = 21 + + # Path to the dataset + data_root = cfg.DATASET.KIT.ROOT + self.hparams.data_root = data_root + self.hparams.text_dir = pjoin(data_root, "texts") + self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs') + + # Mean and std of the dataset + dis_data_root = pjoin(cfg.DATASET.KIT.MEAN_STD_PATH, 'kit', + "VQVAEV3_CB1024_CMT_H1024_NRES3", "meta") + self.hparams.mean = np.load(pjoin(dis_data_root, "mean.npy")) + self.hparams.std = np.load(pjoin(dis_data_root, "std.npy")) + + # Mean and std for fair evaluation + dis_data_root_eval = pjoin(cfg.DATASET.KIT.MEAN_STD_PATH, 't2m', + "Comp_v6_KLD005", "meta") + self.hparams.mean_eval = np.load(pjoin(dis_data_root_eval, "mean.npy")) + self.hparams.std_eval = np.load(pjoin(dis_data_root_eval, "std.npy")) + + # Length of the dataset + self.hparams.max_motion_length = cfg.DATASET.KIT.MAX_MOTION_LEN + self.hparams.min_motion_length = cfg.DATASET.KIT.MIN_MOTION_LEN + self.hparams.max_text_len = cfg.DATASET.KIT.MAX_TEXT_LEN + self.hparams.unit_length = cfg.DATASET.KIT.UNIT_LEN + + # Get additional info of the dataset + self._sample_set = self.get_sample_set(overrides={"split": "test", "tiny": True}) + self.nfeats = self._sample_set.nfeats + cfg.DATASET.NFEATS = self.nfeats + + def feats2joints(self, features): + mean = torch.tensor(self.hparams.mean).to(features) + std = torch.tensor(self.hparams.std).to(features) + features = features * std + mean + return recover_from_ric(features, self.njoints) + + def joints2feats(self, features): + features = process_file(features, self.njoints)[0] + # mean = torch.tensor(self.hparams.mean).to(features) + # std = torch.tensor(self.hparams.std).to(features) + # features = (features - mean) / std + return features + + def normalize(self, features): + mean = torch.tensor(self.hparams.mean).to(features) + std = torch.tensor(self.hparams.std).to(features) + features = (features - mean) / std + return features + + def renorm4t2m(self, features): + # renorm to t2m norms for using t2m evaluators + ori_mean = torch.tensor(self.hparams.mean).to(features) + ori_std = torch.tensor(self.hparams.std).to(features) + eval_mean = torch.tensor(self.hparams.mean_eval).to(features) + eval_std = torch.tensor(self.hparams.std_eval).to(features) + features = features * ori_std + ori_mean + features = (features - eval_mean) / eval_std + return features + + def mm_mode(self, mm_on=True): + # random select samples for mm + if mm_on: + self.is_mm = True + self.name_list = self.test_dataset.name_list + self.mm_list = np.random.choice(self.name_list, + self.cfg.METRIC.MM_NUM_SAMPLES, + replace=False) + self.test_dataset.name_list = self.mm_list + else: + self.is_mm = False + self.test_dataset.name_list = self.name_list diff --git a/mGPT/data/__init__.py b/mGPT/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..948f49c9d3779f8c48f861365c0830bd7061cf9d --- /dev/null +++ b/mGPT/data/__init__.py @@ -0,0 +1,103 @@ +import pytorch_lightning as pl +from torch.utils.data import DataLoader + + +class BASEDataModule(pl.LightningDataModule): + def __init__(self, collate_fn): + super().__init__() + + self.dataloader_options = {"collate_fn": collate_fn} + self.persistent_workers = True + self.is_mm = False + + self._train_dataset = None + self._val_dataset = None + self._test_dataset = None + + def get_sample_set(self, overrides={}): + sample_params = self.hparams.copy() + sample_params.update(overrides) + return self.DatasetEval(**sample_params) + + @property + def train_dataset(self): + if self._train_dataset is None: + self._train_dataset = self.Dataset(split=self.cfg.TRAIN.SPLIT, + **self.hparams) + return self._train_dataset + + @property + def val_dataset(self): + if self._val_dataset is None: + params = self.hparams.copy() + params['code_path'] = None + params['split'] = self.cfg.EVAL.SPLIT + self._val_dataset = self.DatasetEval(**params) + return self._val_dataset + + @property + def test_dataset(self): + if self._test_dataset is None: + # self._test_dataset = self.DatasetEval(split=self.cfg.TEST.SPLIT, + # **self.hparams) + params = self.hparams.copy() + params['code_path'] = None + params['split'] = self.cfg.TEST.SPLIT + self._test_dataset = self.DatasetEval( **params) + return self._test_dataset + + def setup(self, stage=None): + # Use the getter the first time to load the data + if stage in (None, "fit"): + _ = self.train_dataset + _ = self.val_dataset + if stage in (None, "test"): + _ = self.test_dataset + + def train_dataloader(self): + dataloader_options = self.dataloader_options.copy() + dataloader_options["batch_size"] = self.cfg.TRAIN.BATCH_SIZE + dataloader_options["num_workers"] = self.cfg.TRAIN.NUM_WORKERS + return DataLoader( + self.train_dataset, + shuffle=False, + persistent_workers=True, + **dataloader_options, + ) + + def predict_dataloader(self): + dataloader_options = self.dataloader_options.copy() + dataloader_options[ + "batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE + dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS + dataloader_options["shuffle"] = False + return DataLoader( + self.test_dataset, + persistent_workers=True, + **dataloader_options, + ) + + def val_dataloader(self): + # overrides batch_size and num_workers + dataloader_options = self.dataloader_options.copy() + dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE + dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS + dataloader_options["shuffle"] = False + return DataLoader( + self.val_dataset, + persistent_workers=True, + **dataloader_options, + ) + + def test_dataloader(self): + # overrides batch_size and num_workers + dataloader_options = self.dataloader_options.copy() + dataloader_options[ + "batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE + dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS + dataloader_options["shuffle"] = False + return DataLoader( + self.test_dataset, + persistent_workers=True, + **dataloader_options, + ) diff --git a/mGPT/data/build_data.py b/mGPT/data/build_data.py new file mode 100644 index 0000000000000000000000000000000000000000..057a556ad891709ac412ffe3bf54aab9aa3731c4 --- /dev/null +++ b/mGPT/data/build_data.py @@ -0,0 +1,15 @@ +from omegaconf import OmegaConf +from os.path import join as pjoin +from mGPT.config import instantiate_from_config + + +def build_data(cfg, phase="train"): + data_config = OmegaConf.to_container(cfg.DATASET, resolve=True) + data_config['params'] = {'cfg': cfg, 'phase': phase} + if isinstance(data_config['target'], str): + return instantiate_from_config(data_config) + elif isinstance(data_config['target'], list): + data_config_tmp = data_config.copy() + data_config_tmp['params']['dataModules'] = data_config['target'] + data_config_tmp['target'] = 'mGPT.data.Concat.ConcatDataModule' + return instantiate_from_config(data_config) diff --git a/mGPT/data/humanml/README.md b/mGPT/data/humanml/README.md new file mode 100755 index 0000000000000000000000000000000000000000..4bf224f6b341e21f549a27a000d8400c4909c6c1 --- /dev/null +++ b/mGPT/data/humanml/README.md @@ -0,0 +1 @@ +This code is based on https://github.com/EricGuo5513/text-to-motion.git \ No newline at end of file diff --git a/mGPT/data/humanml/__init__.py b/mGPT/data/humanml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e753581956dcf7c8d92d0934e64486c884f60ac --- /dev/null +++ b/mGPT/data/humanml/__init__.py @@ -0,0 +1,7 @@ +from .dataset_t2m import Text2MotionDataset +from .dataset_t2m_eval import Text2MotionDatasetEval +from .dataset_t2m_cb import Text2MotionDatasetCB +from .dataset_t2m_token import Text2MotionDatasetToken +from .dataset_t2m_m2t import Text2MotionDatasetM2T +from .dataset_m import MotionDataset +from .dataset_m_vq import MotionDatasetVQ diff --git a/mGPT/data/humanml/common/quaternion.py b/mGPT/data/humanml/common/quaternion.py new file mode 100755 index 0000000000000000000000000000000000000000..dca3d890080a4e91e3f275f442b0aed006562881 --- /dev/null +++ b/mGPT/data/humanml/common/quaternion.py @@ -0,0 +1,423 @@ +# Copyright (c) 2018-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import numpy as np + +_EPS4 = np.finfo(float).eps * 4.0 + +_FLOAT_EPS = np.finfo(np.float64).eps + +# PyTorch-backed implementations +def qinv(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + + +def qinv_np(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return qinv(torch.from_numpy(q).float()).numpy() + + +def qnormalize(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return q / torch.norm(q, dim=-1, keepdim=True) + + +def qmul(q, r): + """ + Multiply quaternion(s) q with quaternion(s) r. + Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. + Returns q*r as a tensor of shape (*, 4). + """ + assert q.shape[-1] == 4 + assert r.shape[-1] == 4 + + original_shape = q.shape + + # Compute outer product + terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) + + w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] + x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] + y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] + z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] + return torch.stack((w, x, y, z), dim=1).view(original_shape) + + +def qrot(q, v): + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + # print(q.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) + + +def qeuler(q, order, epsilon=0, deg=True): + """ + Convert quaternion(s) q to Euler angles. + Expects a tensor of shape (*, 4), where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + + original_shape = list(q.shape) + original_shape[-1] = 3 + q = q.view(-1, 4) + + q0 = q[:, 0] + q1 = q[:, 1] + q2 = q[:, 2] + q3 = q[:, 3] + + if order == 'xyz': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + elif order == 'yzx': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) + elif order == 'zxy': + x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'xzy': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) + elif order == 'yxz': + x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'zyx': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + else: + raise + + if deg: + return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi + else: + return torch.stack((x, y, z), dim=1).view(original_shape) + + +# Numpy-backed implementations + +def qmul_np(q, r): + q = torch.from_numpy(q).contiguous().float() + r = torch.from_numpy(r).contiguous().float() + return qmul(q, r).numpy() + + +def qrot_np(q, v): + q = torch.from_numpy(q).contiguous().float() + v = torch.from_numpy(v).contiguous().float() + return qrot(q, v).numpy() + + +def qeuler_np(q, order, epsilon=0, use_gpu=False): + if use_gpu: + q = torch.from_numpy(q).cuda().float() + return qeuler(q, order, epsilon).cpu().numpy() + else: + q = torch.from_numpy(q).contiguous().float() + return qeuler(q, order, epsilon).numpy() + + +def qfix(q): + """ + Enforce quaternion continuity across the time dimension by selecting + the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) + between two consecutive frames. + + Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. + Returns a tensor of the same shape. + """ + assert len(q.shape) == 3 + assert q.shape[-1] == 4 + + result = q.copy() + dot_products = np.sum(q[1:] * q[:-1], axis=2) + mask = dot_products < 0 + mask = (np.cumsum(mask, axis=0) % 2).astype(bool) + result[1:][mask] *= -1 + return result + + +def euler2quat(e, order, deg=True): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.view(-1, 3) + + ## if euler angles in degrees + if deg: + e = e * np.pi / 180. + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) + ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) + rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.view(original_shape) + + +def expmap_to_quaternion(e): + """ + Convert axis-angle rotations (aka exponential maps) to quaternions. + Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". + Expects a tensor of shape (*, 3), where * denotes any number of dimensions. + Returns a tensor of shape (*, 4). + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + e = e.reshape(-1, 3) + + theta = np.linalg.norm(e, axis=1).reshape(-1, 1) + w = np.cos(0.5 * theta).reshape(-1, 1) + xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e + return np.concatenate((w, xyz), axis=1).reshape(original_shape) + + +def euler_to_quaternion(e, order): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.reshape(-1, 3) + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) + ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) + rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul_np(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.reshape(original_shape) + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def quaternion_to_matrix_np(quaternions): + q = torch.from_numpy(quaternions).contiguous().float() + return quaternion_to_matrix(q).numpy() + + +def quaternion_to_cont6d_np(quaternions): + rotation_mat = quaternion_to_matrix_np(quaternions) + cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) + return cont_6d + + +def quaternion_to_cont6d(quaternions): + rotation_mat = quaternion_to_matrix(quaternions) + cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) + return cont_6d + + +def cont6d_to_matrix(cont6d): + assert cont6d.shape[-1] == 6, "The last dimension must be 6" + x_raw = cont6d[..., 0:3] + y_raw = cont6d[..., 3:6] + + x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) + z = torch.cross(x, y_raw, dim=-1) + z = z / torch.norm(z, dim=-1, keepdim=True) + + y = torch.cross(z, x, dim=-1) + + x = x[..., None] + y = y[..., None] + z = z[..., None] + + mat = torch.cat([x, y, z], dim=-1) + return mat + + +def cont6d_to_matrix_np(cont6d): + q = torch.from_numpy(cont6d).contiguous().float() + return cont6d_to_matrix(q).numpy() + + +def qpow(q0, t, dtype=torch.float): + ''' q0 : tensor of quaternions + t: tensor of powers + ''' + q0 = qnormalize(q0) + theta0 = torch.acos(q0[..., 0]) + + ## if theta0 is close to zero, add epsilon to avoid NaNs + mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) + theta0 = (1 - mask) * theta0 + mask * 10e-10 + v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) + + if isinstance(t, torch.Tensor): + q = torch.zeros(t.shape + q0.shape) + theta = t.view(-1, 1) * theta0.view(1, -1) + else: ## if t is a number + q = torch.zeros(q0.shape) + theta = t * theta0 + + q[..., 0] = torch.cos(theta) + q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) + + return q.to(dtype) + + +def qslerp(q0, q1, t): + ''' + q0: starting quaternion + q1: ending quaternion + t: array of points along the way + + Returns: + Tensor of Slerps: t.shape + q0.shape + ''' + + q0 = qnormalize(q0) + q1 = qnormalize(q1) + q_ = qpow(qmul(q1, qinv(q0)), t) + + return qmul(q_, + q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous()) + + +def qbetween(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v = torch.cross(v0, v1) + w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, + keepdim=True) + return qnormalize(torch.cat([w, v], dim=-1)) + + +def qbetween_np(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v0 = torch.from_numpy(v0).float() + v1 = torch.from_numpy(v1).float() + return qbetween(v0, v1).numpy() + + +def lerp(p0, p1, t): + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]) + + new_shape = t.shape + p0.shape + new_view_t = t.shape + torch.Size([1] * len(p0.shape)) + new_view_p = torch.Size([1] * len(t.shape)) + p0.shape + p0 = p0.view(new_view_p).expand(new_shape) + p1 = p1.view(new_view_p).expand(new_shape) + t = t.view(new_view_t).expand(new_shape) + + return p0 + t * (p1 - p0) diff --git a/mGPT/data/humanml/common/skeleton.py b/mGPT/data/humanml/common/skeleton.py new file mode 100644 index 0000000000000000000000000000000000000000..b2ae85ad14df8c1a8d77e689b1cffbc6c814a979 --- /dev/null +++ b/mGPT/data/humanml/common/skeleton.py @@ -0,0 +1,199 @@ +from .quaternion import * +import scipy.ndimage.filters as filters + +class Skeleton(object): + def __init__(self, offset, kinematic_tree, device): + self.device = device + self._raw_offset_np = offset.numpy() + self._raw_offset = offset.clone().detach().to(device).float() + self._kinematic_tree = kinematic_tree + self._offset = None + self._parents = [0] * len(self._raw_offset) + self._parents[0] = -1 + for chain in self._kinematic_tree: + for j in range(1, len(chain)): + self._parents[chain[j]] = chain[j-1] + + def njoints(self): + return len(self._raw_offset) + + def offset(self): + return self._offset + + def set_offset(self, offsets): + self._offset = offsets.clone().detach().to(self.device).float() + + def kinematic_tree(self): + return self._kinematic_tree + + def parents(self): + return self._parents + + # joints (batch_size, joints_num, 3) + def get_offsets_joints_batch(self, joints): + assert len(joints.shape) == 3 + _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() + for i in range(1, self._raw_offset.shape[0]): + _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] + + self._offset = _offsets.detach() + return _offsets + + # joints (joints_num, 3) + def get_offsets_joints(self, joints): + assert len(joints.shape) == 2 + _offsets = self._raw_offset.clone() + for i in range(1, self._raw_offset.shape[0]): + # print(joints.shape) + _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] + + self._offset = _offsets.detach() + return _offsets + + # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder + # joints (batch_size, joints_num, 3) + def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): + assert len(face_joint_idx) == 4 + '''Get Forward Direction''' + l_hip, r_hip, sdr_r, sdr_l = face_joint_idx + across1 = joints[:, r_hip] - joints[:, l_hip] + across2 = joints[:, sdr_r] - joints[:, sdr_l] + across = across1 + across2 + across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] + # print(across1.shape, across2.shape) + + # forward (batch_size, 3) + forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + if smooth_forward: + forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') + # forward (batch_size, 3) + forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] + + '''Get Root Rotation''' + target = np.array([[0,0,1]]).repeat(len(forward), axis=0) + root_quat = qbetween_np(forward, target) + + '''Inverse Kinematics''' + # quat_params (batch_size, joints_num, 4) + # print(joints.shape[:-1]) + quat_params = np.zeros(joints.shape[:-1] + (4,)) + # print(quat_params.shape) + root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + quat_params[:, 0] = root_quat + # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + for chain in self._kinematic_tree: + R = root_quat + for j in range(len(chain) - 1): + # (batch, 3) + u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) + # print(u.shape) + # (batch, 3) + v = joints[:, chain[j+1]] - joints[:, chain[j]] + v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] + # print(u.shape, v.shape) + rot_u_v = qbetween_np(u, v) + + R_loc = qmul_np(qinv_np(R), rot_u_v) + + quat_params[:,chain[j + 1], :] = R_loc + R = qmul_np(R, R_loc) + + return quat_params + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device) + for i in range(1, len(chain)): + R = qmul(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] + return joints + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(quat_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) + for i in range(1, len(chain)): + R = qmul_np(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] + return joints + + def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(cont6d_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix_np(cont6d_params[:, 0]) + else: + matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) + for i in range(1, len(chain)): + matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]][..., np.newaxis] + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + # skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) + joints[..., 0, :] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix(cont6d_params[:, 0]) + else: + matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device) + for i in range(1, len(chain)): + matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]].unsqueeze(-1) + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + + + + diff --git a/mGPT/data/humanml/dataset_m.py b/mGPT/data/humanml/dataset_m.py new file mode 100644 index 0000000000000000000000000000000000000000..241cabcfbaa0778922e052a4cd66721215a9d051 --- /dev/null +++ b/mGPT/data/humanml/dataset_m.py @@ -0,0 +1,156 @@ +import os +import rich +import random +import pickle +import codecs as cs +import numpy as np +from torch.utils import data +from rich.progress import track +from os.path import join as pjoin + + +class MotionDataset(data.Dataset): + def __init__( + self, + data_root, + split, + mean, + std, + max_motion_length=196, + min_motion_length=20, + unit_length=4, + fps=20, + tmpFile=True, + tiny=False, + debug=False, + **kwargs, + ): + + # restrian the length of motion and text + self.max_motion_length = max_motion_length + self.min_motion_length = min_motion_length + self.unit_length = unit_length + + # Data mean and std + self.mean = mean + self.std = std + + # Data path + split_file = pjoin(data_root, split + '.txt') + motion_dir = pjoin(data_root, 'new_joint_vecs') + text_dir = pjoin(data_root, 'texts') + + # Data id list + self.id_list = [] + with cs.open(split_file, "r") as f: + for line in f.readlines(): + self.id_list.append(line.strip()) + + # Debug mode + if tiny or debug: + enumerator = enumerate( + track( + self.id_list, + f"Loading HumanML3D {split}", + )) + maxdata = 100 + subset = '_tiny' + else: + enumerator = enumerate(self.id_list) + maxdata = 1e10 + subset = '' + + new_name_list = [] + motion_dict = {} + + # Fast loading + if os.path.exists(pjoin(data_root, f'tmp/{split}{subset}_motion.pkl')): + with rich.progress.open(pjoin(data_root, f'tmp/{split}{subset}_motion.pkl'), + 'rb', description=f"Loading HumanML3D {split}") as file: + motion_dict = pickle.load(file) + with open(pjoin(data_root, f'tmp/{split}{subset}_index.pkl'), 'rb') as file: + new_name_list = pickle.load(file) + else: + for idx, name in enumerator: + if len(new_name_list) > maxdata: + break + try: + motion = [np.load(pjoin(motion_dir, name + ".npy"))] + + # Read text + with cs.open(pjoin(text_dir, name + '.txt')) as f: + text_data = [] + flag = False + lines = f.readlines() + + for line in lines: + try: + line_split = line.strip().split('#') + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + if f_tag == 0.0 and to_tag == 0.0: + flag = True + else: + motion_new = [tokens[int(f_tag*fps/unit_length) : int(to_tag*fps/unit_length)] for tokens in motion if int(f_tag*fps/unit_length) < int(to_tag*fps/unit_length)] + + if len(motion_new) == 0: + continue + new_name = '%s_%f_%f'%(name, f_tag, to_tag) + + motion_dict[new_name] = { + 'motion': motion_new, + "length": [len(m[0]) for m in motion_new]} + new_name_list.append(new_name) + except: + pass + + if flag: + motion_dict[name] = { + 'motion': motion, + "length": [len(m[0]) for m in motion]} + new_name_list.append(name) + except: + pass + + if tmpFile: + os.makedirs(pjoin(data_root, 'tmp'), exist_ok=True) + + with open(pjoin(data_root, f'tmp/{split}{subset}_motion.pkl'),'wb') as file: + pickle.dump(motion_dict, file) + with open(pjoin(data_root, f'tmp/{split}{subset}_index.pkl'), 'wb') as file: + pickle.dump(new_name_list, file) + + self.motion_dict = motion_dict + self.name_list = new_name_list + self.nfeats = motion_dict[new_name_list[0]]['motion'][0].shape[1] + + def __len__(self): + return len(self.name_list) + + def __getitem__(self, item): + data = self.motion_dict[self.name_list[item]] + motion_list, m_length = data["motion"], data["length"] + + # Randomly select a motion + motion = random.choice(motion_list) + + # Crop the motions in to times of 4, and introduce small variations + if self.unit_length < 10: + coin2 = np.random.choice(["single", "single", "double"]) + else: + coin2 = "single" + + if coin2 == "double": + m_length = (m_length // self.unit_length - 1) * self.unit_length + elif coin2 == "single": + m_length = (m_length // self.unit_length) * self.unit_length + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx:idx + m_length] + + # Z Normalization + motion = (motion - self.mean) / self.std + + return None, motion, m_length, None, None, None, None, diff --git a/mGPT/data/humanml/dataset_m_vq.py b/mGPT/data/humanml/dataset_m_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa3ae5086d8da32dd49a0d4766200df13111619 --- /dev/null +++ b/mGPT/data/humanml/dataset_m_vq.py @@ -0,0 +1,54 @@ +import random +import codecs as cs +import numpy as np +from torch.utils import data +from rich.progress import track +from os.path import join as pjoin +from .dataset_m import MotionDataset +from .dataset_t2m import Text2MotionDataset + + +class MotionDatasetVQ(Text2MotionDataset): + def __init__( + self, + data_root, + split, + mean, + std, + max_motion_length, + min_motion_length, + win_size, + unit_length=4, + fps=20, + tmpFile=True, + tiny=False, + debug=False, + **kwargs, + ): + super().__init__(data_root, split, mean, std, max_motion_length, + min_motion_length, unit_length, fps, tmpFile, tiny, + debug, **kwargs) + + # Filter out the motions that are too short + self.window_size = win_size + name_list = list(self.name_list) + for name in self.name_list: + motion = self.data_dict[name]["motion"] + if motion.shape[0] < self.window_size: + name_list.remove(name) + self.data_dict.pop(name) + self.name_list = name_list + + def __len__(self): + return len(self.name_list) + + def __getitem__(self, item): + idx = self.pointer + item + data = self.data_dict[self.name_list[idx]] + motion, length = data["motion"], data["length"] + + idx = random.randint(0, motion.shape[0] - self.window_size) + motion = motion[idx:idx + self.window_size] + motion = (motion - self.mean) / self.std + + return None, motion, length, None, None, None, None, diff --git a/mGPT/data/humanml/dataset_t2m.py b/mGPT/data/humanml/dataset_t2m.py new file mode 100644 index 0000000000000000000000000000000000000000..34606817e390562ba3776db0c8e5aa82f6720b40 --- /dev/null +++ b/mGPT/data/humanml/dataset_t2m.py @@ -0,0 +1,211 @@ +import os +import rich +import random +import pickle +import codecs as cs +import numpy as np +from torch.utils import data +from rich.progress import track +from os.path import join as pjoin + + +class Text2MotionDataset(data.Dataset): + + def __init__( + self, + data_root, + split, + mean, + std, + max_motion_length=196, + min_motion_length=40, + unit_length=4, + fps=20, + tmpFile=True, + tiny=False, + debug=False, + **kwargs, + ): + + # restrian the length of motion and text + self.max_length = 20 + self.max_motion_length = max_motion_length + self.min_motion_length = min_motion_length + self.unit_length = unit_length + + # Data mean and std + self.mean = mean + self.std = std + + # Data path + split_file = pjoin(data_root, split + '.txt') + motion_dir = pjoin(data_root, 'new_joint_vecs') + text_dir = pjoin(data_root, 'texts') + + # Data id list + self.id_list = [] + with cs.open(split_file, "r") as f: + for line in f.readlines(): + self.id_list.append(line.strip()) + + # Debug mode + if tiny or debug: + enumerator = enumerate(self.id_list) + maxdata = 100 + subset = '_tiny' + else: + enumerator = enumerate( + track( + self.id_list, + f"Loading HumanML3D {split}", + )) + maxdata = 1e10 + subset = '' + + new_name_list = [] + length_list = [] + data_dict = {} + + # Fast loading + if os.path.exists(pjoin(data_root, f'tmp/{split}{subset}_data.pkl')): + if tiny or debug: + with open(pjoin(data_root, f'tmp/{split}{subset}_data.pkl'), + 'rb') as file: + data_dict = pickle.load(file) + else: + with rich.progress.open( + pjoin(data_root, f'tmp/{split}{subset}_data.pkl'), + 'rb', + description=f"Loading HumanML3D {split}") as file: + data_dict = pickle.load(file) + with open(pjoin(data_root, f'tmp/{split}{subset}_index.pkl'), + 'rb') as file: + name_list = pickle.load(file) + for name in new_name_list: + length_list.append(data_dict[name]['length']) + + else: + for idx, name in enumerator: + if len(new_name_list) > maxdata: + break + try: + motion = np.load(pjoin(motion_dir, name + ".npy")) + if (len(motion)) < self.min_motion_length or (len(motion) + >= 200): + continue + + # Read text + text_data = [] + flag = False + with cs.open(pjoin(text_dir, name + '.txt')) as f: + lines = f.readlines() + for line in lines: + text_dict = {} + line_split = line.strip().split('#') + caption = line_split[0] + t_tokens = line_split[1].split(' ') + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict['caption'] = caption + text_dict['tokens'] = t_tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + motion_new = motion[int(f_tag * + fps):int(to_tag * fps)] + if (len(motion_new) + ) < self.min_motion_length or ( + len(motion_new) >= 200): + continue + new_name = random.choice( + 'ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + while new_name in new_name_list: + new_name = random.choice( + 'ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + name_count = 1 + while new_name in data_dict: + new_name += '_' + name_count + name_count += 1 + data_dict[new_name] = { + 'motion': motion_new, + "length": len(motion_new), + 'text': [text_dict] + } + new_name_list.append(new_name) + length_list.append(len(motion_new)) + + if flag: + data_dict[name] = { + 'motion': motion, + "length": len(motion), + 'text': text_data + } + new_name_list.append(name) + length_list.append(len(motion)) + except: + pass + + name_list, length_list = zip( + *sorted(zip(new_name_list, length_list), key=lambda x: x[1])) + + if tmpFile: + os.makedirs(pjoin(data_root, 'tmp'), exist_ok=True) + with open(pjoin(data_root, f'tmp/{split}{subset}_data.pkl'), + 'wb') as file: + pickle.dump(data_dict, file) + with open(pjoin(data_root, f'tmp/{split}{subset}_index.pkl'), + 'wb') as file: + pickle.dump(name_list, file) + + self.length_arr = np.array(length_list) + self.data_dict = data_dict + self.name_list = name_list + self.nfeats = data_dict[name_list[0]]['motion'].shape[1] + self.reset_max_len(self.max_length) + + def reset_max_len(self, length): + assert length <= self.max_motion_length + self.pointer = np.searchsorted(self.length_arr, length) + print("Pointer Pointing at %d" % self.pointer) + self.max_length = length + + def __len__(self): + return len(self.name_list) - self.pointer + + def __getitem__(self, item): + idx = self.pointer + item + data = self.data_dict[self.name_list[idx]] + motion, m_length, text_list = data["motion"], data["length"], data[ + "text"] + + # Randomly select a caption + text_data = random.choice(text_list) + caption = text_data["caption"] + + all_captions = [ + ' '.join([token.split('/')[0] for token in text_dic['tokens']]) + for text_dic in text_list + ] + + # Crop the motions in to times of 4, and introduce small variations + if self.unit_length < 10: + coin2 = np.random.choice(["single", "single", "double"]) + else: + coin2 = "single" + + if coin2 == "double": + m_length = (m_length // self.unit_length - 1) * self.unit_length + elif coin2 == "single": + m_length = (m_length // self.unit_length) * self.unit_length + + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx:idx + m_length] + + # Z Normalization + motion = (motion - self.mean) / self.std + + return caption, motion, m_length, None, None, None, None, all_captions diff --git a/mGPT/data/humanml/dataset_t2m_cb.py b/mGPT/data/humanml/dataset_t2m_cb.py new file mode 100644 index 0000000000000000000000000000000000000000..6f7e1feb2d4f4a824ff8af5128d8cf78b56f92b8 --- /dev/null +++ b/mGPT/data/humanml/dataset_t2m_cb.py @@ -0,0 +1,211 @@ +import rich +import random +import pickle +import os +import numpy as np +import codecs as cs +from torch.utils import data +from os.path import join as pjoin +from rich.progress import track +import json +import spacy + +class Text2MotionDatasetCB(data.Dataset): + def __init__( + self, + data_root, + split, + mean, + std, + max_motion_length=196, + min_motion_length=20, + unit_length=4, + fps=20, + tmpFile=True, + tiny=False, + debug=False, + stage='lm_pretrain', + code_path='VQVAE', + task_path=None, + std_text=False, + **kwargs, + ): + self.tiny = tiny + self.unit_length = unit_length + + # Data mean and std + self.mean = mean + self.std = std + + # Data path + split = 'train' + split_file = pjoin(data_root, split + '.txt') + motion_dir = pjoin(data_root, code_path) + text_dir = pjoin(data_root, 'texts') + + if task_path: + instructions = task_path + elif stage == 'lm_pretrain': + instructions = pjoin(data_root, 'template_pretrain.json') + elif stage in ['lm_instruct', "lm_rl"]: + instructions = pjoin(data_root, 'template_instructions.json') + else: + raise NotImplementedError(f"stage {stage} not implemented") + + # Data id list + self.id_list = [] + with cs.open(split_file, "r") as f: + for line in f.readlines(): + self.id_list.append(line.strip()) + + # Debug mode + if tiny or debug: + enumerator = enumerate(self.id_list) + maxdata = 100 + subset = '_tiny' + else: + enumerator = enumerate( + track( + self.id_list, + f"Loading HumanML3D {split}", + )) + maxdata = 1e10 + subset = '' + + new_name_list = [] + data_dict = {} + + # Fast loading + for i, name in enumerator: + if len(new_name_list) > maxdata: + break + try: + # Load motion tokens + m_token_list = np.load(pjoin(motion_dir, f'{name}.npy')) + # Read text + with cs.open(pjoin(text_dir, name + '.txt')) as f: + text_data = [] + flag = False + lines = f.readlines() + + for line in lines: + try: + text_dict = {} + line_split = line.strip().split('#') + caption = line_split[0] + t_tokens = line_split[1].split(' ') + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict['caption'] = caption + text_dict['tokens'] = t_tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + m_token_list_new = [ + tokens[int(f_tag * fps / unit_length + ):int(to_tag * fps / + unit_length)] + for tokens in m_token_list + if int(f_tag * fps / unit_length) < + int(to_tag * fps / unit_length) + ] + + if len(m_token_list_new) == 0: + continue + new_name = '%s_%f_%f' % (name, f_tag, + to_tag) + + data_dict[new_name] = { + 'm_token_list': m_token_list_new, + 'text': [text_dict] + } + new_name_list.append(new_name) + except: + pass + + if flag: + data_dict[name] = { + 'm_token_list': m_token_list, + 'text': text_data + } + new_name_list.append(name) + except: + pass + + if tmpFile: + os.makedirs(pjoin(data_root, 'tmp'), exist_ok=True) + with open( + pjoin(data_root, + f'tmp/{split}{subset}_tokens_data.pkl'), + 'wb') as file: + pickle.dump(data_dict, file) + with open( + pjoin(data_root, + f'tmp/{split}{subset}_tokens_index.pkl'), + 'wb') as file: + pickle.dump(new_name_list, file) + + self.data_dict = data_dict + self.name_list = new_name_list + self.nlp = spacy.load('en_core_web_sm') + self.std_text = std_text + self.instructions = json.load(open(instructions, 'r')) + self.tasks = [] + for task in self.instructions.keys(): + for subtask in self.instructions[task].keys(): + self.tasks.append(self.instructions[task][subtask]) + + def __len__(self): + return len(self.name_list) * len(self.tasks) + + def __getitem__(self, item): + data_idx = item % len(self.name_list) + task_idx = item // len(self.name_list) + + data = self.data_dict[self.name_list[data_idx]] + m_token_list, text_list = data['m_token_list'], data['text'] + + m_tokens = random.choice(m_token_list) + text_data = random.choice(text_list) + caption = text_data['caption'] + if self.std_text: + doc = self.nlp(caption) + word_list = [] + pos_list = [] + for token in doc: + word = token.text + if not word.isalpha(): + continue + if (token.pos_ == 'NOUN' + or token.pos_ == 'VERB') and (word != 'left'): + word_list.append(token.lemma_) + else: + word_list.append(word) + pos_list.append(token.pos_) + + caption = ' '.join(word_list) + + all_captions = [ + ' '.join([token.split('/')[0] for token in text_dic['tokens']]) + for text_dic in text_list + ] + + coin = np.random.choice([False, False, True]) + + if coin: + # drop one token at the head or tail + coin2 = np.random.choice([True, False]) + if coin2: + m_tokens = m_tokens[:-1] + else: + m_tokens = m_tokens[1:] + + m_tokens_len = m_tokens.shape[0] + + tasks = self.tasks[task_idx] + + return caption, m_tokens, m_tokens_len, None, None, None, None, all_captions, tasks diff --git a/mGPT/data/humanml/dataset_t2m_eval.py b/mGPT/data/humanml/dataset_t2m_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..4162981c1362ecc2f1637dad7415951325c36271 --- /dev/null +++ b/mGPT/data/humanml/dataset_t2m_eval.py @@ -0,0 +1,92 @@ +import random +import numpy as np +from .dataset_t2m import Text2MotionDataset + + +class Text2MotionDatasetEval(Text2MotionDataset): + + def __init__( + self, + data_root, + split, + mean, + std, + w_vectorizer, + max_motion_length=196, + min_motion_length=40, + unit_length=4, + fps=20, + tmpFile=True, + tiny=False, + debug=False, + **kwargs, + ): + super().__init__(data_root, split, mean, std, max_motion_length, + min_motion_length, unit_length, fps, tmpFile, tiny, + debug, **kwargs) + + self.w_vectorizer = w_vectorizer + + + def __getitem__(self, item): + # Get text data + idx = self.pointer + item + data = self.data_dict[self.name_list[idx]] + motion, m_length, text_list = data["motion"], data["length"], data["text"] + + all_captions = [ + ' '.join([token.split('/')[0] for token in text_dic['tokens']]) + for text_dic in text_list + ] + + if len(all_captions) > 3: + all_captions = all_captions[:3] + elif len(all_captions) == 2: + all_captions = all_captions + all_captions[0:1] + elif len(all_captions) == 1: + all_captions = all_captions * 3 + + # Randomly select a caption + text_data = random.choice(text_list) + caption, tokens = text_data["caption"], text_data["tokens"] + + # Text + max_text_len = 20 + if len(tokens) < max_text_len: + # pad with "unk" + tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"] + sent_len = len(tokens) + tokens = tokens + ["unk/OTHER"] * (max_text_len + 2 - sent_len) + else: + # crop + tokens = tokens[:max_text_len] + tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"] + sent_len = len(tokens) + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + + # Random crop + if self.unit_length < 10: + coin2 = np.random.choice(["single", "single", "double"]) + else: + coin2 = "single" + + if coin2 == "double": + m_length = (m_length // self.unit_length - 1) * self.unit_length + elif coin2 == "single": + m_length = (m_length // self.unit_length) * self.unit_length + + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx:idx + m_length] + + # Z Normalization + motion = (motion - self.mean) / self.std + + return caption, motion, m_length, word_embeddings, pos_one_hots, sent_len, "_".join( + tokens), all_captions diff --git a/mGPT/data/humanml/dataset_t2m_m2t.py b/mGPT/data/humanml/dataset_t2m_m2t.py new file mode 100644 index 0000000000000000000000000000000000000000..259078c87a6f66d095404616ad9dba55a59b490a --- /dev/null +++ b/mGPT/data/humanml/dataset_t2m_m2t.py @@ -0,0 +1,119 @@ +import random +import numpy as np +from torch.utils import data +from .dataset_t2m import Text2MotionDataset +import codecs as cs +from os.path import join as pjoin + + +class Text2MotionDatasetM2T(data.Dataset): + + def __init__( + self, + data_root, + split, + mean, + std, + max_motion_length=196, + min_motion_length=40, + unit_length=4, + fps=20, + tmpFile=True, + tiny=False, + debug=False, + **kwargs, + ): + + self.max_motion_length = max_motion_length + self.min_motion_length = min_motion_length + self.unit_length = unit_length + + # Data mean and std + self.mean = mean + self.std = std + + # Data path + split_file = pjoin(data_root, split + '.txt') + motion_dir = pjoin(data_root, 'new_joint_vecs') + text_dir = pjoin(data_root, 'texts') + + # Data id list + self.id_list = [] + with cs.open(split_file, "r") as f: + for line in f.readlines(): + self.id_list.append(line.strip()) + + new_name_list = [] + length_list = [] + data_dict = {} + for name in self.id_list: + # try: + motion = np.load(pjoin(motion_dir, name + '.npy')) + if (len(motion)) < self.min_motion_length or (len(motion) >= 200): + continue + + + text_data = [] + flag = False + + with cs.open(pjoin(text_dir, name + '.txt')) as f: + for line in f.readlines(): + text_dict = {} + line_split = line.strip().split('#') + caption = line_split[0] + tokens = line_split[1].split(' ') + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict['caption'] = caption + text_dict['tokens'] = tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + try: + n_motion = motion[int(f_tag*20) : int(to_tag*20)] + + if (len(n_motion)) < min_motion_length or (len(n_motion) >= 200): + continue + + new_name = "%s_%f_%f"%(name, f_tag, to_tag) + data_dict[new_name] = {'motion': n_motion, + 'length': len(n_motion), + 'text':[text_dict]} + new_name_list.append(new_name) + except: + print(line_split) + print(line_split[2], line_split[3], f_tag, to_tag, name) + if flag: + data_dict[name] = {'motion': motion, + 'length': len(motion), + 'name': name, + 'text': text_data} + + new_name_list.append(name) + length_list.append(len(motion)) + # except: + # # Some motion may not exist in KIT dataset + # pass + + self.length_arr = np.array(length_list) + self.data_dict = data_dict + self.name_list = new_name_list + self.nfeats = motion.shape[-1] + + + def __len__(self): + return len(self.data_dict) + + def __getitem__(self, item): + name = self.name_list[item] + data = self.data_dict[name] + motion, m_length = data['motion'], data['length'] + + "Z Normalization" + motion = (motion - self.mean) / self.std + + return name, motion, m_length, True, True, True, True, True, True diff --git a/mGPT/data/humanml/dataset_t2m_token.py b/mGPT/data/humanml/dataset_t2m_token.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f1e5435244897af485c860b5c57385d02d7791 --- /dev/null +++ b/mGPT/data/humanml/dataset_t2m_token.py @@ -0,0 +1,86 @@ +import random +import numpy as np +from torch.utils import data +from .dataset_t2m import Text2MotionDataset +import codecs as cs +from os.path import join as pjoin + + +class Text2MotionDatasetToken(data.Dataset): + + def __init__( + self, + data_root, + split, + mean, + std, + max_motion_length=196, + min_motion_length=40, + unit_length=4, + fps=20, + tmpFile=True, + tiny=False, + debug=False, + **kwargs, + ): + + self.max_motion_length = max_motion_length + self.min_motion_length = min_motion_length + self.unit_length = unit_length + + # Data mean and std + self.mean = mean + self.std = std + + # Data path + split_file = pjoin(data_root, split + '.txt') + motion_dir = pjoin(data_root, 'new_joint_vecs') + text_dir = pjoin(data_root, 'texts') + + # Data id list + self.id_list = [] + with cs.open(split_file, "r") as f: + for line in f.readlines(): + self.id_list.append(line.strip()) + + new_name_list = [] + length_list = [] + data_dict = {} + for name in self.id_list: + try: + motion = np.load(pjoin(motion_dir, name + '.npy')) + if (len(motion)) < self.min_motion_length or (len(motion) >= 200): + continue + + data_dict[name] = {'motion': motion, + 'length': len(motion), + 'name': name} + new_name_list.append(name) + length_list.append(len(motion)) + except: + # Some motion may not exist in KIT dataset + pass + + self.length_arr = np.array(length_list) + self.data_dict = data_dict + self.name_list = new_name_list + self.nfeats = motion.shape[-1] + + + def __len__(self): + return len(self.data_dict) + + def __getitem__(self, item): + name = self.name_list[item] + data = self.data_dict[name] + motion, m_length = data['motion'], data['length'] + + m_length = (m_length // self.unit_length) * self.unit_length + + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx:idx+m_length] + + "Z Normalization" + motion = (motion - self.mean) / self.std + + return name, motion, m_length, True, True, True, True, True, True diff --git a/mGPT/data/humanml/scripts/motion_process.py b/mGPT/data/humanml/scripts/motion_process.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3395cfa348685d3df88375943bfe3c80980253 --- /dev/null +++ b/mGPT/data/humanml/scripts/motion_process.py @@ -0,0 +1,529 @@ +from os.path import join as pjoin + +from ..common.skeleton import Skeleton +import numpy as np +import os +from ..common.quaternion import * +from ..utils.paramUtil import * + +import torch +from tqdm import tqdm + +# positions (batch, joint_num, 3) +def uniform_skeleton(positions, target_offset): + src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu') + src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0])) + src_offset = src_offset.numpy() + tgt_offset = target_offset.numpy() + # print(src_offset) + # print(tgt_offset) + '''Calculate Scale Ratio as the ratio of legs''' + src_leg_len = np.abs(src_offset[l_idx1]).max() + np.abs(src_offset[l_idx2]).max() + tgt_leg_len = np.abs(tgt_offset[l_idx1]).max() + np.abs(tgt_offset[l_idx2]).max() + + scale_rt = tgt_leg_len / src_leg_len + # print(scale_rt) + src_root_pos = positions[:, 0] + tgt_root_pos = src_root_pos * scale_rt + + '''Inverse Kinematics''' + quat_params = src_skel.inverse_kinematics_np(positions, face_joint_indx) + # print(quat_params.shape) + + '''Forward Kinematics''' + src_skel.set_offset(target_offset) + new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos) + return new_joints + + +def extract_features(positions, feet_thre, n_raw_offsets, kinematic_chain, face_joint_indx, fid_r, fid_l): + global_positions = positions.copy() + """ Get Foot Contacts """ + + def foot_detect(positions, thres): + velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0]) + + feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2 + feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2 + feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2 + # feet_l_h = positions[:-1,fid_l,1] + # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float64) + feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float64) + + feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2 + feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2 + feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2 + # feet_r_h = positions[:-1,fid_r,1] + # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float64) + feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float64) + return feet_l, feet_r + + # + feet_l, feet_r = foot_detect(positions, feet_thre) + # feet_l, feet_r = foot_detect(positions, 0.002) + + '''Quaternion and Cartesian representation''' + r_rot = None + + def get_rifke(positions): + '''Local pose''' + positions[..., 0] -= positions[:, 0:1, 0] + positions[..., 2] -= positions[:, 0:1, 2] + '''All pose face Z+''' + positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions) + return positions + + def get_quaternion(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False) + + '''Fix Quaternion Discontinuity''' + quat_params = qfix(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + quat_params[1:, 0] = r_velocity + # (seq_len, joints_num, 4) + return quat_params, r_velocity, velocity, r_rot + + def get_cont6d_params(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True) + + '''Quaternion to continuous 6D''' + cont_6d_params = quaternion_to_cont6d_np(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + # (seq_len, joints_num, 4) + return cont_6d_params, r_velocity, velocity, r_rot + + cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions) + positions = get_rifke(positions) + + # trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0) + # r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]]) + + # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*') + # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r') + # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g') + # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y') + # plt.xlabel('x') + # plt.ylabel('z') + # plt.axis('equal') + # plt.show() + + '''Root height''' + root_y = positions[:, 0, 1:2] + + '''Root rotation and linear velocity''' + # (seq_len-1, 1) rotation velocity along y-axis + # (seq_len-1, 2) linear velovity on xz plane + r_velocity = np.arcsin(r_velocity[:, 2:3]) + l_velocity = velocity[:, [0, 2]] + # print(r_velocity.shape, l_velocity.shape, root_y.shape) + root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1) + + '''Get Joint Rotation Representation''' + # (seq_len, (joints_num-1) *6) quaternion for skeleton joints + rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1) + + '''Get Joint Rotation Invariant Position Represention''' + # (seq_len, (joints_num-1)*3) local joint position + ric_data = positions[:, 1:].reshape(len(positions), -1) + + '''Get Joint Velocity Representation''' + # (seq_len-1, joints_num*3) + local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1), + global_positions[1:] - global_positions[:-1]) + local_vel = local_vel.reshape(len(local_vel), -1) + + data = root_data + data = np.concatenate([data, ric_data[:-1]], axis=-1) + data = np.concatenate([data, rot_data[:-1]], axis=-1) + # print(dataset.shape, local_vel.shape) + data = np.concatenate([data, local_vel], axis=-1) + data = np.concatenate([data, feet_l, feet_r], axis=-1) + + return data + + +def process_file(positions, feet_thre): + # (seq_len, joints_num, 3) + # '''Down Sample''' + # positions = positions[::ds_num] + + '''Uniform Skeleton''' + positions = uniform_skeleton(positions, tgt_offsets) + + '''Put on Floor''' + floor_height = positions.min(axis=0).min(axis=0)[1] + positions[:, :, 1] -= floor_height + # print(floor_height) + + # plot_3d_motion("./positions_1.mp4", kinematic_chain, positions, 'title', fps=20) + + '''XZ at origin''' + root_pos_init = positions[0] + root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1]) + positions = positions - root_pose_init_xz + + # '''Move the first pose to origin ''' + # root_pos_init = positions[0] + # positions = positions - root_pos_init[0] + + '''All initially face Z+''' + r_hip, l_hip, sdr_r, sdr_l = face_joint_indx + across1 = root_pos_init[r_hip] - root_pos_init[l_hip] + across2 = root_pos_init[sdr_r] - root_pos_init[sdr_l] + across = across1 + across2 + across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] + + # forward (3,), rotate around y-axis + forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + # forward (3,) + forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] + + # print(forward_init) + + target = np.array([[0, 0, 1]]) + root_quat_init = qbetween_np(forward_init, target) + root_quat_init = np.ones(positions.shape[:-1] + (4,)) * root_quat_init + + positions_b = positions.copy() + + positions = qrot_np(root_quat_init, positions) + + # plot_3d_motion("./positions_2.mp4", kinematic_chain, positions, 'title', fps=20) + + '''New ground truth positions''' + global_positions = positions.copy() + + # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*') + # plt.plot(positions[:, 0, 0], positions[:, 0, 2], marker='o', color='r') + # plt.xlabel('x') + # plt.ylabel('z') + # plt.axis('equal') + # plt.show() + + """ Get Foot Contacts """ + + def foot_detect(positions, thres): + velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0]) + + feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2 + feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2 + feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2 + # feet_l_h = positions[:-1,fid_l,1] + # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float64) + feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float64) + + feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2 + feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2 + feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2 + # feet_r_h = positions[:-1,fid_r,1] + # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float64) + feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float64) + return feet_l, feet_r + # + feet_l, feet_r = foot_detect(positions, feet_thre) + # feet_l, feet_r = foot_detect(positions, 0.002) + + '''Quaternion and Cartesian representation''' + r_rot = None + + def get_rifke(positions): + '''Local pose''' + positions[..., 0] -= positions[:, 0:1, 0] + positions[..., 2] -= positions[:, 0:1, 2] + '''All pose face Z+''' + positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions) + return positions + + def get_quaternion(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False) + + '''Fix Quaternion Discontinuity''' + quat_params = qfix(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + quat_params[1:, 0] = r_velocity + # (seq_len, joints_num, 4) + return quat_params, r_velocity, velocity, r_rot + + def get_cont6d_params(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True) + + '''Quaternion to continuous 6D''' + cont_6d_params = quaternion_to_cont6d_np(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + # (seq_len, joints_num, 4) + return cont_6d_params, r_velocity, velocity, r_rot + + cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions) + positions = get_rifke(positions) + + # trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0) + # r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]]) + + # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*') + # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r') + # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g') + # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y') + # plt.xlabel('x') + # plt.ylabel('z') + # plt.axis('equal') + # plt.show() + + '''Root height''' + root_y = positions[:, 0, 1:2] + + '''Root rotation and linear velocity''' + # (seq_len-1, 1) rotation velocity along y-axis + # (seq_len-1, 2) linear velovity on xz plane + r_velocity = np.arcsin(r_velocity[:, 2:3]) + l_velocity = velocity[:, [0, 2]] + # print(r_velocity.shape, l_velocity.shape, root_y.shape) + root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1) + + '''Get Joint Rotation Representation''' + # (seq_len, (joints_num-1) *6) quaternion for skeleton joints + rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1) + + '''Get Joint Rotation Invariant Position Represention''' + # (seq_len, (joints_num-1)*3) local joint position + ric_data = positions[:, 1:].reshape(len(positions), -1) + + '''Get Joint Velocity Representation''' + # (seq_len-1, joints_num*3) + local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1), + global_positions[1:] - global_positions[:-1]) + local_vel = local_vel.reshape(len(local_vel), -1) + + data = root_data + data = np.concatenate([data, ric_data[:-1]], axis=-1) + data = np.concatenate([data, rot_data[:-1]], axis=-1) + # print(dataset.shape, local_vel.shape) + data = np.concatenate([data, local_vel], axis=-1) + data = np.concatenate([data, feet_l, feet_r], axis=-1) + + return data, global_positions, positions, l_velocity + + +# Recover global angle and positions for rotation dataset +# root_rot_velocity (B, seq_len, 1) +# root_linear_velocity (B, seq_len, 2) +# root_y (B, seq_len, 1) +# ric_data (B, seq_len, (joint_num - 1)*3) +# rot_data (B, seq_len, (joint_num - 1)*6) +# local_velocity (B, seq_len, joint_num*3) +# foot contact (B, seq_len, 4) +def recover_root_rot_pos(data): + rot_vel = data[..., 0] + r_rot_ang = torch.zeros_like(rot_vel).to(data.device) + '''Get Y-axis rotation from rotation velocity''' + r_rot_ang[..., 1:] = rot_vel[..., :-1] + r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) + + r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device) + r_rot_quat[..., 0] = torch.cos(r_rot_ang) + r_rot_quat[..., 2] = torch.sin(r_rot_ang) + + r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device) + r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] + '''Add Y-axis rotation to root position''' + r_pos = qrot(qinv(r_rot_quat), r_pos) + + r_pos = torch.cumsum(r_pos, dim=-2) + + r_pos[..., 1] = data[..., 3] + return r_rot_quat, r_pos + + +def recover_from_rot(data, joints_num, skeleton): + r_rot_quat, r_pos = recover_root_rot_pos(data) + + r_rot_cont6d = quaternion_to_cont6d(r_rot_quat) + + start_indx = 1 + 2 + 1 + (joints_num - 1) * 3 + end_indx = start_indx + (joints_num - 1) * 6 + cont6d_params = data[..., start_indx:end_indx] + # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape) + cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1) + cont6d_params = cont6d_params.view(-1, joints_num, 6) + + positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos) + + return positions + +def recover_rot(data): + # dataset [bs, seqlen, 263/251] HumanML/KIT + joints_num = 22 if data.shape[-1] == 263 else 21 + r_rot_quat, r_pos = recover_root_rot_pos(data) + r_pos_pad = torch.cat([r_pos, torch.zeros_like(r_pos)], dim=-1).unsqueeze(-2) + r_rot_cont6d = quaternion_to_cont6d(r_rot_quat) + start_indx = 1 + 2 + 1 + (joints_num - 1) * 3 + end_indx = start_indx + (joints_num - 1) * 6 + cont6d_params = data[..., start_indx:end_indx] + cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1) + cont6d_params = cont6d_params.view(-1, joints_num, 6) + cont6d_params = torch.cat([cont6d_params, r_pos_pad], dim=-2) + return cont6d_params + + +def recover_from_ric(data, joints_num): + r_rot_quat, r_pos = recover_root_rot_pos(data) + positions = data[..., 4:(joints_num - 1) * 3 + 4] + positions = positions.view(positions.shape[:-1] + (-1, 3)) + + '''Add Y-axis rotation to local joints''' + positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) + + '''Add root XZ to joints''' + positions[..., 0] += r_pos[..., 0:1] + positions[..., 2] += r_pos[..., 2:3] + + '''Concate root and joints''' + positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) + + return positions +''' +For Text2Motion Dataset +''' +''' +if __name__ == "__main__": + example_id = "000021" + # Lower legs + l_idx1, l_idx2 = 5, 8 + # Right/Left foot + fid_r, fid_l = [8, 11], [7, 10] + # Face direction, r_hip, l_hip, sdr_r, sdr_l + face_joint_indx = [2, 1, 17, 16] + # l_hip, r_hip + r_hip, l_hip = 2, 1 + joints_num = 22 + # ds_num = 8 + data_dir = '../dataset/pose_data_raw/joints/' + save_dir1 = '../dataset/pose_data_raw/new_joints/' + save_dir2 = '../dataset/pose_data_raw/new_joint_vecs/' + + n_raw_offsets = torch.from_numpy(t2m_raw_offsets) + kinematic_chain = t2m_kinematic_chain + + # Get offsets of target skeleton + example_data = np.load(os.path.join(data_dir, example_id + '.npy')) + example_data = example_data.reshape(len(example_data), -1, 3) + example_data = torch.from_numpy(example_data) + tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu') + # (joints_num, 3) + tgt_offsets = tgt_skel.get_offsets_joints(example_data[0]) + # print(tgt_offsets) + + source_list = os.listdir(data_dir) + frame_num = 0 + for source_file in tqdm(source_list): + source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num] + try: + dataset, ground_positions, positions, l_velocity = process_file(source_data, 0.002) + rec_ric_data = recover_from_ric(torch.from_numpy(dataset).unsqueeze(0).float(), joints_num) + np.save(pjoin(save_dir1, source_file), rec_ric_data.squeeze().numpy()) + np.save(pjoin(save_dir2, source_file), dataset) + frame_num += dataset.shape[0] + except Exception as e: + print(source_file) + print(e) + + print('Total clips: %d, Frames: %d, Duration: %fm' % + (len(source_list), frame_num, frame_num / 20 / 60)) +''' + +if __name__ == "__main__": + example_id = "03950_gt" + # Lower legs + l_idx1, l_idx2 = 17, 18 + # Right/Left foot + fid_r, fid_l = [14, 15], [19, 20] + # Face direction, r_hip, l_hip, sdr_r, sdr_l + face_joint_indx = [11, 16, 5, 8] + # l_hip, r_hip + r_hip, l_hip = 11, 16 + joints_num = 21 + # ds_num = 8 + data_dir = '../dataset/kit_mocap_dataset/joints/' + save_dir1 = '../dataset/kit_mocap_dataset/new_joints/' + save_dir2 = '../dataset/kit_mocap_dataset/new_joint_vecs/' + + n_raw_offsets = torch.from_numpy(kit_raw_offsets) + kinematic_chain = kit_kinematic_chain + + '''Get offsets of target skeleton''' + example_data = np.load(os.path.join(data_dir, example_id + '.npy')) + example_data = example_data.reshape(len(example_data), -1, 3) + example_data = torch.from_numpy(example_data) + tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu') + # (joints_num, 3) + tgt_offsets = tgt_skel.get_offsets_joints(example_data[0]) + # print(tgt_offsets) + + source_list = os.listdir(data_dir) + frame_num = 0 + '''Read source dataset''' + for source_file in tqdm(source_list): + source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num] + try: + name = ''.join(source_file[:-7].split('_')) + '.npy' + data, ground_positions, positions, l_velocity = process_file(source_data, 0.05) + rec_ric_data = recover_from_ric(torch.from_numpy(data).unsqueeze(0).float(), joints_num) + if np.isnan(rec_ric_data.numpy()).any(): + print(source_file) + continue + np.save(pjoin(save_dir1, name), rec_ric_data.squeeze().numpy()) + np.save(pjoin(save_dir2, name), data) + frame_num += data.shape[0] + except Exception as e: + print(source_file) + print(e) + + print('Total clips: %d, Frames: %d, Duration: %fm' % + (len(source_list), frame_num, frame_num / 12.5 / 60)) diff --git a/mGPT/data/humanml/utils/paramUtil.py b/mGPT/data/humanml/utils/paramUtil.py new file mode 100755 index 0000000000000000000000000000000000000000..a9f1708b85ca80a9051cb3675cec9b999a0d0e2b --- /dev/null +++ b/mGPT/data/humanml/utils/paramUtil.py @@ -0,0 +1,63 @@ +import numpy as np + +# Define a kinematic tree for the skeletal struture +kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] + +kit_raw_offsets = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1] + ] +) + +t2m_raw_offsets = np.array([[0,0,0], + [1,0,0], + [-1,0,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,0,1], + [0,0,1], + [0,1,0], + [1,0,0], + [-1,0,0], + [0,0,1], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0]]) + +t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] +t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] +t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] + + +kit_tgt_skel_id = '03950' + +t2m_tgt_skel_id = '000021' + diff --git a/mGPT/data/humanml/utils/word_vectorizer.py b/mGPT/data/humanml/utils/word_vectorizer.py new file mode 100755 index 0000000000000000000000000000000000000000..d27205820c6ce17cac2e0f923808b35c0ba5f0eb --- /dev/null +++ b/mGPT/data/humanml/utils/word_vectorizer.py @@ -0,0 +1,79 @@ +import numpy as np +import pickle +from os.path import join as pjoin + +POS_enumerator = { + 'VERB': 0, + 'NOUN': 1, + 'DET': 2, + 'ADP': 3, + 'NUM': 4, + 'AUX': 5, + 'PRON': 6, + 'ADJ': 7, + 'ADV': 8, + 'Loc_VIP': 9, + 'Body_VIP': 10, + 'Obj_VIP': 11, + 'Act_VIP': 12, + 'Desc_VIP': 13, + 'OTHER': 14, +} + +Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', + 'up', 'down', 'straight', 'curve') + +Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') + +Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') + +Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', + 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', + 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') + +Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', 'angrily', 'sadly') + +VIP_dict = { + 'Loc_VIP': Loc_list, + 'Body_VIP': Body_list, + 'Obj_VIP': Obj_List, + 'Act_VIP': Act_list, + 'Desc_VIP': Desc_list, +} + + +class WordVectorizer(object): + def __init__(self, meta_root, prefix): + vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) + words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) + word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) + self.word2vec = {w: vectors[word2idx[w]] for w in words} + + def _get_pos_ohot(self, pos): + pos_vec = np.zeros(len(POS_enumerator)) + if pos in POS_enumerator: + pos_vec[POS_enumerator[pos]] = 1 + else: + pos_vec[POS_enumerator['OTHER']] = 1 + return pos_vec + + def __len__(self): + return len(self.word2vec) + + def __getitem__(self, item): + word, pos = item.split('/') + if word in self.word2vec: + word_vec = self.word2vec[word] + vip_pos = None + for key, values in VIP_dict.items(): + if word in values: + vip_pos = key + break + if vip_pos is not None: + pos_vec = self._get_pos_ohot(vip_pos) + else: + pos_vec = self._get_pos_ohot(pos) + else: + word_vec = self.word2vec['unk'] + pos_vec = self._get_pos_ohot('OTHER') + return word_vec, pos_vec diff --git a/mGPT/data/tools/__init__.py b/mGPT/data/tools/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..04a49ec769271d1c97b1a043fd04af5aa63e746d --- /dev/null +++ b/mGPT/data/tools/__init__.py @@ -0,0 +1,2 @@ +from .tensors import lengths_to_mask +from .collate import collate_text_and_length, collate_pairs_and_text, collate_datastruct_and_text, collate_tensor_with_padding diff --git a/mGPT/data/tools/collate.py b/mGPT/data/tools/collate.py new file mode 100755 index 0000000000000000000000000000000000000000..fec416cb4df7d720218bda2875ea492ade1dce09 --- /dev/null +++ b/mGPT/data/tools/collate.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import List, Dict +from torch import Tensor + + +def collate_tensor_with_padding(batch: List[Tensor]) -> Tensor: + dims = batch[0].dim() + max_size = [max([b.size(i) for b in batch]) for i in range(dims)] + size = (len(batch),) + tuple(max_size) + canvas = batch[0].new_zeros(size=size) + for i, b in enumerate(batch): + sub_tensor = canvas[i] + for d in range(dims): + sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) + sub_tensor.add_(b) + return canvas + + +def collate_datastruct_and_text(lst_elements: List) -> Dict: + collate_datastruct = lst_elements[0]["datastruct"].transforms.collate + + batch = { + # Collate with padding for the datastruct + "datastruct": collate_datastruct([x["datastruct"] for x in lst_elements]), + # Collate normally for the length + "length": [x["length"] for x in lst_elements], + # Collate the text + "text": [x["text"] for x in lst_elements]} + + # add keyid for example + otherkeys = [x for x in lst_elements[0].keys() if x not in batch] + for key in otherkeys: + batch[key] = [x[key] for x in lst_elements] + + return batch + +def collate_length_and_text(lst_elements: List) -> Dict: + + batch = { + "length_0": [x["length_0"] for x in lst_elements], + "length_1": [x["length_1"] for x in lst_elements], + "length_transition": [x["length_transition"] for x in lst_elements], + "length_1_with_transition": [x["length_1_with_transition"] for x in lst_elements], + "text_0": [x["text_0"] for x in lst_elements], + "text_1": [x["text_1"] for x in lst_elements] + } + + return batch + +def collate_pairs_and_text(lst_elements: List, ) -> Dict: + if 'features_0' not in lst_elements[0]: # test set + collate_datastruct = lst_elements[0]["datastruct"].transforms.collate + batch = {"datastruct": collate_datastruct([x["datastruct"] for x in lst_elements]), + "length_0": [x["length_0"] for x in lst_elements], + "length_1": [x["length_1"] for x in lst_elements], + "length_transition": [x["length_transition"] for x in lst_elements], + "length_1_with_transition": [x["length_1_with_transition"] for x in lst_elements], + "text_0": [x["text_0"] for x in lst_elements], + "text_1": [x["text_1"] for x in lst_elements] + } + + else: + batch = {"motion_feats_0": collate_tensor_with_padding([el["features_0"] for el in lst_elements]), + "motion_feats_1": collate_tensor_with_padding([el["features_1"] for el in lst_elements]), + "motion_feats_1_with_transition": collate_tensor_with_padding([el["features_1_with_transition"] for el in lst_elements]), + "length_0": [x["length_0"] for x in lst_elements], + "length_1": [x["length_1"] for x in lst_elements], + "length_transition": [x["length_transition"] for x in lst_elements], + "length_1_with_transition": [x["length_1_with_transition"] for x in lst_elements], + "text_0": [x["text_0"] for x in lst_elements], + "text_1": [x["text_1"] for x in lst_elements] + } + return batch + + +def collate_text_and_length(lst_elements: Dict) -> Dict: + batch = {"length": [x["length"] for x in lst_elements], + "text": [x["text"] for x in lst_elements]} + + # add keyid for example + otherkeys = [x for x in lst_elements[0].keys() if x not in batch and x != "datastruct"] + for key in otherkeys: + batch[key] = [x[key] for x in lst_elements] + return batch diff --git a/mGPT/data/tools/easyconvert.py b/mGPT/data/tools/easyconvert.py new file mode 100644 index 0000000000000000000000000000000000000000..3649a93f947d47beb872fdc3f933d0b81fc56b37 --- /dev/null +++ b/mGPT/data/tools/easyconvert.py @@ -0,0 +1,72 @@ +from .geometry import * + +def nfeats_of(rottype): + if rottype in ["rotvec", "axisangle"]: + return 3 + elif rottype in ["rotquat", "quaternion"]: + return 4 + elif rottype in ["rot6d", "6drot", "rotation6d"]: + return 6 + elif rottype in ["rotmat"]: + return 9 + else: + return TypeError("This rotation type doesn't have features.") + + +def axis_angle_to(newtype, rotations): + if newtype in ["matrix"]: + rotations = axis_angle_to_matrix(rotations) + return rotations + elif newtype in ["rotmat"]: + rotations = axis_angle_to_matrix(rotations) + rotations = matrix_to("rotmat", rotations) + return rotations + elif newtype in ["rot6d", "6drot", "rotation6d"]: + rotations = axis_angle_to_matrix(rotations) + rotations = matrix_to("rot6d", rotations) + return rotations + elif newtype in ["rotquat", "quaternion"]: + rotations = axis_angle_to_quaternion(rotations) + return rotations + elif newtype in ["rotvec", "axisangle"]: + return rotations + else: + raise NotImplementedError + + +def matrix_to(newtype, rotations): + if newtype in ["matrix"]: + return rotations + if newtype in ["rotmat"]: + rotations = rotations.reshape((*rotations.shape[:-2], 9)) + return rotations + elif newtype in ["rot6d", "6drot", "rotation6d"]: + rotations = matrix_to_rotation_6d(rotations) + return rotations + elif newtype in ["rotquat", "quaternion"]: + rotations = matrix_to_quaternion(rotations) + return rotations + elif newtype in ["rotvec", "axisangle"]: + rotations = matrix_to_axis_angle(rotations) + return rotations + else: + raise NotImplementedError + + +def to_matrix(oldtype, rotations): + if oldtype in ["matrix"]: + return rotations + if oldtype in ["rotmat"]: + rotations = rotations.reshape((*rotations.shape[:-2], 3, 3)) + return rotations + elif oldtype in ["rot6d", "6drot", "rotation6d"]: + rotations = rotation_6d_to_matrix(rotations) + return rotations + elif oldtype in ["rotquat", "quaternion"]: + rotations = quaternion_to_matrix(rotations) + return rotations + elif oldtype in ["rotvec", "axisangle"]: + rotations = axis_angle_to_matrix(rotations) + return rotations + else: + raise NotImplementedError diff --git a/mGPT/data/tools/geometry.py b/mGPT/data/tools/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..e6eafa2e1f2459a0f6f5ad1280c71e6a9625549e --- /dev/null +++ b/mGPT/data/tools/geometry.py @@ -0,0 +1,566 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Check PYTORCH3D_LICENCE before use + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +# Added +def matrix_of_angles(cos, sin, inv=False, dim=2): + assert dim in [2, 3] + sin = -sin if inv else sin + if dim == 2: + row1 = torch.stack((cos, -sin), axis=-1) + row2 = torch.stack((sin, cos), axis=-1) + return torch.stack((row1, row2), axis=-2) + elif dim == 3: + row1 = torch.stack((cos, -sin, 0*cos), axis=-1) + row2 = torch.stack((sin, cos, 0*cos), axis=-1) + row3 = torch.stack((0*sin, 0*cos, 1+0*cos), axis=-1) + return torch.stack((row1, row2, row3),axis=-2) + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/mGPT/data/tools/tensors.py b/mGPT/data/tools/tensors.py new file mode 100755 index 0000000000000000000000000000000000000000..6bcc051117961dff9dce513e679c50ddb1d327b7 --- /dev/null +++ b/mGPT/data/tools/tensors.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import List, Dict +import torch +from torch import Tensor + + +def lengths_to_mask(lengths: List[int], device: torch.device) -> Tensor: + lengths = torch.tensor(lengths, device=device) + max_len = max(lengths) + mask = torch.arange(max_len, device=device).expand(len(lengths), max_len) < lengths.unsqueeze(1) + return mask diff --git a/mGPT/data/transforms/__init__.py b/mGPT/data/transforms/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..394dfc566af25c5f7e16a1469f2b8bb625c04a57 --- /dev/null +++ b/mGPT/data/transforms/__init__.py @@ -0,0 +1,15 @@ +from .base import Transform +from .smpl import SMPLTransform +from .xyz import XYZTransform + +# rots2rfeats +from .rots2rfeats import Rots2Rfeats +from .rots2rfeats import Globalvelandy + +# rots2joints +from .rots2joints import Rots2Joints +from .rots2joints import SMPLH, SMPLX + +# joints2jfeats +from .joints2jfeats import Joints2Jfeats +from .joints2jfeats import Rifke diff --git a/mGPT/data/transforms/base.py b/mGPT/data/transforms/base.py new file mode 100755 index 0000000000000000000000000000000000000000..1c60a6021e8dda4c27bdd8365ba2e298ae6acf76 --- /dev/null +++ b/mGPT/data/transforms/base.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from dataclasses import dataclass, fields + + +class Transform: + + def collate(self, lst_datastruct): + from ..tools import collate_tensor_with_padding + example = lst_datastruct[0] + + def collate_or_none(key): + if example[key] is None: + return None + key_lst = [x[key] for x in lst_datastruct] + return collate_tensor_with_padding(key_lst) + + kwargs = {key: collate_or_none(key) for key in example.datakeys} + + return self.Datastruct(**kwargs) + + +# Inspired from SMPLX library +# need to define "datakeys" and transforms +@dataclass +class Datastruct: + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + self.__dict__[key] = value + + def get(self, key, default=None): + return getattr(self, key, default) + + def __iter__(self): + return self.keys() + + def keys(self): + keys = [t.name for t in fields(self)] + return iter(keys) + + def values(self): + values = [getattr(self, t.name) for t in fields(self)] + return iter(values) + + def items(self): + data = [(t.name, getattr(self, t.name)) for t in fields(self)] + return iter(data) + + def to(self, *args, **kwargs): + for key in self.datakeys: + if self[key] is not None: + self[key] = self[key].to(*args, **kwargs) + return self + + @property + def device(self): + return self[self.datakeys[0]].device + + def detach(self): + + def detach_or_none(tensor): + if tensor is not None: + return tensor.detach() + return None + + kwargs = {key: detach_or_none(self[key]) for key in self.datakeys} + return self.transforms.Datastruct(**kwargs) diff --git a/mGPT/data/transforms/identity.py b/mGPT/data/transforms/identity.py new file mode 100755 index 0000000000000000000000000000000000000000..ec12e7ff04e8c2f18d889ceb64fc19189b52231c --- /dev/null +++ b/mGPT/data/transforms/identity.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import Optional +from torch import Tensor + +from .base import Datastruct, dataclass, Transform + + +class IdentityTransform(Transform): + def __init__(self, **kwargs): + return + + def Datastruct(self, **kwargs): + return IdentityDatastruct(**kwargs) + + def __repr__(self): + return "IdentityTransform()" + + +@dataclass +class IdentityDatastruct(Datastruct): + transforms: IdentityTransform + + features: Optional[Tensor] = None + + def __post_init__(self): + self.datakeys = ["features"] + + def __len__(self): + return len(self.rfeats) diff --git a/mGPT/data/transforms/joints2jfeats/__init__.py b/mGPT/data/transforms/joints2jfeats/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..0a924e845912842ec042b5b3195b8da7aee3f252 --- /dev/null +++ b/mGPT/data/transforms/joints2jfeats/__init__.py @@ -0,0 +1,2 @@ +from .base import Joints2Jfeats +from .rifke import Rifke diff --git a/mGPT/data/transforms/joints2jfeats/base.py b/mGPT/data/transforms/joints2jfeats/base.py new file mode 100755 index 0000000000000000000000000000000000000000..03d6f5fb10bf34cc98df41402a6e4f875e6c28af --- /dev/null +++ b/mGPT/data/transforms/joints2jfeats/base.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import Optional + +import torch +from torch import Tensor, nn +from pathlib import Path +import os + + +class Joints2Jfeats(nn.Module): + + def __init__(self, + path: Optional[str] = None, + normalization: bool = False, + eps: float = 1e-12, + **kwargs) -> None: + if normalization and path is None: + raise TypeError( + "You should provide a path if normalization is on.") + + super().__init__() + self.normalization = normalization + self.eps = eps + # workaround for cluster local/sync + if path is not None: + # rel_p = path.split('/') + # rel_p = rel_p[rel_p.index('deps'):] + # rel_p = '/'.join(rel_p) + pass + if normalization: + mean_path = Path(path) / "jfeats_mean.pt" + std_path = Path(path) / "jfeats_std.pt" + self.register_buffer('mean', torch.load(mean_path)) + self.register_buffer('std', torch.load(std_path)) + + def normalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = (features - self.mean) / (self.std + self.eps) + return features + + def unnormalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = features * self.std + self.mean + return features diff --git a/mGPT/data/transforms/joints2jfeats/rifke.py b/mGPT/data/transforms/joints2jfeats/rifke.py new file mode 100755 index 0000000000000000000000000000000000000000..c6f2a8e83d645a1fffaf3e4fc842b81f89b0950f --- /dev/null +++ b/mGPT/data/transforms/joints2jfeats/rifke.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import Optional + +import torch +from einops import rearrange +from torch import Tensor +from .tools import get_forward_direction, get_floor, gaussian_filter1d # noqa +from mGPT.utils.geometry_tools import matrix_of_angles +from .base import Joints2Jfeats + + +class Rifke(Joints2Jfeats): + + def __init__(self, + jointstype: str = "mmm", + path: Optional[str] = None, + normalization: bool = False, + forward_filter: bool = False, + **kwargs) -> None: + # + # if jointstype != "mmm": + # print("This function assume that the root is the first index") + # raise NotImplementedError("This jointstype is not implemented.") + + super().__init__(path=path, normalization=normalization) + self.jointstype = jointstype + self.forward_filter = forward_filter + + def forward(self, joints: Tensor) -> Tensor: + # Joints to rotation invariant poses (Holden et. al.) + # Similar function than fke2rifke in Language2Pose repository + # Adapted to pytorch + # Put the origin center of the root joint instead of the ground projection + poses = joints.clone() + poses[..., 1] -= get_floor(poses, jointstype=self.jointstype) + + translation = poses[..., 0, :].clone() + # Let the root have the Y translation --> gravity axis + root_y = translation[..., 1] + + # Trajectory => Translation without gravity axis (Y) + trajectory = translation[..., [0, 2]] + + # Delete the root joints of the poses + poses = poses[..., 1:, :] + + # Remove the trajectory of the poses + poses[..., [0, 2]] -= trajectory[..., None, :] + + # Compute the trajectory + vel_trajectory = torch.diff(trajectory, dim=-2) + # 0 for the first one => keep the dimentionality + vel_trajectory = torch.cat( + (0 * vel_trajectory[..., [0], :], vel_trajectory), dim=-2) + + # Compute the forward direction + forward = get_forward_direction(poses, jointstype=self.jointstype) + if self.forward_filter: + # Smoothing to remove high frequencies + forward = gaussian_filter1d(forward, 2) + # normalize again to get real directions + forward = torch.nn.functional.normalize(forward, dim=-1) + # changed this also for New pytorch + angles = torch.atan2(*(forward.transpose(0, -1))).transpose(0, -1) + vel_angles = torch.diff(angles, dim=-1) + # 0 for the first one => keep the dimentionality + vel_angles = torch.cat((0 * vel_angles[..., [0]], vel_angles), dim=-1) + + # Construct the inverse rotation matrix + sin, cos = forward[..., 0], forward[..., 1] + rotations_inv = matrix_of_angles(cos, sin, inv=True) + + # Rotate the poses + poses_local = torch.einsum("...lj,...jk->...lk", poses[..., [0, 2]], + rotations_inv) + poses_local = torch.stack( + (poses_local[..., 0], poses[..., 1], poses_local[..., 1]), axis=-1) + + # stack the xyz joints into feature vectors + poses_features = rearrange(poses_local, + "... joints xyz -> ... (joints xyz)") + + # Rotate the vel_trajectory + vel_trajectory_local = torch.einsum("...j,...jk->...k", vel_trajectory, + rotations_inv) + + # Stack things together + features = torch.cat((root_y[..., None], poses_features, + vel_angles[..., None], vel_trajectory_local), -1) + + # Normalize if needed + features = self.normalize(features) + return features + + def inverse(self, features: Tensor) -> Tensor: + features = self.unnormalize(features) + root_y, poses_features, vel_angles, vel_trajectory_local = self.extract( + features) + + # already have the good dimensionality + angles = torch.cumsum(vel_angles, dim=-1) + # First frame should be 0, but if infered it is better to ensure it + angles = angles - angles[..., [0]] + + cos, sin = torch.cos(angles), torch.sin(angles) + rotations = matrix_of_angles(cos, sin, inv=False) + + # Get back the poses + poses_local = rearrange(poses_features, + "... (joints xyz) -> ... joints xyz", + xyz=3) + + # Rotate the poses + poses = torch.einsum("...lj,...jk->...lk", poses_local[..., [0, 2]], + rotations) + poses = torch.stack( + (poses[..., 0], poses_local[..., 1], poses[..., 1]), axis=-1) + + # Rotate the vel_trajectory + vel_trajectory = torch.einsum("...j,...jk->...k", vel_trajectory_local, + rotations) + # Integrate the trajectory + # Already have the good dimensionality + trajectory = torch.cumsum(vel_trajectory, dim=-2) + # First frame should be 0, but if infered it is better to ensure it + trajectory = trajectory - trajectory[..., [0], :] + + # Add the root joints (which is still zero) + poses = torch.cat((0 * poses[..., [0], :], poses), -2) + + # put back the root joint y + poses[..., 0, 1] = root_y + + # Add the trajectory globally + poses[..., [0, 2]] += trajectory[..., None, :] + return poses + + def extract(self, features: Tensor): + root_y = features[..., 0] + poses_features = features[..., 1:-3] + vel_angles = features[..., -3] + vel_trajectory_local = features[..., -2:] + + return root_y, poses_features, vel_angles, vel_trajectory_local diff --git a/mGPT/data/transforms/joints2jfeats/tools.py b/mGPT/data/transforms/joints2jfeats/tools.py new file mode 100755 index 0000000000000000000000000000000000000000..734e109d202c905a3e56f7e204d6512dd7de9018 --- /dev/null +++ b/mGPT/data/transforms/joints2jfeats/tools.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import torch +import torch.nn.functional as F + +from mGPT.utils.joints import mmm_joints + +# Get the indexes of particular body part SMPLH case +# Feet +# LM, RM = smplh_joints.index("left_ankle"), smplh_joints.index("right_ankle") +# LF, RF = smplh_joints.index("left_foot"), smplh_joints.index("right_foot") +# # Shoulders +# LS, RS = smplh_joints.index("left_shoulder"), smplh_joints.index("right_shoulder") +# # Hips +# LH, RH = smplh_joints.index("left_hip"), smplh_joints.index("right_hip") + +# Get the indexes of particular body part +# Feet +LM, RM = mmm_joints.index("LMrot"), mmm_joints.index("RMrot") +LF, RF = mmm_joints.index("LF"), mmm_joints.index("RF") +# Shoulders +LS, RS = mmm_joints.index("LS"), mmm_joints.index("RS") +# Hips +LH, RH = mmm_joints.index("LH"), mmm_joints.index("RH") + + +def get_forward_direction(poses, jointstype="mmm"): + # assert jointstype == 'mmm' + across = poses[..., RH, :] - poses[..., LH, :] + poses[..., RS, :] - poses[ + ..., LS, :] + forward = torch.stack((-across[..., 2], across[..., 0]), axis=-1) + forward = torch.nn.functional.normalize(forward, dim=-1) + return forward + + +def get_floor(poses, jointstype="mmm"): + # assert jointstype == 'mmm' + ndim = len(poses.shape) + foot_heights = poses[..., (LM, LF, RM, RF), 1].min(-1).values + floor_height = softmin(foot_heights, softness=0.5, dim=-1) + # changed this thing Mathis version 1.11 pytorch + return floor_height[(ndim - 2) * [None]].transpose(0, -1) + + +def softmax(x, softness=1.0, dim=None): + maxi, mini = x.max(dim=dim).values, x.min(dim=dim).values + return maxi + torch.log(softness + torch.exp(mini - maxi)) + + +def softmin(x, softness=1.0, dim=0): + return -softmax(-x, softness=softness, dim=dim) + + +def gaussian_filter1d(_inputs, sigma, truncate=4.0): + # Code adapted/mixed from scipy library into pytorch + # https://github.com/scipy/scipy/blob/47bb6febaa10658c72962b9615d5d5aa2513fa3a/scipy/ndimage/filters.py#L211 + # and gaussian kernel + # https://github.com/scipy/scipy/blob/47bb6febaa10658c72962b9615d5d5aa2513fa3a/scipy/ndimage/filters.py#L179 + # Correspond to mode="nearest" and order = 0 + # But works batched + if len(_inputs.shape) == 2: + inputs = _inputs[None] + else: + inputs = _inputs + + sd = float(sigma) + radius = int(truncate * sd + 0.5) + sigma2 = sigma * sigma + x = torch.arange(-radius, + radius + 1, + device=inputs.device, + dtype=inputs.dtype) + phi_x = torch.exp(-0.5 / sigma2 * x**2) + phi_x = phi_x / phi_x.sum() + + # Conv1d weights + groups = inputs.shape[-1] + weights = torch.tile(phi_x, (groups, 1, 1)) + inputs = inputs.transpose(-1, -2) + outputs = F.conv1d(inputs, weights, padding="same", + groups=groups).transpose(-1, -2) + + return outputs.reshape(_inputs.shape) diff --git a/mGPT/data/transforms/joints2rots/config.py b/mGPT/data/transforms/joints2rots/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9014befa889a0c132a05df36d6aa5a6cad4d9e08 --- /dev/null +++ b/mGPT/data/transforms/joints2rots/config.py @@ -0,0 +1,119 @@ +import numpy as np +from mGPT.utils.joints import mmm_joints, smplh2mmm_indexes + +# Map joints Name to SMPL joints idx +JOINT_MAP = { + 'MidHip': 0, + 'LHip': 1, + 'LKnee': 4, + 'LAnkle': 7, + 'LFoot': 10, + 'RHip': 2, + 'RKnee': 5, + 'RAnkle': 8, + 'RFoot': 11, + 'LShoulder': 16, + 'LElbow': 18, + 'LWrist': 20, + 'LHand': 22, + 'RShoulder': 17, + 'RElbow': 19, + 'RWrist': 21, + 'RHand': 23, + 'spine1': 3, + 'spine2': 6, + 'spine3': 9, + 'Neck': 12, + 'Head': 15, + 'LCollar': 13, + 'Rcollar': 14, + 'Nose': 24, + 'REye': 26, + 'LEye': 26, + 'REar': 27, + 'LEar': 28, + 'LHeel': 31, + 'RHeel': 34, + 'OP RShoulder': 17, + 'OP LShoulder': 16, + 'OP RHip': 2, + 'OP LHip': 1, + 'OP Neck': 12, +} + +mmm2smpl_correspondence = { + "root": "MidHip", + "BP": "spine1", + "BT": "spine3", + "BLN": "Neck", + "BUN": "Head", + "LS": "LShoulder", + "LE": "LElbow", + "LW": "LWrist", + "RS": "RShoulder", + "RE": "RElbow", + "RW": "RWrist", + "LH": "LHip", + "LK": "LKnee", + "LA": "LAnkle", + "LMrot": "LHeel", + "LF": "LFoot", + "RH": "RHip", + "RK": "RKnee", + "RA": "RAnkle", + "RMrot": "RHeel", + "RF": "RFoot" +} + +full_smpl_idx = range(24) +key_smpl_idx = [0, 1, 4, 7, 2, 5, 8, 17, 19, 21, 16, 18, 20] + +AMASS_JOINT_MAP = { + 'MidHip': 0, + 'LHip': 1, + 'LKnee': 4, + 'LAnkle': 7, + 'LFoot': 10, + 'RHip': 2, + 'RKnee': 5, + 'RAnkle': 8, + 'RFoot': 11, + 'LShoulder': 16, + 'LElbow': 18, + 'LWrist': 20, + 'RShoulder': 17, + 'RElbow': 19, + 'RWrist': 21, + 'spine1': 3, + 'spine2': 6, + 'spine3': 9, + 'Neck': 12, + 'Head': 15, + 'LCollar': 13, + 'Rcollar': 14, +} +amass_idx = range(22) +amass_smpl_idx = range(22) + +# cal mmm in smpl index +smpl2mmm_correspondence = { + val: key + for key, val in mmm2smpl_correspondence.items() +} +smpl2mmm_indexes = [JOINT_MAP[mmm2smpl_correspondence[x]] for x in mmm_joints] + +# cal mmm joints map +MMM_JOINT_MAP = { + val: JOINT_MAP[val] + for key, val in mmm2smpl_correspondence.items() +} + +# mmm_idx = range(21) +# mmm_smpl_dix = smpl2mmm_indexes +# mmm_smpl_dix = smplh2mmm_indexes +# todo - configable +SMPL_MODEL_DIR = "/apdcephfs/share_1227775/shingxchen/AIMotion/TMOSTData/deps/smpl_models/" +GMM_MODEL_DIR = "/apdcephfs/share_1227775/shingxchen/AIMotion/TMOSTData/deps/smpl_models/" +SMPL_MEAN_FILE = "/apdcephfs/share_1227775/shingxchen/AIMotion/TMOSTData/deps/smpl_models/neutral_smpl_mean_params.h5" +# for collsion +Part_Seg_DIR = "/apdcephfs/share_1227775/shingxchen/AIMotion/TMOSTData/deps/smpl_models/smplx_parts_segm.pkl" diff --git a/mGPT/data/transforms/joints2rots/customloss.py b/mGPT/data/transforms/joints2rots/customloss.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3c3a530876113596f223324dc9dd0c002fd520 --- /dev/null +++ b/mGPT/data/transforms/joints2rots/customloss.py @@ -0,0 +1,217 @@ +import torch +import torch.nn.functional as F +import config + +# Guassian +def gmof(x, sigma): + """ + Geman-McClure error function + """ + x_squared = x ** 2 + sigma_squared = sigma ** 2 + return (sigma_squared * x_squared) / (sigma_squared + x_squared) + +# angle prior +def angle_prior(pose): + """ + Angle prior that penalizes unnatural bending of the knees and elbows + """ + # We subtract 3 because pose does not include the global rotation of the model + return torch.exp( + pose[:, [55 - 3, 58 - 3, 12 - 3, 15 - 3]] * torch.tensor([1., -1., -1, -1.], device=pose.device)) ** 2 + + +def perspective_projection(points, rotation, translation, + focal_length, camera_center): + """ + This function computes the perspective projection of a set of points. + Input: + points (bs, N, 3): 3D points + rotation (bs, 3, 3): Camera rotation + translation (bs, 3): Camera translation + focal_length (bs,) or scalar: Focal length + camera_center (bs, 2): Camera center + """ + batch_size = points.shape[0] + K = torch.zeros([batch_size, 3, 3], device=points.device) + K[:, 0, 0] = focal_length + K[:, 1, 1] = focal_length + K[:, 2, 2] = 1. + K[:, :-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:, :, -1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + return projected_points[:, :, :-1] + + +def body_fitting_loss(body_pose, betas, model_joints, camera_t, camera_center, + joints_2d, joints_conf, pose_prior, + focal_length=5000, sigma=100, pose_prior_weight=4.78, + shape_prior_weight=5, angle_prior_weight=15.2, + output='sum'): + """ + Loss function for body fitting + """ + batch_size = body_pose.shape[0] + rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand(batch_size, -1, -1) + + projected_joints = perspective_projection(model_joints, rotation, camera_t, + focal_length, camera_center) + + # Weighted robust reprojection error + reprojection_error = gmof(projected_joints - joints_2d, sigma) + reprojection_loss = (joints_conf ** 2) * reprojection_error.sum(dim=-1) + + # Pose prior loss + pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) + + # Angle prior for knees and elbows + angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) + + # Regularizer to prevent betas from taking large values + shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) + + total_loss = reprojection_loss.sum(dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss + + if output == 'sum': + return total_loss.sum() + elif output == 'reprojection': + return reprojection_loss + + +# --- get camera fitting loss ----- +def camera_fitting_loss(model_joints, camera_t, camera_t_est, camera_center, + joints_2d, joints_conf, + focal_length=5000, depth_loss_weight=100): + """ + Loss function for camera optimization. + """ + # Project model joints + batch_size = model_joints.shape[0] + rotation = torch.eye(3, device=model_joints.device).unsqueeze(0).expand(batch_size, -1, -1) + projected_joints = perspective_projection(model_joints, rotation, camera_t, + focal_length, camera_center) + + # get the indexed four + op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder'] + op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints] + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + reprojection_error_op = (joints_2d[:, op_joints_ind] - + projected_joints[:, op_joints_ind]) ** 2 + reprojection_error_gt = (joints_2d[:, gt_joints_ind] - + projected_joints[:, gt_joints_ind]) ** 2 + + # Check if for each example in the batch all 4 OpenPose detections are valid, otherwise use the GT detections + # OpenPose joints are more reliable for this task, so we prefer to use them if possible + is_valid = (joints_conf[:, op_joints_ind].min(dim=-1)[0][:, None, None] > 0).float() + reprojection_loss = (is_valid * reprojection_error_op + (1 - is_valid) * reprojection_error_gt).sum(dim=(1, 2)) + + # Loss that penalizes deviation from depth estimate + depth_loss = (depth_loss_weight ** 2) * (camera_t[:, 2] - camera_t_est[:, 2]) ** 2 + + total_loss = reprojection_loss + depth_loss + return total_loss.sum() + + + + # #####--- body fitiing loss ----- +def body_fitting_loss_3d(body_pose, preserve_pose, + betas, model_joints, camera_translation, + j3d, pose_prior, + joints3d_conf, + sigma=100, pose_prior_weight=4.78*1.5, + shape_prior_weight=5.0, angle_prior_weight=15.2, + joint_loss_weight=500.0, + pose_preserve_weight=0.0, + use_collision=False, + model_vertices=None, model_faces=None, + search_tree=None, pen_distance=None, filter_faces=None, + collision_loss_weight=1000 + ): + """ + Loss function for body fitting + """ + batch_size = body_pose.shape[0] + + #joint3d_loss = (joint_loss_weight ** 2) * gmof((model_joints + camera_translation) - j3d, sigma).sum(dim=-1) + + joint3d_error = gmof((model_joints + camera_translation) - j3d, sigma) + + joint3d_loss_part = (joints3d_conf ** 2) * joint3d_error.sum(dim=-1) + joint3d_loss = (joint_loss_weight ** 2) * joint3d_loss_part + + # Pose prior loss + pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) + # Angle prior for knees and elbows + angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) + # Regularizer to prevent betas from taking large values + shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) + + collision_loss = 0.0 + # Calculate the loss due to interpenetration + if use_collision: + triangles = torch.index_select( + model_vertices, 1, + model_faces).view(batch_size, -1, 3, 3) + + with torch.no_grad(): + collision_idxs = search_tree(triangles) + + # Remove unwanted collisions + if filter_faces is not None: + collision_idxs = filter_faces(collision_idxs) + + if collision_idxs.ge(0).sum().item() > 0: + collision_loss = torch.sum(collision_loss_weight * pen_distance(triangles, collision_idxs)) + + pose_preserve_loss = (pose_preserve_weight ** 2) * ((body_pose - preserve_pose) ** 2).sum(dim=-1) + + total_loss = joint3d_loss + pose_prior_loss + angle_prior_loss + shape_prior_loss + collision_loss + pose_preserve_loss + + return total_loss.sum() + + +# #####--- get camera fitting loss ----- +def camera_fitting_loss_3d(model_joints, camera_t, camera_t_est, + j3d, joints_category="orig", depth_loss_weight=100.0): + """ + Loss function for camera optimization. + """ + model_joints = model_joints + camera_t + # # get the indexed four + # op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder'] + # op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints] + # + # j3d_error_loss = (j3d[:, op_joints_ind] - + # model_joints[:, op_joints_ind]) ** 2 + + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category=="orig": + select_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="AMASS": + select_joints_ind = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="MMM": + select_joints_ind = [config.MMM_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + j3d_error_loss = (j3d[:, select_joints_ind] - + model_joints[:, gt_joints_ind]) ** 2 + + # Loss that penalizes deviation from depth estimate + depth_loss = (depth_loss_weight**2) * (camera_t - camera_t_est)**2 + + total_loss = j3d_error_loss + depth_loss + return total_loss.sum() \ No newline at end of file diff --git a/mGPT/data/transforms/joints2rots/prior.py b/mGPT/data/transforms/joints2rots/prior.py new file mode 100644 index 0000000000000000000000000000000000000000..d85debddd185d44082f6ac14fdaa606d4deebd40 --- /dev/null +++ b/mGPT/data/transforms/joints2rots/prior.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import sys +import os + +import time +import pickle + +import numpy as np + +import torch +import torch.nn as nn + +DEFAULT_DTYPE = torch.float32 + + +def create_prior(prior_type, **kwargs): + if prior_type == 'gmm': + prior = MaxMixturePrior(**kwargs) + elif prior_type == 'l2': + return L2Prior(**kwargs) + elif prior_type == 'angle': + return SMPLifyAnglePrior(**kwargs) + elif prior_type == 'none' or prior_type is None: + # Don't use any pose prior + def no_prior(*args, **kwargs): + return 0.0 + prior = no_prior + else: + raise ValueError('Prior {}'.format(prior_type) + ' is not implemented') + return prior + + +class SMPLifyAnglePrior(nn.Module): + def __init__(self, dtype=torch.float32, **kwargs): + super(SMPLifyAnglePrior, self).__init__() + + # Indices for the roration angle of + # 55: left elbow, 90deg bend at -np.pi/2 + # 58: right elbow, 90deg bend at np.pi/2 + # 12: left knee, 90deg bend at np.pi/2 + # 15: right knee, 90deg bend at np.pi/2 + angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64) + angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long) + self.register_buffer('angle_prior_idxs', angle_prior_idxs) + + angle_prior_signs = np.array([1, -1, -1, -1], + dtype=np.float6432 if dtype == torch.float32 + else np.float6464) + angle_prior_signs = torch.tensor(angle_prior_signs, + dtype=dtype) + self.register_buffer('angle_prior_signs', angle_prior_signs) + + def forward(self, pose, with_global_pose=False): + ''' Returns the angle prior loss for the given pose + Args: + pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle + representation of the rotations of the joints of the SMPL model. + Kwargs: + with_global_pose: Whether the pose vector also contains the global + orientation of the SMPL model. If not then the indices must be + corrected. + Returns: + A sze (B) tensor containing the angle prior loss for each element + in the batch. + ''' + angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3 + return torch.exp(pose[:, angle_prior_idxs] * + self.angle_prior_signs).pow(2) + + +class L2Prior(nn.Module): + def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs): + super(L2Prior, self).__init__() + + def forward(self, module_input, *args): + return torch.sum(module_input.pow(2)) + + +class MaxMixturePrior(nn.Module): + + def __init__(self, prior_folder='prior', + num_gaussians=6, dtype=DEFAULT_DTYPE, epsilon=1e-16, + use_merged=True, + **kwargs): + super(MaxMixturePrior, self).__init__() + + if dtype == DEFAULT_DTYPE: + np_dtype = np.float6432 + elif dtype == torch.float64: + np_dtype = np.float6464 + else: + print('Unknown float type {}, exiting!'.format(dtype)) + sys.exit(-1) + + self.num_gaussians = num_gaussians + self.epsilon = epsilon + self.use_merged = use_merged + gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) + + full_gmm_fn = os.path.join(prior_folder, gmm_fn) + if not os.path.exists(full_gmm_fn): + print('The path to the mixture prior "{}"'.format(full_gmm_fn) + + ' does not exist, exiting!') + sys.exit(-1) + + with open(full_gmm_fn, 'rb') as f: + gmm = pickle.load(f, encoding='latin1') + + if type(gmm) == dict: + means = gmm['means'].astype(np_dtype) + covs = gmm['covars'].astype(np_dtype) + weights = gmm['weights'].astype(np_dtype) + elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): + means = gmm.means_.astype(np_dtype) + covs = gmm.covars_.astype(np_dtype) + weights = gmm.weights_.astype(np_dtype) + else: + print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) + sys.exit(-1) + + self.register_buffer('means', torch.tensor(means, dtype=dtype)) + + self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) + + precisions = [np.linalg.inv(cov) for cov in covs] + precisions = np.stack(precisions).astype(np_dtype) + + self.register_buffer('precisions', + torch.tensor(precisions, dtype=dtype)) + + # The constant term: + sqrdets = np.array([(np.sqrt(np.linalg.det(c))) + for c in gmm['covars']]) + const = (2 * np.pi)**(69 / 2.) + + nll_weights = np.asarray(gmm['weights'] / (const * + (sqrdets / sqrdets.min()))) + nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) + self.register_buffer('nll_weights', nll_weights) + + weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) + self.register_buffer('weights', weights) + + self.register_buffer('pi_term', + torch.log(torch.tensor(2 * np.pi, dtype=dtype))) + + cov_dets = [np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) + for cov in covs] + self.register_buffer('cov_dets', + torch.tensor(cov_dets, dtype=dtype)) + + # The dimensionality of the random variable + self.random_var_dim = self.means.shape[1] + + def get_mean(self): + ''' Returns the mean of the mixture ''' + mean_pose = torch.matmul(self.weights, self.means) + return mean_pose + + def merged_log_likelihood(self, pose, betas): + diff_from_mean = pose.unsqueeze(dim=1) - self.means + + prec_diff_prod = torch.einsum('mij,bmj->bmi', + [self.precisions, diff_from_mean]) + diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) + + curr_loglikelihood = 0.5 * diff_prec_quadratic - \ + torch.log(self.nll_weights) + # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) + + # self.random_var_dim * self.pi_term + + # diff_prec_quadratic + # ) - torch.log(self.weights) + + min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) + return min_likelihood + + def log_likelihood(self, pose, betas, *args, **kwargs): + ''' Create graph operation for negative log-likelihood calculation + ''' + likelihoods = [] + + for idx in range(self.num_gaussians): + mean = self.means[idx] + prec = self.precisions[idx] + cov = self.covs[idx] + diff_from_mean = pose - mean + + curr_loglikelihood = torch.einsum('bj,ji->bi', + [diff_from_mean, prec]) + curr_loglikelihood = torch.einsum('bi,bi->b', + [curr_loglikelihood, + diff_from_mean]) + cov_term = torch.log(torch.det(cov) + self.epsilon) + curr_loglikelihood += 0.5 * (cov_term + + self.random_var_dim * + self.pi_term) + likelihoods.append(curr_loglikelihood) + + log_likelihoods = torch.stack(likelihoods, dim=1) + min_idx = torch.argmin(log_likelihoods, dim=1) + weight_component = self.nll_weights[:, min_idx] + weight_component = -torch.log(weight_component) + + return weight_component + log_likelihoods[:, min_idx] + + def forward(self, pose, betas): + if self.use_merged: + return self.merged_log_likelihood(pose, betas) + else: + return self.log_likelihood(pose, betas) diff --git a/mGPT/data/transforms/joints2rots/smplify.py b/mGPT/data/transforms/joints2rots/smplify.py new file mode 100644 index 0000000000000000000000000000000000000000..7df51503a4a46a479a508c9fdf362cb063b93742 --- /dev/null +++ b/mGPT/data/transforms/joints2rots/smplify.py @@ -0,0 +1,284 @@ +import torch +import os, sys +import pickle +import smplx +import numpy as np +from tqdm import tqdm + +sys.path.append(os.path.dirname(__file__)) +from customloss import (camera_fitting_loss, + body_fitting_loss, + camera_fitting_loss_3d, + body_fitting_loss_3d, + ) +from prior import MaxMixturePrior +import config + + + +@torch.no_grad() +def guess_init_3d(model_joints, + j3d, + joints_category="orig"): + """Initialize the camera translation via triangle similarity, by using the torso joints . + :param model_joints: SMPL model with pre joints + :param j3d: 25x3 array of Kinect Joints + :returns: 3D vector corresponding to the estimated camera translation + """ + # get the indexed four + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category=="orig": + joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="AMASS": + joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="MMM": + joints_ind_category = [config.MMM_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum(dim=1) + init_t = sum_init_t / 4.0 + return init_t + + +# SMPLIfy 3D +class SMPLify3D(): + """Implementation of SMPLify, use 3D joints.""" + + def __init__(self, + smplxmodel, + step_size=1e-2, + batch_size=1, + num_iters=100, + use_collision=False, + use_lbfgs=True, + joints_category="orig", + device=torch.device('cuda:0'), + ): + + # Store options + self.batch_size = batch_size + self.device = device + self.step_size = step_size + + self.num_iters = num_iters + # --- choose optimizer + self.use_lbfgs = use_lbfgs + # GMM pose prior + self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR, + num_gaussians=8, + dtype=torch.float32).to(device) + # collision part + self.use_collision = use_collision + if self.use_collision: + self.part_segm_fn = config.Part_Seg_DIR + + # reLoad SMPL-X model + self.smpl = smplxmodel + + self.model_faces = smplxmodel.faces_tensor.view(-1) + + # select joint joint_category + self.joints_category = joints_category + + if joints_category=="orig": + self.smpl_index = config.full_smpl_idx + self.corr_index = config.full_smpl_idx + elif joints_category=="AMASS": + self.smpl_index = config.amass_smpl_idx + self.corr_index = config.amass_idx + # elif joints_category=="MMM": + # self.smpl_index = config.mmm_smpl_dix + # self.corr_index = config.mmm_idx + else: + self.smpl_index = None + self.corr_index = None + print("NO SUCH JOINTS CATEGORY!") + + # ---- get the man function here ------ + def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, seq_ind=0): + """Perform body fitting. + Input: + init_pose: SMPL pose estimate + init_betas: SMPL betas estimate + init_cam_t: Camera translation estimate + j3d: joints 3d aka keypoints + conf_3d: confidence for 3d joints + seq_ind: index of the sequence + Returns: + vertices: Vertices of optimized shape + joints: 3D joints of optimized shape + pose: SMPL pose parameters of optimized shape + betas: SMPL beta parameters of optimized shape + camera_translation: Camera translation + """ + + # # # add the mesh inter-section to avoid + search_tree = None + pen_distance = None + filter_faces = None + + if self.use_collision: + from mesh_intersection.bvh_search_tree import BVH + import mesh_intersection.loss as collisions_loss + from mesh_intersection.filter_faces import FilterFaces + + search_tree = BVH(max_collisions=8) + + pen_distance = collisions_loss.DistanceFieldPenetrationLoss( + sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True) + + if self.part_segm_fn: + # Read the part segmentation + part_segm_fn = os.path.expandvars(self.part_segm_fn) + with open(part_segm_fn, 'rb') as faces_parents_file: + face_segm_data = pickle.load(faces_parents_file, encoding='latin1') + faces_segm = face_segm_data['segm'] + faces_parents = face_segm_data['parents'] + # Create the module used to filter invalid collision pairs + filter_faces = FilterFaces( + faces_segm=faces_segm, faces_parents=faces_parents, + ign_part_pairs=None).to(device=self.device) + + + # Split SMPL pose to body pose and global orientation + body_pose = init_pose[:, 3:].detach().clone() + global_orient = init_pose[:, :3].detach().clone() + betas = init_betas.detach().clone() + + # use guess 3d to get the initial + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).detach() + camera_translation = init_cam_t.clone() + + preserve_pose = init_pose[:, 3:].detach().clone() + # -------------Step 1: Optimize camera translation and body orientation-------- + # Optimize only camera translation and body orientation + body_pose.requires_grad = False + betas.requires_grad = False + global_orient.requires_grad = True + camera_translation.requires_grad = True + + camera_opt_params = [global_orient, camera_translation] + + if self.use_lbfgs: + camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + for i in range(10): + def closure(): + camera_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + loss = camera_fitting_loss_3d(model_joints, camera_translation, + init_cam_t, j3d, self.joints_category) + loss.backward() + return loss + + camera_optimizer.step(closure) + else: + camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(20): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + loss = camera_fitting_loss_3d(model_joints[:, self.smpl_index], camera_translation, + init_cam_t, j3d[:, self.corr_index], self.joints_category) + camera_optimizer.zero_grad() + loss.backward() + camera_optimizer.step() + + # Fix camera translation after optimizing camera + # --------Step 2: Optimize body joints -------------------------- + # Optimize only the body pose and global orientation of the body + body_pose.requires_grad = True + global_orient.requires_grad = True + camera_translation.requires_grad = True + + # --- if we use the sequence, fix the shape + if seq_ind == 0: + betas.requires_grad = True + body_opt_params = [body_pose, betas, global_orient, camera_translation] + else: + betas.requires_grad = False + body_opt_params = [body_pose, global_orient, camera_translation] + + if self.use_lbfgs: + body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + + for i in tqdm(range(self.num_iters), desc=f"LBFGS iter: "): + # for i in range(self.num_iters): + def closure(): + body_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + pose_preserve_weight=5.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + loss.backward() + return loss + + body_optimizer.step(closure) + else: + body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(self.num_iters): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + body_optimizer.zero_grad() + loss.backward() + body_optimizer.step() + + # Get final loss value + with torch.no_grad(): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas, return_full_pose=True) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + final_loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + + vertices = smpl_output.vertices.detach() + joints = smpl_output.joints.detach() + pose = torch.cat([global_orient, body_pose], dim=-1).detach() + betas = betas.detach() + + return vertices, joints, pose, betas, camera_translation, final_loss \ No newline at end of file diff --git a/mGPT/data/transforms/rots2joints/__init__.py b/mGPT/data/transforms/rots2joints/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..7719c7018469a7c97a944d8d6d2113ef21ad01ab --- /dev/null +++ b/mGPT/data/transforms/rots2joints/__init__.py @@ -0,0 +1,3 @@ +from .base import Rots2Joints +from .smplh import SMPLH +from .smplx import SMPLX diff --git a/mGPT/data/transforms/rots2joints/base.py b/mGPT/data/transforms/rots2joints/base.py new file mode 100755 index 0000000000000000000000000000000000000000..524f830f071a61962163aa77e895be0090d7ba35 --- /dev/null +++ b/mGPT/data/transforms/rots2joints/base.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import Optional + +import torch +from torch import Tensor, nn +from pathlib import Path +import os +# import hydra + +class Rots2Joints(nn.Module): + def __init__(self, path: Optional[str] = None, + normalization: bool = False, + eps: float = 1e-12, + **kwargs) -> None: + if normalization and path is None: + raise TypeError("You should provide a path if normalization is on.") + + super().__init__() + self.normalization = normalization + self.eps = eps + # workaround for cluster local/sync + if path is not None: + rel_p = path.split('/') + rel_p = rel_p[rel_p.index('deps'):] + rel_p = '/'.join(rel_p) + # path = hydra.utils.get_original_cwd() + '/' + rel_p + if normalization: + mean_path = Path(path) / "mean.pt" + std_path = Path(path) / "std.pt" + self.register_buffer('mean', torch.load(mean_path)) + self.register_buffer('std', torch.load(std_path)) + + def normalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = (features - self.mean)/(self.std + self.eps) + return features + + def unnormalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = features * self.std + self.mean + return features diff --git a/mGPT/data/transforms/rots2joints/smplh.py b/mGPT/data/transforms/rots2joints/smplh.py new file mode 100755 index 0000000000000000000000000000000000000000..90efa4ff27a99f56618de16c84a5a8e1cfa2bee7 --- /dev/null +++ b/mGPT/data/transforms/rots2joints/smplh.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import contextlib +from typing import Optional + +import torch +from einops import rearrange +from torch import Tensor +from mGPT.utils.joints import smplh_to_mmm_scaling_factor +from mGPT.utils.joints import smplh2mmm_indexes +from .base import Rots2Joints + + +def slice_or_none(data, cslice): + if data is None: + return data + else: + return data[cslice] + + +class SMPLH(Rots2Joints): + + def __init__(self, + path: str, + jointstype: str = "mmm", + input_pose_rep: str = "matrix", + batch_size: int = 512, + gender="neutral", + **kwargs) -> None: + super().__init__(path=None, normalization=False) + self.batch_size = batch_size + self.input_pose_rep = input_pose_rep + self.jointstype = jointstype + self.training = False + + from smplx.body_models import SMPLHLayer + import os + # rel_p = path.split('/') + # rel_p = rel_p[rel_p.index('data'):] + # rel_p = '/'.join(rel_p) + + # Remove annoying print + with contextlib.redirect_stdout(None): + self.smplh = SMPLHLayer(path, ext="pkl", gender=gender).eval() + + self.faces = self.smplh.faces + for p in self.parameters(): + p.requires_grad = False + + def train(self, *args, **kwargs): + return self + + def forward(self, + smpl_data: dict, + jointstype: Optional[str] = None, + input_pose_rep: Optional[str] = None, + batch_size: Optional[int] = None) -> Tensor: + + # Take values from init if not specified there + jointstype = self.jointstype if jointstype is None else jointstype + batch_size = self.batch_size if batch_size is None else batch_size + input_pose_rep = self.input_pose_rep if input_pose_rep is None else input_pose_rep + + if input_pose_rep == "xyz": + raise NotImplementedError( + "You should use identity pose2joints instead") + + poses = smpl_data.rots + trans = smpl_data.trans + + from functools import reduce + import operator + save_shape_bs_len = poses.shape[:-3] + nposes = reduce(operator.mul, save_shape_bs_len, 1) + + if poses.shape[-3] == 52: + nohands = False + elif poses.shape[-3] == 22: + nohands = True + else: + raise NotImplementedError("Could not parse the poses.") + + # Convert any rotations to matrix + # from temos.tools.easyconvert import to_matrix + # matrix_poses = to_matrix(input_pose_rep, poses) + matrix_poses = poses + + # Reshaping + matrix_poses = matrix_poses.reshape((nposes, *matrix_poses.shape[-3:])) + global_orient = matrix_poses[:, 0] + + if trans is None: + trans = torch.zeros((*save_shape_bs_len, 3), + dtype=poses.dtype, + device=poses.device) + + trans_all = trans.reshape((nposes, *trans.shape[-1:])) + + body_pose = matrix_poses[:, 1:22] + if nohands: + left_hand_pose = None + right_hand_pose = None + else: + hand_pose = matrix_poses[:, 22:] + left_hand_pose = hand_pose[:, :15] + right_hand_pose = hand_pose[:, 15:] + + n = len(body_pose) + outputs = [] + for chunk in range(int((n - 1) / batch_size) + 1): + chunk_slice = slice(chunk * batch_size, (chunk + 1) * batch_size) + smpl_output = self.smplh( + global_orient=slice_or_none(global_orient, chunk_slice), + body_pose=slice_or_none(body_pose, chunk_slice), + left_hand_pose=slice_or_none(left_hand_pose, chunk_slice), + right_hand_pose=slice_or_none(right_hand_pose, chunk_slice), + transl=slice_or_none(trans_all, chunk_slice)) + + if jointstype == "vertices": + output_chunk = smpl_output.vertices + else: + joints = smpl_output.joints + output_chunk = joints + + outputs.append(output_chunk) + + outputs = torch.cat(outputs) + outputs = outputs.reshape((*save_shape_bs_len, *outputs.shape[1:])) + + # Change topology if needed + outputs = smplh_to(jointstype, outputs, trans) + + return outputs + + def inverse(self, joints: Tensor) -> Tensor: + raise NotImplementedError("Cannot inverse SMPLH layer.") + + +def smplh_to(jointstype, data, trans): + from mGPT.utils.joints import get_root_idx + + if "mmm" in jointstype: + from mGPT.utils.joints import smplh2mmm_indexes + indexes = smplh2mmm_indexes + data = data[..., indexes, :] + + # make it compatible with mmm + if jointstype == "mmm": + from mGPT.utils.joints import smplh_to_mmm_scaling_factor + data *= smplh_to_mmm_scaling_factor + + if jointstype == "smplmmm": + pass + elif jointstype in ["mmm", "mmmns"]: + # swap axis + data = data[..., [1, 2, 0]] + # revert left and right + data[..., 2] = -data[..., 2] + + elif jointstype == "smplnh": + from mGPT.utils.joints import smplh2smplnh_indexes + indexes = smplh2smplnh_indexes + data = data[..., indexes, :] + elif jointstype == "smplh": + pass + elif jointstype == "vertices": + pass + else: + raise NotImplementedError(f"SMPLH to {jointstype} is not implemented.") + + if jointstype != "vertices": + # shift the output in each batch + # such that it is centered on the pelvis/root on the first frame + root_joint_idx = get_root_idx(jointstype) + shift = trans[..., 0, :] - data[..., 0, root_joint_idx, :] + data += shift[..., None, None, :] + + return data diff --git a/mGPT/data/transforms/rots2joints/smplx.py b/mGPT/data/transforms/rots2joints/smplx.py new file mode 100755 index 0000000000000000000000000000000000000000..107eb57735a0344bb0d32a341310f0b6c0e6035b --- /dev/null +++ b/mGPT/data/transforms/rots2joints/smplx.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import contextlib +from typing import Optional +import torch +from torch import Tensor +from mGPT.utils.joints import smplh_to_mmm_scaling_factor, smplh2mmm_indexes, get_root_idx +from mGPT.utils.easyconvert import rep_to_rep +from .base import Rots2Joints + + +def slice_or_none(data, cslice): + if data is None: + return data + else: + return data[cslice] + + +class SMPLX(Rots2Joints): + def __init__(self, + path: str, + jointstype: str = "mmm", + input_pose_rep: str = "matrix", + batch_size: int = 512, + gender="neutral", + **kwargs) -> None: + super().__init__(path=None, normalization=False) + self.batch_size = batch_size + self.input_pose_rep = input_pose_rep + self.jointstype = jointstype + self.training = False + + from smplx.body_models import SMPLXLayer + import os + # rel_p = path.split('/') + # rel_p = rel_p[rel_p.index('data'):] + # rel_p = '/'.join(rel_p) + + # Remove annoying print + with contextlib.redirect_stdout(None): + self.smplx = SMPLXLayer(path, + ext="npz", + gender=gender, + batch_size=batch_size).eval() + + self.faces = self.smplx.faces + for p in self.parameters(): + p.requires_grad = False + + def train(self, *args, **kwargs): + return self + + def forward(self, + smpl_data: dict, + jointstype: Optional[str] = None, + input_pose_rep: Optional[str] = None, + batch_size: Optional[int] = None) -> Tensor: + + # Take values from init if not specified there + jointstype = self.jointstype if jointstype is None else jointstype + batch_size = self.batch_size if batch_size is None else batch_size + input_pose_rep = self.input_pose_rep if input_pose_rep is None else input_pose_rep + + poses = smpl_data.rots + trans = smpl_data.trans + + from functools import reduce + import operator + save_shape_bs_len = poses.shape[:-3] + nposes = reduce(operator.mul, save_shape_bs_len, 1) + + + matrix_poses = rep_to_rep(self.input_pose_rep, input_pose_rep, poses) + + # Reshaping + matrix_poses = matrix_poses.reshape((nposes, *matrix_poses.shape[-3:])) + + global_orient = matrix_poses[:, 0] + + if trans is None: + trans = torch.zeros((*save_shape_bs_len, 3), + dtype=poses.dtype, + device=poses.device) + + trans_all = trans.reshape((nposes, *trans.shape[-1:])) + + body_pose = matrix_poses[:, 1:22] + + if poses.shape[-3] == 55: + nohands = False + nofaces = False + elif poses.shape[-3] == 52: + nohands = False + nofaces = True + elif poses.shape[-3] == 22: + nohands = True + nofaces = True + else: + raise NotImplementedError("Could not parse the poses.") + + if nohands: + left_hand_pose = None + right_hand_pose = None + else: + left_hand_pose = matrix_poses[:, 25:40] + right_hand_pose = matrix_poses[:, 40:55] + + if nofaces: + jaw_pose = None + leye_pose = None + reye_pose = None + else: + jaw_pose = matrix_poses[:, 22:23] + leye_pose = matrix_poses[:, 23:24] + reye_pose = matrix_poses[:, 24:25] + + n = len(body_pose) + outputs = [] + for chunk in range(int((n - 1) / batch_size) + 1): + chunk_slice = slice(chunk * batch_size, (chunk + 1) * batch_size) + smpl_output = self.smplx( + global_orient=slice_or_none(global_orient, chunk_slice), + body_pose=slice_or_none(body_pose, chunk_slice), + left_hand_pose=slice_or_none(left_hand_pose, chunk_slice), + right_hand_pose=slice_or_none(right_hand_pose, chunk_slice), + jaw_pose=slice_or_none(jaw_pose, chunk_slice), + leye_pose=slice_or_none(leye_pose, chunk_slice), + reye_pose=slice_or_none(reye_pose, chunk_slice), + transl=slice_or_none(trans_all, chunk_slice)) + + if jointstype == "vertices": + output_chunk = smpl_output.vertices + else: + joints = smpl_output.joints + output_chunk = joints + + outputs.append(output_chunk) + + outputs = torch.cat(outputs) + outputs = outputs.reshape((*save_shape_bs_len, *outputs.shape[1:])) + + # Change topology if needed + outputs = smplx_to(jointstype, outputs, trans) + + return outputs + + def inverse(self, joints: Tensor) -> Tensor: + raise NotImplementedError("Cannot inverse SMPLX layer.") + + +def smplx_to(jointstype, data, trans): + + if "mmm" in jointstype: + indexes = smplh2mmm_indexes + data = data[..., indexes, :] + + # make it compatible with mmm + if jointstype == "mmm": + data *= smplh_to_mmm_scaling_factor + + if jointstype == "smplmmm": + pass + elif jointstype in ["mmm", "mmmns"]: + # swap axis + data = data[..., [1, 2, 0]] + # revert left and right + data[..., 2] = -data[..., 2] + + elif jointstype == "smplnh": + from mGPT.utils.joints import smplh2smplnh_indexes + indexes = smplh2smplnh_indexes + data = data[..., indexes, :] + elif jointstype == "smplh": + pass + elif jointstype == "vertices": + pass + else: + raise NotImplementedError(f"SMPLX to {jointstype} is not implemented.") + + if jointstype != "vertices": + # shift the output in each batch + # such that it is centered on the pelvis/root on the first frame + root_joint_idx = get_root_idx(jointstype) + shift = trans[..., 0, :] - data[..., 0, root_joint_idx, :] + data += shift[..., None, None, :] + + return data diff --git a/mGPT/data/transforms/rots2rfeats/__init__.py b/mGPT/data/transforms/rots2rfeats/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..29b206c50cfb3c0c5e66be5e3876456bfdaee119 --- /dev/null +++ b/mGPT/data/transforms/rots2rfeats/__init__.py @@ -0,0 +1,5 @@ +from .base import Rots2Rfeats +# from .globvel import Globalvel + +from .globvelandy import Globalvelandy +# from .rifeats import Rifeats diff --git a/mGPT/data/transforms/rots2rfeats/base.py b/mGPT/data/transforms/rots2rfeats/base.py new file mode 100755 index 0000000000000000000000000000000000000000..98c33bd3f30aebfc76a68b58e4a0e03eb53da29d --- /dev/null +++ b/mGPT/data/transforms/rots2rfeats/base.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import Optional + +import torch +from torch import Tensor, nn +from pathlib import Path +import os + +class Rots2Rfeats(nn.Module): + def __init__(self, path: Optional[str] = None, + normalization: bool = True, + eps: float = 1e-12, + **kwargs) -> None: + if normalization and path is None: + raise TypeError("You should provide a path if normalization is on.") + + super().__init__() + self.normalization = normalization + self.eps = eps + if normalization: + # workaround for cluster local/sync + rel_p = path.split('/') + # superhacky it is for the datatype ugly stuff change it, copy the main stuff to seperate_pairs dict + if rel_p[-1] == 'separate_pairs': + rel_p.remove('separate_pairs') + ######################################################## + # rel_p = rel_p[rel_p.index('deps'):] + rel_p = '/'.join(rel_p) + # path = hydra.utils.get_original_cwd() + '/' + rel_p + path = rel_p + mean_path = Path(path) / "rfeats_mean.pt" + std_path = Path(path) / "rfeats_std.pt" + + self.register_buffer('mean', torch.load(mean_path)) + self.register_buffer('std', torch.load(std_path)) + + def normalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = (features - self.mean)/(self.std + self.eps) + return features + + def unnormalize(self, features: Tensor) -> Tensor: + if self.normalization: + features = features * self.std + self.mean + return features diff --git a/mGPT/data/transforms/rots2rfeats/globvelandy.py b/mGPT/data/transforms/rots2rfeats/globvelandy.py new file mode 100755 index 0000000000000000000000000000000000000000..fe223afd2b2a2cbf6d868a553d7f336ccc169785 --- /dev/null +++ b/mGPT/data/transforms/rots2rfeats/globvelandy.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import Optional + +import torch +from torch import Tensor +from einops import rearrange + +from mGPT.utils.easyconvert import rep_to_rep, nfeats_of, to_matrix +import mGPT.utils.geometry_tools as geometry_tools + +from .base import Rots2Rfeats + + +class Globalvelandy(Rots2Rfeats): + def __init__(self, + path: Optional[str] = None, + normalization: bool = False, + pose_rep: str = "rot6d", + canonicalize: bool = False, + offset: bool = True, + **kwargs) -> None: + super().__init__(path=path, normalization=normalization) + + self.canonicalize = canonicalize + self.pose_rep = pose_rep + self.nfeats = nfeats_of(pose_rep) + self.offset = offset + + def forward(self, data, data_rep='matrix', first_frame=None) -> Tensor: + + poses, trans = data.rots, data.trans + + # extract the root gravity axis + # for smpl it is the last coordinate + root_y = trans[..., 2] + trajectory = trans[..., [0, 1]] + + # Compute the difference of trajectory + vel_trajectory = torch.diff(trajectory, dim=-2) + + # 0 for the first one => keep the dimentionality + if first_frame is None: + first_frame = 0 * vel_trajectory[..., [0], :] + + vel_trajectory = torch.cat((first_frame, vel_trajectory), dim=-2) + + # first normalize the data + if self.canonicalize: + + matrix_poses = rep_to_rep(data_rep, 'matrix', poses) + global_orient = matrix_poses[..., 0, :, :] + + # remove the rotation + rot2d = rep_to_rep(data_rep, 'rotvec', poses[0, 0, ...]) + + # Remove the fist rotation along the vertical axis + rot2d[..., :2] = 0 + + if self.offset: + # add a bit more rotation + rot2d[..., 2] += torch.pi / 2 + + rot2d = rep_to_rep('rotvec', 'matrix', rot2d) + + # turn with the same amount all the rotations + global_orient = torch.einsum("...kj,...kl->...jl", rot2d, + global_orient) + + matrix_poses = torch.cat( + (global_orient[..., None, :, :], matrix_poses[..., 1:, :, :]), + dim=-3) + + poses = rep_to_rep('matrix', data_rep, matrix_poses) + + # Turn the trajectory as well + vel_trajectory = torch.einsum("...kj,...lk->...lj", + rot2d[..., :2, :2], vel_trajectory) + + poses = rep_to_rep(data_rep, self.pose_rep, poses) + features = torch.cat( + (root_y[..., None], vel_trajectory, + rearrange(poses, "... joints rot -> ... (joints rot)")), + dim=-1) + features = self.normalize(features) + + return features + + def extract(self, features): + root_y = features[..., 0] + vel_trajectory = features[..., 1:3] + poses_features = features[..., 3:] + poses = rearrange(poses_features, + "... (joints rot) -> ... joints rot", + rot=self.nfeats) + return root_y, vel_trajectory, poses + + def inverse(self, features, last_frame=None): + features = self.unnormalize(features) + root_y, vel_trajectory, poses = self.extract(features) + + # integrate the trajectory + trajectory = torch.cumsum(vel_trajectory, dim=-2) + if last_frame is None: + pass + # First frame should be 0, but if infered it is better to ensure it + trajectory = trajectory - trajectory[..., [0], :] + + # Get back the translation + trans = torch.cat([trajectory, root_y[..., None]], dim=-1) + matrix_poses = rep_to_rep(self.pose_rep, 'matrix', poses) + + from ..smpl import RotTransDatastruct + return RotTransDatastruct(rots=matrix_poses, trans=trans) diff --git a/mGPT/data/transforms/smpl.py b/mGPT/data/transforms/smpl.py new file mode 100755 index 0000000000000000000000000000000000000000..fc46b11cc24231db62d1f8a182c00a10fce70db6 --- /dev/null +++ b/mGPT/data/transforms/smpl.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import Optional +from torch import Tensor +import smplx + +from .base import Datastruct, dataclass, Transform + +from .rots2rfeats import Rots2Rfeats +from .rots2joints import Rots2Joints +from .joints2jfeats import Joints2Jfeats + + +class SMPLTransform(Transform): + def __init__(self, rots2rfeats: Rots2Rfeats, + rots2joints: Rots2Joints, + joints2jfeats: Joints2Jfeats, + **kwargs): + self.rots2rfeats = rots2rfeats + self.rots2joints = rots2joints + self.joints2jfeats = joints2jfeats + + def Datastruct(self, **kwargs): + return SMPLDatastruct(_rots2rfeats=self.rots2rfeats, + _rots2joints=self.rots2joints, + _joints2jfeats=self.joints2jfeats, + transforms=self, + **kwargs) + + def __repr__(self): + return "SMPLTransform()" + + +class RotIdentityTransform(Transform): + def __init__(self, **kwargs): + return + + def Datastruct(self, **kwargs): + return RotTransDatastruct(**kwargs) + + def __repr__(self): + return "RotIdentityTransform()" + + +@dataclass +class RotTransDatastruct(Datastruct): + rots: Tensor + trans: Tensor + + transforms: RotIdentityTransform = RotIdentityTransform() + + def __post_init__(self): + self.datakeys = ["rots", "trans"] + + def __len__(self): + return len(self.rots) + + +@dataclass +class SMPLDatastruct(Datastruct): + transforms: SMPLTransform + _rots2rfeats: Rots2Rfeats + _rots2joints: Rots2Joints + _joints2jfeats: Joints2Jfeats + + features: Optional[Tensor] = None + rots_: Optional[RotTransDatastruct] = None + rfeats_: Optional[Tensor] = None + joints_: Optional[Tensor] = None + jfeats_: Optional[Tensor] = None + vertices_: Optional[Tensor] = None + + def __post_init__(self): + self.datakeys = ['features', 'rots_', 'rfeats_', + 'joints_', 'jfeats_', 'vertices_'] + # starting point + if self.features is not None and self.rfeats_ is None: + self.rfeats_ = self.features + + @property + def rots(self): + # Cached value + if self.rots_ is not None: + return self.rots_ + + # self.rfeats_ should be defined + assert self.rfeats_ is not None + + self._rots2rfeats.to(self.rfeats.device) + self.rots_ = self._rots2rfeats.inverse(self.rfeats) + return self.rots_ + + @property + def rfeats(self): + # Cached value + if self.rfeats_ is not None: + return self.rfeats_ + + # self.rots_ should be defined + assert self.rots_ is not None + + self._rots2rfeats.to(self.rots.device) + self.rfeats_ = self._rots2rfeats(self.rots) + return self.rfeats_ + + @property + def joints(self): + # Cached value + if self.joints_ is not None: + return self.joints_ + + self._rots2joints.to(self.rots.device) + self.joints_ = self._rots2joints(self.rots) + return self.joints_ + + @property + def jfeats(self): + # Cached value + if self.jfeats_ is not None: + return self.jfeats_ + + self._joints2jfeats.to(self.joints.device) + self.jfeats_ = self._joints2jfeats(self.joints) + return self.jfeats_ + + @property + def vertices(self): + # Cached value + if self.vertices_ is not None: + return self.vertices_ + + self._rots2joints.to(self.rots.device) + self.vertices_ = self._rots2joints(self.rots, jointstype="vertices") + return self.vertices_ + + def __len__(self): + return len(self.rfeats) + + +def get_body_model(model_type, gender, batch_size, device='cpu', ext='pkl'): + ''' + type: smpl, smplx smplh and others. Refer to smplx tutorial + gender: male, female, neutral + batch_size: an positive integar + ''' + mtype = model_type.upper() + if gender != 'neutral': + if not isinstance(gender, str): + gender = str(gender.astype(str)).upper() + else: + gender = gender.upper() + else: + gender = gender.upper() + ext = 'npz' + body_model_path = f'data/smpl_models/{model_type}/{mtype}_{gender}.{ext}' + + body_model = smplx.create(body_model_path, model_type=type, + gender=gender, ext=ext, + use_pca=False, + num_pca_comps=12, + create_global_orient=True, + create_body_pose=True, + create_betas=True, + create_left_hand_pose=True, + create_right_hand_pose=True, + create_expression=True, + create_jaw_pose=True, + create_leye_pose=True, + create_reye_pose=True, + create_transl=True, + batch_size=batch_size) + + if device == 'cuda': + return body_model.cuda() + else: + return body_model + diff --git a/mGPT/data/transforms/xyz.py b/mGPT/data/transforms/xyz.py new file mode 100755 index 0000000000000000000000000000000000000000..7add165fb15abd6f7362ffb9af906d93964b2d7a --- /dev/null +++ b/mGPT/data/transforms/xyz.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2020 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from typing import Optional +from torch import Tensor + +from .base import Datastruct, dataclass, Transform +from ..tools import collate_tensor_with_padding + +from .joints2jfeats import Joints2Jfeats + + +class XYZTransform(Transform): + def __init__(self, joints2jfeats: Joints2Jfeats, **kwargs): + self.joints2jfeats = joints2jfeats + + def Datastruct(self, **kwargs): + return XYZDatastruct(_joints2jfeats=self.joints2jfeats, + transforms=self, + **kwargs) + + def __repr__(self): + return "XYZTransform()" + + +@dataclass +class XYZDatastruct(Datastruct): + transforms: XYZTransform + _joints2jfeats: Joints2Jfeats + + features: Optional[Tensor] = None + joints_: Optional[Tensor] = None + jfeats_: Optional[Tensor] = None + + def __post_init__(self): + self.datakeys = ["features", "joints_", "jfeats_"] + # starting point + if self.features is not None and self.jfeats_ is None: + self.jfeats_ = self.features + + @property + def joints(self): + # Cached value + if self.joints_ is not None: + return self.joints_ + + # self.jfeats_ should be defined + assert self.jfeats_ is not None + + self._joints2jfeats.to(self.jfeats.device) + self.joints_ = self._joints2jfeats.inverse(self.jfeats) + return self.joints_ + + @property + def jfeats(self): + # Cached value + if self.jfeats_ is not None: + return self.jfeats_ + + # self.joints_ should be defined + assert self.joints_ is not None + + self._joints2jfeats.to(self.joints.device) + self.jfeats_ = self._joints2jfeats(self.joints) + return self.jfeats_ + + def __len__(self): + return len(self.jfeats) diff --git a/mGPT/data/utils.py b/mGPT/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..30714ff6562edd5136e93107fea10c25b4449c79 --- /dev/null +++ b/mGPT/data/utils.py @@ -0,0 +1,81 @@ +import torch +import rich +import pickle +import numpy as np + + +def lengths_to_mask(lengths): + max_len = max(lengths) + mask = torch.arange(max_len, device=lengths.device).expand( + len(lengths), max_len) < lengths.unsqueeze(1) + return mask + + +# padding to max length in one batch +def collate_tensors(batch): + if isinstance(batch[0], np.ndarray): + batch = [torch.tensor(b).float() for b in batch] + + dims = batch[0].dim() + max_size = [max([b.size(i) for b in batch]) for i in range(dims)] + size = (len(batch), ) + tuple(max_size) + canvas = batch[0].new_zeros(size=size) + for i, b in enumerate(batch): + sub_tensor = canvas[i] + for d in range(dims): + sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) + sub_tensor.add_(b) + return canvas + +def humanml3d_collate(batch): + notnone_batches = [b for b in batch if b is not None] + EvalFlag = False if notnone_batches[0][5] is None else True + + # Sort by text length + if EvalFlag: + notnone_batches.sort(key=lambda x: x[5], reverse=True) + + # Motion only + adapted_batch = { + "motion": + collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]), + "length": [b[2] for b in notnone_batches], + } + + # Text and motion + if notnone_batches[0][0] is not None: + adapted_batch.update({ + "text": [b[0] for b in notnone_batches], + "all_captions": [b[7] for b in notnone_batches], + }) + + # Evaluation related + if EvalFlag: + adapted_batch.update({ + "text": [b[0] for b in notnone_batches], + "word_embs": + collate_tensors( + [torch.tensor(b[3]).float() for b in notnone_batches]), + "pos_ohot": + collate_tensors( + [torch.tensor(b[4]).float() for b in notnone_batches]), + "text_len": + collate_tensors([torch.tensor(b[5]) for b in notnone_batches]), + "tokens": [b[6] for b in notnone_batches], + }) + + # Tasks + if len(notnone_batches[0]) == 9: + adapted_batch.update({"tasks": [b[8] for b in notnone_batches]}) + + return adapted_batch + + +def load_pkl(path, description=None, progressBar=False): + if progressBar: + with rich.progress.open(path, 'rb', description=description) as file: + data = pickle.load(file) + else: + with open(path, 'rb') as file: + data = pickle.load(file) + return data diff --git a/mGPT/losses/__init__.py b/mGPT/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe3b35f61b038e414b3e8eb75b24b2a00a8ef2f --- /dev/null +++ b/mGPT/losses/__init__.py @@ -0,0 +1 @@ +from .base import BaseLosses diff --git a/mGPT/losses/base.py b/mGPT/losses/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d3103118984ab7c37b9fd40acbd23485f02b11a1 --- /dev/null +++ b/mGPT/losses/base.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn + +class BaseLosses(nn.Module): + def __init__(self, cfg, losses, params, losses_func, num_joints, **kwargs): + super().__init__() + + # Save parameters + self.num_joints = num_joints + self._params = params + + # Add total indicator + losses.append("total") if "total" not in losses else None + + # Register losses + for loss in losses: + self.register_buffer(loss, torch.tensor(0.0)) + self.register_buffer("count", torch.tensor(0.0)) + self.losses = losses + + # Instantiate loss functions + self._losses_func = {} + for loss in losses[:-1]: + self._losses_func[loss] = losses_func[loss](reduction='mean') + + def _update_loss(self, loss: str, outputs, inputs): + '''Update the loss and return the weighted loss.''' + # Update the loss + val = self._losses_func[loss](outputs, inputs) + # self.losses_values[loss] += val.detach() + getattr(self, loss).add_(val.detach()) + # Return a weighted sum + weighted_loss = self._params[loss] * val + return weighted_loss + + def reset(self): + '''Reset the losses to 0.''' + for loss in self.losses: + setattr(self, loss, torch.tensor(0.0, device=getattr(self, loss).device)) + setattr(self, "count", torch.tensor(0.0, device=getattr(self, "count").device)) + + def compute(self, split): + '''Compute the losses and return a dictionary with the losses.''' + count = self.count + # Loss dictionary + loss_dict = {loss: getattr(self, loss)/count for loss in self.losses} + # Format the losses for logging + log_dict = { self.loss2logname(loss, split): value.item() + for loss, value in loss_dict.items() if not torch.isnan(value)} + # Reset the losses + self.reset() + return log_dict + + def loss2logname(self, loss: str, split: str): + '''Convert the loss name to a log name.''' + if loss == "total": + log_name = f"{loss}/{split}" + else: + loss_type, name = loss.split("_") + log_name = f"{loss_type}/{name}/{split}" + return log_name diff --git a/mGPT/losses/mgpt.py b/mGPT/losses/mgpt.py new file mode 100644 index 0000000000000000000000000000000000000000..69846b2cfb59b04e8ee9fd7f51f0b6e5c624e18a --- /dev/null +++ b/mGPT/losses/mgpt.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from .base import BaseLosses + + +class CommitLoss(nn.Module): + """ + Useless Wrapper + """ + def __init__(self, **kwargs): + super().__init__() + + def forward(self, commit, commit2, **kwargs): + return commit + + +class GPTLosses(BaseLosses): + + def __init__(self, cfg, stage, num_joints, **kwargs): + # Save parameters + self.stage = stage + recons_loss = cfg.LOSS.ABLATION.RECONS_LOSS + + # Define losses + losses = [] + params = {} + if stage == "vae": + losses.append("recons_feature") + params['recons_feature'] = cfg.LOSS.LAMBDA_FEATURE + + losses.append("recons_velocity") + params['recons_velocity'] = cfg.LOSS.LAMBDA_VELOCITY + + losses.append("vq_commit") + params['vq_commit'] = cfg.LOSS.LAMBDA_COMMIT + elif stage in ["lm_pretrain", "lm_instruct"]: + losses.append("gpt_loss") + params['gpt_loss'] = cfg.LOSS.LAMBDA_CLS + + # Define loss functions & weights + losses_func = {} + for loss in losses: + if loss.split('_')[0] == 'recons': + if recons_loss == "l1": + losses_func[loss] = nn.L1Loss + elif recons_loss == "l2": + losses_func[loss] = nn.MSELoss + elif recons_loss == "l1_smooth": + losses_func[loss] = nn.SmoothL1Loss + elif loss.split('_')[1] in [ + 'commit', 'loss', 'gpt', 'm2t2m', 't2m2t' + ]: + losses_func[loss] = CommitLoss + elif loss.split('_')[1] in ['cls', 'lm']: + losses_func[loss] = nn.CrossEntropyLoss + else: + raise NotImplementedError(f"Loss {loss} not implemented.") + + super().__init__(cfg, losses, params, losses_func, num_joints, + **kwargs) + + def update(self, rs_set): + '''Update the losses''' + total: float = 0.0 + + if self.stage in ["vae"]: + total += self._update_loss("recons_feature", rs_set['m_rst'], + rs_set['m_ref']) + # total += self._update_loss("recons_joints", rs_set['joints_rst'], rs_set['joints_ref']) + nfeats = rs_set['m_rst'].shape[-1] + if nfeats in [263, 135 + 263]: + if nfeats == 135 + 263: + vel_start = 135 + 4 + elif nfeats == 263: + vel_start = 4 + total += self._update_loss( + "recons_velocity", + rs_set['m_rst'][..., vel_start:(self.num_joints - 1) * 3 + + vel_start], + rs_set['m_ref'][..., vel_start:(self.num_joints - 1) * 3 + + vel_start]) + else: + if self._params['recons_velocity'] != 0.0: + raise NotImplementedError( + "Velocity not implemented for nfeats = {})".format(nfeats)) + total += self._update_loss("vq_commit", rs_set['loss_commit'], + rs_set['loss_commit']) + + if self.stage in ["lm_pretrain", "lm_instruct"]: + total += self._update_loss("gpt_loss", rs_set['outputs'].loss, + rs_set['outputs'].loss) + + # Update the total loss + self.total += total.detach() + self.count += 1 + + return total diff --git a/mGPT/metrics/__init__.py b/mGPT/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33c20815c943401fb346bb0ab5a512dd9713658c --- /dev/null +++ b/mGPT/metrics/__init__.py @@ -0,0 +1 @@ +from .base import BaseMetrics diff --git a/mGPT/metrics/base.py b/mGPT/metrics/base.py new file mode 100644 index 0000000000000000000000000000000000000000..540771902f3f91bf926f5fbf1e170242f8b26181 --- /dev/null +++ b/mGPT/metrics/base.py @@ -0,0 +1,46 @@ +from torch import Tensor, nn +from os.path import join as pjoin +from .mr import MRMetrics +from .t2m import TM2TMetrics +from .mm import MMMetrics +from .m2t import M2TMetrics +from .m2m import PredMetrics + + +class BaseMetrics(nn.Module): + def __init__(self, cfg, datamodule, debug, **kwargs) -> None: + super().__init__() + + njoints = datamodule.njoints + + data_name = datamodule.name + if data_name in ["humanml3d", "kit"]: + self.TM2TMetrics = TM2TMetrics( + cfg=cfg, + dataname=data_name, + diversity_times=30 if debug else cfg.METRIC.DIVERSITY_TIMES, + dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, + ) + self.M2TMetrics = M2TMetrics( + cfg=cfg, + w_vectorizer=datamodule.hparams.w_vectorizer, + diversity_times=30 if debug else cfg.METRIC.DIVERSITY_TIMES, + dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP) + self.MMMetrics = MMMetrics( + cfg=cfg, + mm_num_times=cfg.METRIC.MM_NUM_TIMES, + dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, + ) + + self.MRMetrics = MRMetrics( + njoints=njoints, + jointstype=cfg.DATASET.JOINT_TYPE, + dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, + ) + self.PredMetrics = PredMetrics( + cfg=cfg, + njoints=njoints, + jointstype=cfg.DATASET.JOINT_TYPE, + dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, + task=cfg.model.params.task, + ) diff --git a/mGPT/metrics/m2m.py b/mGPT/metrics/m2m.py new file mode 100644 index 0000000000000000000000000000000000000000..780d34dadf7d9b01369884b1ae3e5a6e0c6bceb3 --- /dev/null +++ b/mGPT/metrics/m2m.py @@ -0,0 +1,95 @@ +from typing import List + +import torch +from torch import Tensor +from torchmetrics import Metric + +from .utils import * + + +# motion reconstruction metric +class PredMetrics(Metric): + + def __init__(self, + cfg, + njoints: int = 22, + jointstype: str = "mmm", + force_in_meter: bool = True, + align_root: bool = True, + dist_sync_on_step=True, + task: str = "pred", + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = 'Motion Prdiction' + self.cfg = cfg + self.jointstype = jointstype + self.align_root = align_root + self.task = task + self.force_in_meter = force_in_meter + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.add_state("APD", + default=torch.tensor([0.0]), + dist_reduce_fx="sum") + self.add_state("ADE", + default=torch.tensor([0.0]), + dist_reduce_fx="sum") + self.add_state("FDE", + default=torch.tensor([0.0]), + dist_reduce_fx="sum") + + self.MR_metrics = ["APD", "ADE", "FDE"] + + # All metric + self.metrics = self.MR_metrics + + def compute(self, sanity_flag): + + count = self.count + count_seq = self.count_seq + mr_metrics = {} + mr_metrics["APD"] = self.APD / count_seq + mr_metrics["ADE"] = self.ADE / count_seq + mr_metrics["FDE"] = self.FDE / count_seq + + # Reset + self.reset() + + return mr_metrics + + def update(self, joints_rst: Tensor, joints_ref: Tensor, + lengths: List[int]): + + assert joints_rst.shape == joints_ref.shape + assert joints_rst.dim() == 4 + # (bs, seq, njoint=22, 3) + + self.count += sum(lengths) + self.count_seq += len(lengths) + + rst = torch.flatten(joints_rst, start_dim=2) + ref = torch.flatten(joints_ref, start_dim=2) + + for i, l in enumerate(lengths): + if self.task == "pred": + pred_start = int(l*self.cfg.ABLATION.predict_ratio) + diff = rst[i,pred_start:] - ref[i,pred_start:] + elif self.task == "inbetween": + inbetween_start = int(l*self.cfg.ABLATION.inbetween_ratio) + inbetween_end = l - int(l*self.cfg.ABLATION.inbetween_ratio) + diff = rst[i,inbetween_start:inbetween_end] - ref[i,inbetween_start:inbetween_end] + else: + print(f"Task {self.task} not implemented.") + diff = rst - ref + + dist = torch.linalg.norm(diff, dim=-1)[None] + + ade = dist.mean(dim=1) + fde = dist[:,-1] + self.ADE = self.ADE + ade + self.FDE = self.FDE + fde diff --git a/mGPT/metrics/m2t.py b/mGPT/metrics/m2t.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf1f4cbdc8ba6df5f63291d0ab2dd2640dc3758 --- /dev/null +++ b/mGPT/metrics/m2t.py @@ -0,0 +1,345 @@ +from typing import List +import os +import torch +from torch import Tensor +from torchmetrics import Metric +from .utils import * +from bert_score import score as score_bert +import spacy +from mGPT.config import instantiate_from_config + +class M2TMetrics(Metric): + + def __init__(self, + cfg, + w_vectorizer, + dataname='humanml3d', + top_k=3, + bleu_k=4, + R_size=32, + max_text_len=40, + diversity_times=300, + dist_sync_on_step=True, + unit_length=4, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.cfg = cfg + self.dataname = dataname + self.w_vectorizer = w_vectorizer + self.name = "matching, fid, and diversity scores" + # self.text = True if cfg.TRAIN.STAGE in ["diffusion","t2m_gpt"] else False + self.max_text_len = max_text_len + self.top_k = top_k + self.bleu_k = bleu_k + self.R_size = R_size + self.diversity_times = diversity_times + self.unit_length = unit_length + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = [] + + # Matching scores + self.add_state("Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("gt_Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.Matching_metrics = ["Matching_score", "gt_Matching_score"] + for k in range(1, top_k + 1): + self.add_state( + f"R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"R_precision_top_{str(k)}") + for k in range(1, top_k + 1): + self.add_state( + f"gt_R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}") + + self.metrics.extend(self.Matching_metrics) + + # NLG + for k in range(1, top_k + 1): + self.add_state( + f"Bleu_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.metrics.append(f"Bleu_{str(k)}") + + self.add_state("ROUGE_L", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.metrics.append("ROUGE_L") + + self.add_state("CIDEr", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.metrics.append("CIDEr") + + # Chached batches + self.pred_texts = [] + self.gt_texts = [] + self.add_state("predtext_embeddings", default=[]) + self.add_state("gttext_embeddings", default=[]) + self.add_state("gtmotion_embeddings", default=[]) + + # T2M Evaluator + self._get_t2m_evaluator(cfg) + + self.nlp = spacy.load('en_core_web_sm') + + if self.cfg.model.params.task == 'm2t': + from nlgmetricverse import NLGMetricverse, load_metric + metrics = [ + load_metric("bleu", resulting_name="bleu_1", compute_kwargs={"max_order": 1}), + load_metric("bleu", resulting_name="bleu_4", compute_kwargs={"max_order": 4}), + load_metric("rouge"), + load_metric("cider"), + ] + self.nlg_evaluator = NLGMetricverse(metrics) + + def _get_t2m_evaluator(self, cfg): + """ + load T2M text encoder and motion encoder for evaluating + """ + # init module + self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder) + self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder) + self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder) + + + # load pretrianed + if self.dataname == "kit": + dataname = "kit" + else: + dataname = "t2m" + + t2m_checkpoint = torch.load(os.path.join( + cfg.METRIC.TM2T.t2m_path, dataname, "text_mot_match/model/finest.tar"), + map_location='cpu') + self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) + self.t2m_moveencoder.load_state_dict( + t2m_checkpoint["movement_encoder"]) + self.t2m_motionencoder.load_state_dict( + t2m_checkpoint["motion_encoder"]) + + # freeze params + self.t2m_textencoder.eval() + self.t2m_moveencoder.eval() + self.t2m_motionencoder.eval() + for p in self.t2m_textencoder.parameters(): + p.requires_grad = False + for p in self.t2m_moveencoder.parameters(): + p.requires_grad = False + for p in self.t2m_motionencoder.parameters(): + p.requires_grad = False + + def _process_text(self, sentence): + sentence = sentence.replace('-', '') + doc = self.nlp(sentence) + word_list = [] + pos_list = [] + for token in doc: + word = token.text + if not word.isalpha(): + continue + if (token.pos_ == 'NOUN' + or token.pos_ == 'VERB') and (word != 'left'): + word_list.append(token.lemma_) + else: + word_list.append(word) + pos_list.append(token.pos_) + return word_list, pos_list + + def _get_text_embeddings(self, texts): + word_embs = [] + pos_ohot = [] + text_lengths = [] + for i, sentence in enumerate(texts): + word_list, pos_list = self._process_text(sentence.strip()) + t_tokens = [ + '%s/%s' % (word_list[i], pos_list[i]) + for i in range(len(word_list)) + ] + + if len(t_tokens) < self.max_text_len: + # pad with "unk" + tokens = ['sos/OTHER'] + t_tokens + ['eos/OTHER'] + sent_len = len(tokens) + tokens = tokens + ['unk/OTHER' + ] * (self.max_text_len + 2 - sent_len) + else: + # crop + tokens = t_tokens[:self.max_text_len] + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + sent_len = len(tokens) + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(torch.tensor(pos_oh).float()[None]) + word_embeddings.append(torch.tensor(word_emb).float()[None]) + text_lengths.append(sent_len) + pos_ohot.append(torch.cat(pos_one_hots, dim=0)[None]) + word_embs.append(torch.cat(word_embeddings, dim=0)[None]) + + word_embs = torch.cat(word_embs, dim=0).to(self.Matching_score) + pos_ohot = torch.cat(pos_ohot, dim=0).to(self.Matching_score) + text_lengths = torch.tensor(text_lengths).to(self.Matching_score) + + align_idx = np.argsort(text_lengths.data.tolist())[::-1].copy() + + # get text embeddings + text_embeddings = self.t2m_textencoder(word_embs[align_idx], + pos_ohot[align_idx], + text_lengths[align_idx]) + + original_text_embeddings = text_embeddings.clone() + + for idx, sort in enumerate(align_idx): + original_text_embeddings[sort] = text_embeddings[idx] + + return original_text_embeddings + + @torch.no_grad() + def compute(self, sanity_flag): + count = self.count.item() + count_seq = self.count_seq.item() + + # Init metrics dict + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # Jump in sanity check stage + if sanity_flag: + return metrics + + # Cat cached batches and shuffle + shuffle_idx = torch.randperm(count_seq) + all_motions = torch.cat(self.gtmotion_embeddings, + axis=0).cpu()[shuffle_idx, :] + all_gttexts = torch.cat(self.gttext_embeddings, + axis=0).cpu()[shuffle_idx, :] + all_predtexts = torch.cat(self.predtext_embeddings, + axis=0).cpu()[shuffle_idx, :] + + print("Computing metrics...") + + # Compute r-precision + assert count_seq >= self.R_size + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_predtexts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_motions[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 32] + dist_mat = euclidean_distance_matrix(group_texts, + group_motions).nan_to_num() + # print(dist_mat[:5]) + self.Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) + + R_count = count_seq // self.R_size * self.R_size + metrics["Matching_score"] = self.Matching_score / R_count + for k in range(self.top_k): + metrics[f"R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count + + # Compute r-precision with gt + assert count_seq >= self.R_size + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_gttexts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_motions[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 32] + dist_mat = euclidean_distance_matrix(group_texts, + group_motions).nan_to_num() + # match score + self.gt_Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, top_k=self.top_k).sum(axis=0) + metrics["gt_Matching_score"] = self.gt_Matching_score / R_count + for k in range(self.top_k): + metrics[f"gt_R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count + + # NLP metrics + scores = self.nlg_evaluator(predictions=self.pred_texts, + references=self.gt_texts) + for k in range(1, self.bleu_k + 1): + metrics[f"Bleu_{str(k)}"] = torch.tensor(scores[f'bleu_{str(k)}'], + device=self.device) + + metrics["ROUGE_L"] = torch.tensor(scores["rouge"]["rougeL"], + device=self.device) + metrics["CIDEr"] = torch.tensor(scores["cider"]['score'],device=self.device) + + # Bert metrics + P, R, F1 = score_bert(self.pred_texts, + self.gt_texts, + lang='en', + rescale_with_baseline=True, + idf=True, + device=self.device, + verbose=False) + + metrics["Bert_F1"] = F1.mean() + + # Reset + self.reset() + self.gt_texts = [] + self.pred_texts = [] + + return {**metrics} + + @torch.no_grad() + def update(self, + feats_ref: Tensor, + pred_texts: List[str], + gt_texts: List[str], + lengths: List[int], + word_embs: Tensor = None, + pos_ohot: Tensor = None, + text_lengths: Tensor = None): + + self.count += sum(lengths) + self.count_seq += len(lengths) + + # motion encoder + m_lens = torch.tensor(lengths, device=feats_ref.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + feats_ref = feats_ref[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + ref_mov = self.t2m_moveencoder(feats_ref[..., :-4]).detach() + m_lens = m_lens // self.unit_length + ref_emb = self.t2m_motionencoder(ref_mov, m_lens) + gtmotion_embeddings = torch.flatten(ref_emb, start_dim=1).detach() + self.gtmotion_embeddings.append(gtmotion_embeddings) + + # text encoder + gttext_emb = self.t2m_textencoder(word_embs, pos_ohot, + text_lengths)[align_idx] + gttext_embeddings = torch.flatten(gttext_emb, start_dim=1).detach() + predtext_emb = self._get_text_embeddings(pred_texts)[align_idx] + predtext_embeddings = torch.flatten(predtext_emb, start_dim=1).detach() + + self.gttext_embeddings.append(gttext_embeddings) + self.predtext_embeddings.append(predtext_embeddings) + + self.pred_texts.extend(pred_texts) + self.gt_texts.extend(gt_texts) diff --git a/mGPT/metrics/mm.py b/mGPT/metrics/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..165718736598da6ecc01174144136feb9950f53b --- /dev/null +++ b/mGPT/metrics/mm.py @@ -0,0 +1,129 @@ +from typing import List + +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.functional import pairwise_euclidean_distance +from .utils import * +import os +from mGPT.config import instantiate_from_config + +class MMMetrics(Metric): + full_state_update = True + + def __init__(self, cfg, dataname='humanml3d', mm_num_times=10, dist_sync_on_step=True, **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "MultiModality scores" + self.cfg = cfg + self.dataname = dataname + self.mm_num_times = mm_num_times + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = ["MultiModality"] + self.add_state("MultiModality", + default=torch.tensor(0.), + dist_reduce_fx="sum") + + # chached batches + self.add_state("mm_motion_embeddings", default=[], dist_reduce_fx=None) + + # T2M Evaluator + self._get_t2m_evaluator(cfg) + + def _get_t2m_evaluator(self, cfg): + """ + load T2M text encoder and motion encoder for evaluating + """ + # init module + self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder) + self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder) + self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder) + + # load pretrianed + if self.dataname == "kit": + dataname = "kit" + else: + dataname = "t2m" + t2m_checkpoint = torch.load(os.path.join( + cfg.METRIC.TM2T.t2m_path, dataname, + "text_mot_match/model/finest.tar"), + map_location="cpu") + + self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) + self.t2m_moveencoder.load_state_dict( + t2m_checkpoint["movement_encoder"]) + self.t2m_motionencoder.load_state_dict( + t2m_checkpoint["motion_encoder"]) + + # freeze params + self.t2m_textencoder.eval() + self.t2m_moveencoder.eval() + self.t2m_motionencoder.eval() + for p in self.t2m_textencoder.parameters(): + p.requires_grad = False + for p in self.t2m_moveencoder.parameters(): + p.requires_grad = False + for p in self.t2m_motionencoder.parameters(): + p.requires_grad = False + + def compute(self, sanity_flag): + count = self.count.item() + count_seq = self.count_seq.item() + + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # if in sanity check stage then jump + if sanity_flag: + return metrics + + # cat all embeddings + all_mm_motions = torch.cat(self.mm_motion_embeddings, + axis=0).cpu().numpy() + metrics['MultiModality'] = calculate_multimodality_np( + all_mm_motions, self.mm_num_times) + + # Reset + self.reset() + + return {**metrics} + + def update( + self, + feats_rst: Tensor, + lengths_rst: List[int], + ): + self.count += sum(lengths_rst) + self.count_seq += len(lengths_rst) + + align_idx = np.argsort(lengths_rst)[::-1].copy() + feats_rst = feats_rst[align_idx] + lengths_rst = np.array(lengths_rst)[align_idx] + recmotion_embeddings = self.get_motion_embeddings( + feats_rst, lengths_rst) + cache = [0] * len(lengths_rst) + for i in range(len(lengths_rst)): + cache[align_idx[i]] = recmotion_embeddings[i:i + 1] + + mm_motion_embeddings = torch.cat(cache, axis=0).unsqueeze(0) + # self.mm_motion_embeddings.extend(cache) + # print(mm_motion_embeddings.shape) + # # store all mm motion embeddings + self.mm_motion_embeddings.append(mm_motion_embeddings) + + def get_motion_embeddings(self, feats: Tensor, lengths: List[int]): + m_lens = torch.tensor(lengths) + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + mov = self.t2m_moveencoder(feats[..., :-4]).detach() + emb = self.t2m_motionencoder(mov, m_lens) + + # [bs, nlatent*ndim] <= [bs, nlatent, ndim] + return torch.flatten(emb, start_dim=1).detach() diff --git a/mGPT/metrics/mr.py b/mGPT/metrics/mr.py new file mode 100644 index 0000000000000000000000000000000000000000..ba5129f47012e116fd5e60e0243f8bff736cb901 --- /dev/null +++ b/mGPT/metrics/mr.py @@ -0,0 +1,97 @@ +from typing import List + +import torch +from torch import Tensor +from torchmetrics import Metric + +from .utils import * + + +# motion reconstruction metric +class MRMetrics(Metric): + + def __init__(self, + njoints, + jointstype: str = "mmm", + force_in_meter: bool = True, + align_root: bool = True, + dist_sync_on_step=True, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = 'Motion Reconstructions' + self.jointstype = jointstype + self.align_root = align_root + self.force_in_meter = force_in_meter + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.add_state("MPJPE", + default=torch.tensor([0.0]), + dist_reduce_fx="sum") + self.add_state("PAMPJPE", + default=torch.tensor([0.0]), + dist_reduce_fx="sum") + self.add_state("ACCEL", + default=torch.tensor([0.0]), + dist_reduce_fx="sum") + # todo + # self.add_state("ROOT", default=torch.tensor([0.0]), dist_reduce_fx="sum") + + self.MR_metrics = ["MPJPE", "PAMPJPE", "ACCEL"] + + # All metric + self.metrics = self.MR_metrics + + def compute(self, sanity_flag): + if self.force_in_meter: + # different jointstypes have different scale factors + # if self.jointstype == 'mmm': + # factor = 1000.0 + # elif self.jointstype == 'humanml3d': + # factor = 1000.0 * 0.75 / 480 + factor = 1000.0 + else: + factor = 1.0 + + count = self.count + count_seq = self.count_seq + mr_metrics = {} + mr_metrics["MPJPE"] = self.MPJPE / count * factor + mr_metrics["PAMPJPE"] = self.PAMPJPE / count * factor + # accel error: joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:] + # n-2 for each sequences + mr_metrics["ACCEL"] = self.ACCEL / (count - 2 * count_seq) * factor + + # Reset + self.reset() + + return mr_metrics + + def update(self, joints_rst: Tensor, joints_ref: Tensor, + lengths: List[int]): + assert joints_rst.shape == joints_ref.shape + assert joints_rst.dim() == 4 + # (bs, seq, njoint=22, 3) + + self.count += sum(lengths) + self.count_seq += len(lengths) + + # avoid cuda error of DDP in pampjpe + rst = joints_rst.detach().cpu() + ref = joints_ref.detach().cpu() + + # align root joints index + if self.align_root and self.jointstype in ['mmm', 'humanml3d']: + align_inds = [0] + else: + align_inds = None + + for i in range(len(lengths)): + self.MPJPE += torch.sum( + calc_mpjpe(rst[i], ref[i], align_inds=align_inds)) + self.PAMPJPE += torch.sum(calc_pampjpe(rst[i], ref[i])) + self.ACCEL += torch.sum(calc_accel(rst[i], ref[i])) diff --git a/mGPT/metrics/t2m.py b/mGPT/metrics/t2m.py new file mode 100644 index 0000000000000000000000000000000000000000..d7917c3c6f05330e039da03f657d2e78be10eedf --- /dev/null +++ b/mGPT/metrics/t2m.py @@ -0,0 +1,259 @@ +from typing import List +import os +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.functional import pairwise_euclidean_distance +from .utils import * +from mGPT.config import instantiate_from_config + +class TM2TMetrics(Metric): + def __init__(self, + cfg, + dataname='humanml3d', + top_k=3, + R_size=32, + diversity_times=300, + dist_sync_on_step=True, + **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.cfg = cfg + self.dataname = dataname + self.name = "matching, fid, and diversity scores" + self.top_k = top_k + self.R_size = R_size + self.text = 'lm' in cfg.TRAIN.STAGE and cfg.model.params.task == 't2m' + self.diversity_times = diversity_times + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = [] + + # Matching scores + if self.text: + self.add_state("Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("gt_Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.Matching_metrics = ["Matching_score", "gt_Matching_score"] + for k in range(1, top_k + 1): + self.add_state( + f"R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"R_precision_top_{str(k)}") + for k in range(1, top_k + 1): + self.add_state( + f"gt_R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}") + self.metrics.extend(self.Matching_metrics) + + # Fid + self.add_state("FID", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.metrics.append("FID") + + # Diversity + self.add_state("Diversity", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("gt_Diversity", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.metrics.extend(["Diversity", "gt_Diversity"]) + + # Chached batches + self.add_state("text_embeddings", default=[], dist_reduce_fx=None) + self.add_state("recmotion_embeddings", default=[], dist_reduce_fx=None) + self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx=None) + + # T2M Evaluator + self._get_t2m_evaluator(cfg) + + def _get_t2m_evaluator(self, cfg): + """ + load T2M text encoder and motion encoder for evaluating + """ + # init module + self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder) + self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder) + self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder) + + + # load pretrianed + if self.dataname == "kit": + dataname = "kit" + else: + dataname = "t2m" + + t2m_checkpoint = torch.load(os.path.join( + cfg.METRIC.TM2T.t2m_path, dataname, "text_mot_match/model/finest.tar"), + map_location="cpu") + + self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) + self.t2m_moveencoder.load_state_dict( + t2m_checkpoint["movement_encoder"]) + self.t2m_motionencoder.load_state_dict( + t2m_checkpoint["motion_encoder"]) + + # freeze params + self.t2m_textencoder.eval() + self.t2m_moveencoder.eval() + self.t2m_motionencoder.eval() + for p in self.t2m_textencoder.parameters(): + p.requires_grad = False + for p in self.t2m_moveencoder.parameters(): + p.requires_grad = False + for p in self.t2m_motionencoder.parameters(): + p.requires_grad = False + + @torch.no_grad() + def compute(self, sanity_flag): + count = self.count.item() + count_seq = self.count_seq.item() + + # Init metrics dict + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # Jump in sanity check stage + if sanity_flag: + return metrics + + # Cat cached batches and shuffle + shuffle_idx = torch.randperm(count_seq) + + all_genmotions = torch.cat(self.recmotion_embeddings, + axis=0).cpu()[shuffle_idx, :] + all_gtmotions = torch.cat(self.gtmotion_embeddings, + axis=0).cpu()[shuffle_idx, :] + + # Compute text related metrics + if self.text: + all_texts = torch.cat(self.text_embeddings, + axis=0).cpu()[shuffle_idx, :] + # Compute r-precision + assert count_seq > self.R_size + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_genmotions[i * self.R_size:(i + 1) * + self.R_size] + # dist_mat = pairwise_euclidean_distance(group_texts, group_motions) + # [bs=32, 32] + dist_mat = euclidean_distance_matrix( + group_texts, group_motions).nan_to_num() + # print(dist_mat[:5]) + self.Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, + top_k=self.top_k).sum(axis=0) + + R_count = count_seq // self.R_size * self.R_size + metrics["Matching_score"] = self.Matching_score / R_count + for k in range(self.top_k): + metrics[f"R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count + + # Compute r-precision with gt + assert count_seq > self.R_size + top_k_mat = torch.zeros((self.top_k, )) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_gtmotions[i * self.R_size:(i + 1) * + self.R_size] + # [bs=32, 32] + dist_mat = euclidean_distance_matrix( + group_texts, group_motions).nan_to_num() + # match score + self.gt_Matching_score += dist_mat.trace() + argsmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argsmax, + top_k=self.top_k).sum(axis=0) + metrics["gt_Matching_score"] = self.gt_Matching_score / R_count + for k in range(self.top_k): + metrics[f"gt_R_precision_top_{str(k+1)}"] = top_k_mat[k] / R_count + + # tensor -> numpy for FID + all_genmotions = all_genmotions.numpy() + all_gtmotions = all_gtmotions.numpy() + + # Compute fid + mu, cov = calculate_activation_statistics_np(all_genmotions) + gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov) + + # Compute diversity + assert count_seq > self.diversity_times + metrics["Diversity"] = calculate_diversity_np(all_genmotions, + self.diversity_times) + metrics["gt_Diversity"] = calculate_diversity_np( + all_gtmotions, self.diversity_times) + + # Reset + self.reset() + + return {**metrics} + + @torch.no_grad() + def update(self, + feats_ref: Tensor, + feats_rst: Tensor, + lengths_ref: List[int], + lengths_rst: List[int], + word_embs: Tensor = None, + pos_ohot: Tensor = None, + text_lengths: Tensor = None): + + self.count += sum(lengths_ref) + self.count_seq += len(lengths_ref) + + # T2m motion encoder + align_idx = np.argsort(lengths_ref)[::-1].copy() + feats_ref = feats_ref[align_idx] + lengths_ref = np.array(lengths_ref)[align_idx] + gtmotion_embeddings = self.get_motion_embeddings( + feats_ref, lengths_ref) + cache = [0] * len(lengths_ref) + for i in range(len(lengths_ref)): + cache[align_idx[i]] = gtmotion_embeddings[i:i + 1] + self.gtmotion_embeddings.extend(cache) + + align_idx = np.argsort(lengths_rst)[::-1].copy() + feats_rst = feats_rst[align_idx] + lengths_rst = np.array(lengths_rst)[align_idx] + recmotion_embeddings = self.get_motion_embeddings( + feats_rst, lengths_rst) + cache = [0] * len(lengths_rst) + for i in range(len(lengths_rst)): + cache[align_idx[i]] = recmotion_embeddings[i:i + 1] + self.recmotion_embeddings.extend(cache) + + # T2m text encoder + if self.text: + text_emb = self.t2m_textencoder(word_embs, pos_ohot, text_lengths) + text_embeddings = torch.flatten(text_emb, start_dim=1).detach() + self.text_embeddings.append(text_embeddings) + + def get_motion_embeddings(self, feats: Tensor, lengths: List[int]): + m_lens = torch.tensor(lengths) + m_lens = torch.div(m_lens, + self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + m_lens = m_lens // self.cfg.DATASET.HUMANML3D.UNIT_LEN + mov = self.t2m_moveencoder(feats[..., :-4]).detach() + emb = self.t2m_motionencoder(mov, m_lens) + + # [bs, nlatent*ndim] <= [bs, nlatent, ndim] + return torch.flatten(emb, start_dim=1).detach() diff --git a/mGPT/metrics/utils.py b/mGPT/metrics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..40e536f483f8df7b278088a6315b4c6bcd51caa6 --- /dev/null +++ b/mGPT/metrics/utils.py @@ -0,0 +1,607 @@ +import numpy as np +import scipy.linalg +import torch +from torch import linalg +import sys + + +def l2_norm(x1, x2, dim): + return torch.linalg.vector_norm(x1 - x2, ord=2, dim=dim) + + +def variance(x, T, dim): + mean = x.mean(dim) + out = (x - mean)**2 + out = out.sum(dim) + return out / (T - 1) + + +def sqrtm(input): + m = input.detach().cpu().numpy().astype(np.float64_) + sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m)).to(input) + return sqrtm + + +# (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train +def euclidean_distance_matrix(matrix1, matrix2): + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dist: N1 x N2 + dist[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * torch.mm(matrix1, matrix2.T) # shape (num_test, num_train) + d2 = torch.sum(torch.square(matrix1), axis=1, + keepdims=True) # shape (num_test, 1) + d3 = torch.sum(torch.square(matrix2), axis=1) # shape (num_train, ) + dists = torch.sqrt(d1 + d2 + d3) # broadcasting + return dists + + +def euclidean_distance_matrix_np(matrix1, matrix2): + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dist: N1 x N2 + dist[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) + d2 = np.sum(np.square(matrix1), axis=1, + keepdims=True) # shape (num_test, 1) + d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) + dists = np.sqrt(d1 + d2 + d3) # broadcasting + return dists + + +def calculate_top_k(mat, top_k): + size = mat.shape[0] + gt_mat = (torch.unsqueeze(torch.arange(size), + 1).to(mat.device).repeat_interleave(size, 1)) + bool_mat = mat == gt_mat + correct_vec = False + top_k_list = [] + for i in range(top_k): + # print(correct_vec, bool_mat[:, i]) + correct_vec = correct_vec | bool_mat[:, i] + # print(correct_vec) + top_k_list.append(correct_vec[:, None]) + top_k_mat = torch.cat(top_k_list, dim=1) + return top_k_mat + + +def calculate_activation_statistics(activations): + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + activations = activations.cpu().numpy() + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return mu, sigma + + +def calculate_activation_statistics_np(activations): + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + mu = np.mean(activations, axis=0) + cov = np.cov(activations, rowvar=False) + return mu, cov + + +# def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): +# """Numpy implementation of the Frechet Distance. +# The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) +# and X_2 ~ N(mu_2, C_2) is +# d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). +# Stable version by Dougal J. Sutherland. +# Params: +# -- mu1 : Numpy array containing the activations of a layer of the +# inception net (like returned by the function 'get_predictions') +# for generated samples. +# -- mu2 : The sample mean over activations, precalculated on an +# representative data set. +# -- sigma1: The covariance matrix over activations for generated samples. +# -- sigma2: The covariance matrix over activations, precalculated on an +# representative data set. +# Returns: +# -- : The Frechet Distance. +# """ + +# mu1 = torch.atleast_1d(mu1) +# mu2 = torch.atleast_1d(mu2) + +# sigma1 = torch.atleast_2d(sigma1) +# sigma2 = torch.atleast_2d(sigma2) + +# assert mu1.shape == mu2.shape, \ +# 'Training and test mean vectors have different lengths' +# assert sigma1.shape == sigma2.shape, \ +# 'Training and test covariances have different dimensions' + +# diff = mu1 - mu2 + +# # Product might be almost singular +# # covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False) +# covmean = sqrtm(torch.mm(sigma1,sigma2)) +# if not torch.isfinite(covmean).all(): +# msg = ('fid calculation produces singular product; ' +# 'adding %s to diagonal of cov estimates') % eps +# print(msg) +# offset = torch.eye(sigma1.shape[0]) * eps +# # covmean = sqrtm((sigma1 + offset).dot(sigma2 + offset)) +# covmean = sqrtm(torch.mm(sigma1+ offset,sigma2+ offset)) + +# # Numerical error might give slight imaginary component +# if torch.is_complex(covmean): +# if not torch.allclose(torch.diagonal(covmean).imag, 0, atol=1e-3): +# m = torch.max(torch.abs(covmean.imag)) +# raise ValueError('Imaginary component {}'.format(m)) +# covmean = covmean.real + +# tr_covmean = torch.trace(covmean) + +# return (diff.dot(diff) + torch.trace(sigma1) + +# torch.trace(sigma2) - 2 * tr_covmean) + + +def calculate_frechet_distance_np(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert (mu1.shape == mu2.shape + ), "Training and test mean vectors have different lengths" + assert (sigma1.shape == sigma2.shape + ), "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + # Product might be almost singular + covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ("fid calculation produces singular product; " + "adding %s to diagonal of cov estimates") % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + # print("Imaginary component {}".format(m)) + covmean = covmean.real + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace( + sigma2) - 2 * tr_covmean + + +def calculate_diversity(activation, diversity_times): + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + second_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + dist = linalg.norm(activation[first_indices] - activation[second_indices], + axis=1) + return dist.mean() + + +def calculate_diversity_np(activation, diversity_times): + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + second_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + dist = scipy.linalg.norm(activation[first_indices] - + activation[second_indices], + axis=1) + return dist.mean() + + +def calculate_multimodality_np(activation, multimodality_times): + assert len(activation.shape) == 3 + assert activation.shape[1] > multimodality_times + num_per_sent = activation.shape[1] + + first_dices = np.random.choice(num_per_sent, + multimodality_times, + replace=False) + second_dices = np.random.choice(num_per_sent, + multimodality_times, + replace=False) + dist = scipy.linalg.norm(activation[:, first_dices] - + activation[:, second_dices], + axis=2) + return dist.mean() + + +# motion reconstructions metrics + + +def batch_compute_similarity_transform_torch(S1, S2): + """ + Computes a similarity transform (sR, t) that takes + a set of 3D points S1 (3 x N) closest to a set of 3D points S2, + where R is an 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + """ + transposed = False + if S1.shape[0] != 3 and S1.shape[0] != 2: + S1 = S1.permute(0, 2, 1) + S2 = S2.permute(0, 2, 1) + transposed = True + assert S2.shape[1] == S1.shape[1] + + # 1. Remove mean. + mu1 = S1.mean(axis=-1, keepdims=True) + mu2 = S2.mean(axis=-1, keepdims=True) + + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = torch.sum(X1**2, dim=1).sum(dim=1) + + # 3. The outer product of X1 and X2. + K = X1.bmm(X2.permute(0, 2, 1)) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, s, V = torch.svd(K) + + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) + Z = Z.repeat(U.shape[0], 1, 1) + Z[:, -1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0, 2, 1)))) + + # Construct R. + R = V.bmm(Z.bmm(U.permute(0, 2, 1))) + + # 5. Recover scale. + scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 + + # 6. Recover translation. + t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) + + # 7. Error: + S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t + + if transposed: + S1_hat = S1_hat.permute(0, 2, 1) + + return S1_hat, (scale, R, t) + + +def compute_mpjpe(preds, + target, + valid_mask=None, + pck_joints=None, + sample_wise=True): + """ + Mean per-joint position error (i.e. mean Euclidean distance) + often referred to as "Protocol #1" in many papers. + """ + assert preds.shape == target.shape, print(preds.shape, + target.shape) # BxJx3 + mpjpe = torch.norm(preds - target, p=2, dim=-1) # BxJ + + if pck_joints is None: + if sample_wise: + mpjpe_seq = ((mpjpe * valid_mask.float()).sum(-1) / + valid_mask.float().sum(-1) + if valid_mask is not None else mpjpe.mean(-1)) + else: + mpjpe_seq = mpjpe[valid_mask] if valid_mask is not None else mpjpe + return mpjpe_seq + else: + mpjpe_pck_seq = mpjpe[:, pck_joints] + return mpjpe_pck_seq + + +def align_by_parts(joints, align_inds=None): + if align_inds is None: + return joints + pelvis = joints[:, align_inds].mean(1) + return joints - torch.unsqueeze(pelvis, dim=1) + + +def calc_mpjpe(preds, target, align_inds=[0], sample_wise=True, trans=None): + # Expects BxJx3 + valid_mask = target[:, :, 0] != -2.0 + # valid_mask = torch.BoolTensor(target[:, :, 0].shape) + if align_inds is not None: + preds_aligned = align_by_parts(preds, align_inds=align_inds) + if trans is not None: + preds_aligned += trans + target_aligned = align_by_parts(target, align_inds=align_inds) + else: + preds_aligned, target_aligned = preds, target + mpjpe_each = compute_mpjpe(preds_aligned, + target_aligned, + valid_mask=valid_mask, + sample_wise=sample_wise) + return mpjpe_each + + +def calc_accel(preds, target): + """ + Mean joint acceleration error + often referred to as "Protocol #1" in many papers. + """ + assert preds.shape == target.shape, print(preds.shape, + target.shape) # BxJx3 + assert preds.dim() == 3 + # Expects BxJx3 + # valid_mask = torch.BoolTensor(target[:, :, 0].shape) + accel_gt = target[:-2] - 2 * target[1:-1] + target[2:] + accel_pred = preds[:-2] - 2 * preds[1:-1] + preds[2:] + normed = torch.linalg.norm(accel_pred - accel_gt, dim=-1) + accel_seq = normed.mean(1) + return accel_seq + + +def calc_pampjpe(preds, target, sample_wise=True, return_transform_mat=False): + # Expects BxJx3 + target, preds = target.float(), preds.float() + # extracting the keypoints that all samples have valid annotations + # valid_mask = (target[:, :, 0] != -2.).sum(0) == len(target) + # preds_tranformed, PA_transform = batch_compute_similarity_transform_torch(preds[:, valid_mask], target[:, valid_mask]) + # pa_mpjpe_each = compute_mpjpe(preds_tranformed, target[:, valid_mask], sample_wise=sample_wise) + + preds_tranformed, PA_transform = batch_compute_similarity_transform_torch( + preds, target) + pa_mpjpe_each = compute_mpjpe(preds_tranformed, + target, + sample_wise=sample_wise) + + if return_transform_mat: + return pa_mpjpe_each, PA_transform + else: + return pa_mpjpe_each + + +# from action2motion +def calculate_diversity_multimodality(activations, + labels, + num_labels, + diversity_times=200, + multimodality_times=20): + labels = labels.long() + num_motions = activations.shape[0] # len(labels) + + diversity = 0 + + first_indices = np.random.randint(0, num_motions, diversity_times) + second_indices = np.random.randint(0, num_motions, diversity_times) + for first_idx, second_idx in zip(first_indices, second_indices): + diversity += torch.dist(activations[first_idx, :], + activations[second_idx, :]) + diversity /= diversity_times + + multimodality = 0 + label_quotas = np.zeros(num_labels) + label_quotas[labels.unique( + )] = multimodality_times # if a label does not appear in batch, its quota remains zero + while np.any(label_quotas > 0): + # print(label_quotas) + first_idx = np.random.randint(0, num_motions) + first_label = labels[first_idx] + if not label_quotas[first_label]: + continue + + second_idx = np.random.randint(0, num_motions) + second_label = labels[second_idx] + while first_label != second_label: + second_idx = np.random.randint(0, num_motions) + second_label = labels[second_idx] + + label_quotas[first_label] -= 1 + + first_activation = activations[first_idx, :] + second_activation = activations[second_idx, :] + multimodality += torch.dist(first_activation, second_activation) + + multimodality /= (multimodality_times * num_labels) + + return diversity, multimodality + + +def calculate_fid(statistics_1, statistics_2): + return calculate_frechet_distance_np(statistics_1[0], statistics_1[1], + statistics_2[0], statistics_2[1]) + + +# from: https://github.com/abdulfatir/gan-metrics-pytorch/blob/master/kid_score.py +def polynomial_mmd_averages(codes_g, + codes_r, + n_subsets=50, + subset_size=1000, + ret_var=True, + output=sys.stdout, + **kernel_args): + m = min(codes_g.shape[0], codes_r.shape[0]) + mmds = np.zeros(n_subsets) + if ret_var: + vars = np.zeros(n_subsets) + choice = np.random.choice + + replace = subset_size < len(codes_g) + + for i in range(n_subsets): + g = codes_g[choice(len(codes_g), subset_size, replace=replace)] + r = codes_r[choice(len(codes_r), subset_size, replace=replace)] + o = polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var) + if ret_var: + mmds[i], vars[i] = o + else: + mmds[i] = o + + return (mmds, vars) if ret_var else mmds + + +def polynomial_mmd(codes_g, + codes_r, + degree=3, + gamma=None, + coef0=1, + var_at_m=None, + ret_var=True): + from sklearn.metrics.pairwise import polynomial_kernel + + # use k(x, y) = (gamma + coef0)^degree + # default gamma is 1 / dim + X = codes_g + Y = codes_r + + K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0) + K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0) + K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0) + + return _mmd2_and_variance(K_XX, + K_XY, + K_YY, + var_at_m=var_at_m, + ret_var=ret_var) + + +def _mmd2_and_variance(K_XX, + K_XY, + K_YY, + unit_diagonal=False, + mmd_est='unbiased', + block_size=1024, + var_at_m=None, + ret_var=True): + # based on + # https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py + # but changed to not compute the full kernel matrix at once + m = K_XX.shape[0] + assert K_XX.shape == (m, m) + assert K_XY.shape == (m, m) + assert K_YY.shape == (m, m) + if var_at_m is None: + var_at_m = m + + # Get the various sums of kernels that we'll use + # Kts drop the diagonal, but we don't need to compute them explicitly + if unit_diagonal: + diag_X = diag_Y = 1 + sum_diag_X = sum_diag_Y = m + sum_diag2_X = sum_diag2_Y = m + else: + diag_X = np.diagonal(K_XX) + diag_Y = np.diagonal(K_YY) + + sum_diag_X = diag_X.sum() + sum_diag_Y = diag_Y.sum() + + sum_diag2_X = _sqn(diag_X) + sum_diag2_Y = _sqn(diag_Y) + + Kt_XX_sums = K_XX.sum(axis=1) - diag_X + Kt_YY_sums = K_YY.sum(axis=1) - diag_Y + K_XY_sums_0 = K_XY.sum(axis=0) + K_XY_sums_1 = K_XY.sum(axis=1) + + Kt_XX_sum = Kt_XX_sums.sum() + Kt_YY_sum = Kt_YY_sums.sum() + K_XY_sum = K_XY_sums_0.sum() + + if mmd_est == 'biased': + mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) + (Kt_YY_sum + sum_diag_Y) / + (m * m) - 2 * K_XY_sum / (m * m)) + else: + assert mmd_est in {'unbiased', 'u-statistic'} + mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1)) + if mmd_est == 'unbiased': + mmd2 -= 2 * K_XY_sum / (m * m) + else: + mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m - 1)) + + if not ret_var: + return mmd2 + + Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X + Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y + K_XY_2_sum = _sqn(K_XY) + + dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1) + dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0) + + m1 = m - 1 + m2 = m - 2 + zeta1_est = ( + 1 / (m * m1 * m2) * + (_sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum) - 1 / + (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2) + 1 / (m * m * m1) * + (_sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum) - + 2 / m**4 * K_XY_sum**2 - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) + + 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum) + zeta2_est = (1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum) - 1 / (m * m1)**2 * + (Kt_XX_sum**2 + Kt_YY_sum**2) + 2 / (m * m) * K_XY_2_sum - + 2 / m**4 * K_XY_sum**2 - 4 / (m * m * m1) * + (dot_XX_XY + dot_YY_YX) + 4 / (m**3 * m1) * + (Kt_XX_sum + Kt_YY_sum) * K_XY_sum) + var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est + + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est) + + return mmd2, var_est + + +def _sqn(arr): + flat = np.ravel(arr) + return flat.dot(flat) + + +def calculate_kid(real_activations, generated_activations): + kid_values = polynomial_mmd_averages(real_activations, + generated_activations, + n_subsets=100) + results = (kid_values[0].mean(), kid_values[0].std()) + return results diff --git a/mGPT/models/__init__.py b/mGPT/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mGPT/models/base.py b/mGPT/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4631678b6ea879249fa4a63120656a90aa24a23a --- /dev/null +++ b/mGPT/models/base.py @@ -0,0 +1,204 @@ +import os +import numpy as np +import torch +import logging +from pathlib import Path +from pytorch_lightning import LightningModule +from os.path import join as pjoin +from collections import OrderedDict +from mGPT.metrics import BaseMetrics +from mGPT.config import get_obj_from_str + + +class BaseModel(LightningModule): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # self.configure_metrics() + + # Ablation + self.test_step_outputs = [] + self.times = [] + self.rep_i = 0 + + def training_step(self, batch, batch_idx): + return self.allsplit_step("train", batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.allsplit_step("val", batch, batch_idx) + + def test_step(self, batch, batch_idx): + outputs = self.allsplit_step("test", batch, batch_idx) + self.test_step_outputs.append(outputs) + return outputs + + def predict_step(self, batch, batch_idx): + return self.forward(batch) + + def on_train_epoch_end(self): + # Log steps and losses + dico = self.step_log_dict() + # Log losses + dico.update(self.loss_log_dict('train')) + # Write to log only if not sanity check + if not self.trainer.sanity_checking: + self.log_dict(dico, sync_dist=True, rank_zero_only=True) + + def on_validation_epoch_end(self): + # Log steps and losses + dico = self.step_log_dict() + # Log losses + dico.update(self.loss_log_dict('train')) + dico.update(self.loss_log_dict('val')) + # Log metrics + dico.update(self.metrics_log_dict()) + # Write to log only if not sanity check + if not self.trainer.sanity_checking: + self.log_dict(dico, sync_dist=True, rank_zero_only=True) + + def on_test_epoch_end(self): + # Log metrics + dico = self.metrics_log_dict() + # Write to log only if not sanity check + if not self.trainer.sanity_checking: + self.log_dict(dico, sync_dist=True, rank_zero_only=True) + self.save_npy(self.test_step_outputs) + self.rep_i = self.rep_i + 1 + # Free up the memory + self.test_step_outputs.clear() + + def preprocess_state_dict(self, state_dict): + new_state_dict = OrderedDict() + + # metric_state_dict = self.metrics.state_dict() + loss_state_dict = self._losses.state_dict() + + # for k, v in metric_state_dict.items(): + # new_state_dict['metrics.' + k] = v + + for k, v in loss_state_dict.items(): + new_state_dict['_losses.' + k] = v + + for k, v in state_dict.items(): + if '_losses' not in k and 'Metrics' not in k: + new_state_dict[k] = v + + return new_state_dict + + def load_state_dict(self, state_dict, strict=True): + new_state_dict = self.preprocess_state_dict(state_dict) + super().load_state_dict(new_state_dict, strict) + + def step_log_dict(self): + return { + "epoch": float(self.trainer.current_epoch), + "step": float(self.trainer.current_epoch) + } + + def loss_log_dict(self, split: str): + losses = self._losses['losses_' + split] + loss_dict = losses.compute(split) + return loss_dict + + def metrics_log_dict(self): + + # For TM2TMetrics MM + if self.trainer.datamodule.is_mm and "TM2TMetrics" in self.hparams.metrics_dict: + metrics_dicts = ['MMMetrics'] + else: + metrics_dicts = self.hparams.metrics_dict + + # Compute all metrics + metrics_log_dict = {} + for metric in metrics_dicts: + metrics_dict = getattr( + self.metrics, + metric).compute(sanity_flag=self.trainer.sanity_checking) + metrics_log_dict.update({ + f"Metrics/{metric}": value.item() + for metric, value in metrics_dict.items() + }) + + return metrics_log_dict + + def configure_optimizers(self): + # Optimizer + optim_target = self.hparams.cfg.TRAIN.OPTIM.target + if len(optim_target.split('.')) == 1: + optim_target = 'torch.optim.' + optim_target + optimizer = get_obj_from_str(optim_target)( + params=self.parameters(), **self.hparams.cfg.TRAIN.OPTIM.params) + + # Scheduler + scheduler_target = self.hparams.cfg.TRAIN.LR_SCHEDULER.target + if len(scheduler_target.split('.')) == 1: + scheduler_target = 'torch.optim.lr_scheduler.' + scheduler_target + lr_scheduler = get_obj_from_str(scheduler_target)( + optimizer=optimizer, **self.hparams.cfg.TRAIN.LR_SCHEDULER.params) + + return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler} + + def configure_metrics(self): + self.metrics = BaseMetrics(datamodule=self.datamodule, **self.hparams) + + def save_npy(self, outputs): + cfg = self.hparams.cfg + output_dir = Path( + os.path.join( + cfg.FOLDER, + str(cfg.model.target.split('.')[-2].lower()), + str(cfg.NAME), + "samples_" + cfg.TIME, + )) + if cfg.TEST.SAVE_PREDICTIONS: + lengths = [i[1] for i in outputs] + outputs = [i[0] for i in outputs] + + if cfg.TEST.DATASETS[0].lower() in ["humanml3d", "kit"]: + keyids = self.trainer.datamodule.test_dataset.name_list + for i in range(len(outputs)): + for bid in range( + min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])): + keyid = keyids[i * cfg.TEST.BATCH_SIZE + bid] + data = self.trainer.datamodule.test_dataset.data_dict[ + keyid] + + motion = torch.tensor(data['motion'], + device=outputs[i].device) + motion = self.datamodule.normalize(motion) + length = data['length'] + text_list = data['text'] + gen_joints = outputs[i][bid][:lengths[i][bid]].cpu( + ).numpy() + if cfg.TEST.REPLICATION_TIMES > 1: + name = f"{keyid}.npy" + else: + name = f"{keyid}.npy" + # save predictions results + npypath = output_dir / name + np.save(npypath, gen_joints) + npypath = output_dir / f"{keyid}_gt.npy" + joints = self.feats2joints(motion).cpu().numpy() + np.save(npypath, joints) + + with open(output_dir / f"{keyid}.txt", "a") as f: + for text in text_list: + f.write(f"{text['caption']}\n") + + elif cfg.TEST.DATASETS[0].lower() in ["humanact12", "uestc"]: + keyids = range(len(self.trainer.datamodule.test_dataset)) + for i in range(len(outputs)): + for bid in range( + min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])): + keyid = keyids[i * cfg.TEST.BATCH_SIZE + bid] + gen_joints = outputs[i][bid].cpu() + gen_joints = gen_joints.permute(2, 0, + 1)[:lengths[i][bid], + ...].numpy() + if cfg.TEST.REPLICATION_TIMES > 1: + name = f"{keyid}_{self.rep_i}" + else: + name = f"{keyid}.npy" + # save predictions results + npypath = output_dir / name + np.save(npypath, gen_joints) diff --git a/mGPT/models/build_model.py b/mGPT/models/build_model.py new file mode 100644 index 0000000000000000000000000000000000000000..53c9effa160be57ffe235180e1c5daa85825c170 --- /dev/null +++ b/mGPT/models/build_model.py @@ -0,0 +1,8 @@ +from omegaconf import OmegaConf +from mGPT.config import instantiate_from_config + +def build_model(cfg, datamodule): + model_config = OmegaConf.to_container(cfg.model, resolve=True) + model_config['params']['cfg'] = cfg + model_config['params']['datamodule'] = datamodule + return instantiate_from_config(model_config) diff --git a/mGPT/models/mgpt.py b/mGPT/models/mgpt.py new file mode 100644 index 0000000000000000000000000000000000000000..c8db4a45978020bf712bfae2757b20fc283b13de --- /dev/null +++ b/mGPT/models/mgpt.py @@ -0,0 +1,494 @@ +import numpy as np +import os +import random +import torch +import time +from mGPT.config import instantiate_from_config +from os.path import join as pjoin +from mGPT.losses.mgpt import GPTLosses +from mGPT.models.base import BaseModel +from .base import BaseModel +import json +import mGPT.render.matplot.plot_3d_global as plot_3d + + +class MotionGPT(BaseModel): + """ + Stage 1 Motion Tokenizer + Stage 2 Motion-language pretrian + Stage 3 Motion-language instruction tuning + """ + + def __init__(self, + cfg, + datamodule, + lm, + motion_vae, + codebook_size=512, + stage='vae', + debug=True, + condition='text', + task='t2m', + metrics_dict=['TM2TMetrics'], + **kwargs): + + self.save_hyperparameters(ignore='datamodule', logger=False) + self.datamodule = datamodule + super().__init__() + + # Instantiate motion tokenizer + if motion_vae != None: + self.vae = instantiate_from_config(motion_vae) + + # Instantiate motion-language model + self.lm = instantiate_from_config(lm) + + # Freeze the motion tokenizer for lm training + if 'lm' in self.hparams.stage: + self.vae.training = False + for p in self.vae.parameters(): + p.requires_grad = False + + # Instantiate the losses + self._losses = torch.nn.ModuleDict({ + split: GPTLosses(cfg, self.hparams.stage, self.datamodule.njoints) + for split in ["losses_train", "losses_test", "losses_val"] + }) + + # Data transform + self.feats2joints = datamodule.feats2joints + + # Count codebook frequency + self.codePred = [] + self.codeFrequency = torch.zeros((self.hparams.codebook_size, )) + + def forward(self, batch, task="t2m"): + texts = batch["text"] + lengths_ref = batch["length"] + + # Forward + # texts = ['Generate motion: ' + text for text in texts] + outputs, output_texts = self.lm.generate_direct(texts, do_sample=True) + + # Motion Decode + feats_rst_lst = [] + lengths = [] + max_len = 0 + + for i in range(len(texts)): + if task == "pred": + motion = self.vae.decode( + torch.cat((batch["motion"][i], outputs[i]))) + elif task in ["t2m", "m2t", "inbetween"]: + motion = self.vae.decode(outputs[i]) + # motion = self.datamodule.denormalize(motion) + lengths.append(motion.shape[1]) + else: + raise NotImplementedError + + if motion.shape[1] > max_len: + max_len = motion.shape[1] + + if task in ["t2m", "m2t", "pred"]: + feats_rst_lst.append(motion) + + elif task == "inbetween": + motion = torch.cat( + (batch["motion_heading"][i][None], + motion[:, lengths_ref[i] // 4:lengths_ref[i] // 4 * 3, + ...], batch["motion_tailing"][i][None]), + dim=1) + feats_rst_lst.append(motion) + + feats_rst = torch.zeros( + (len(feats_rst_lst), max_len, motion.shape[-1])).to(self.device) + + # padding and concat + for i in range(len(feats_rst_lst)): + feats_rst[i, :feats_rst_lst[i].shape[1], ...] = feats_rst_lst[i] + + # Recover joints for evaluation + joints_rst = self.feats2joints(feats_rst) + + # return set + outputs = { + "texts": output_texts, + "feats": feats_rst, + "joints": joints_rst, + "length": lengths + } + + return outputs + + def train_lm_forward(self, batch): + tokens_ref = batch["motion"] + texts = batch["text"] + lengths = batch["length"] + tasks = batch["tasks"] + all_captions = batch['all_captions'] + if self.hparams.condition == 'caption': + texts = [random.choice(all_captions[i]) for i in range(len(texts))] + + # LLM Forward + outputs = self.lm(texts, tokens_ref, lengths, tasks) + # outputs = self.t2m_gpt.generate(texts) + return {'outputs': outputs} + + @torch.no_grad() + def val_t2m_forward(self, batch): + feats_ref = batch["motion"] + texts = batch["text"] + lengths = batch["length"] + tasks = None + if self.trainer.datamodule.is_mm: + texts = texts * self.hparams.cfg.METRIC.MM_NUM_REPEATS + feats_ref = feats_ref.repeat_interleave( + self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) + lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS + instructions = pjoin(self.datamodule.hparams.data_root, + 'template_instructions.json') + instructions = json.load(open(instructions, 'r')) + tasks = [instructions["Text-to-Motion"]["caption"]] * len(texts) + + if self.hparams.condition == 'caption': + tasks = [{ + 'input': [''], + 'output': [''] + }] * len(texts) + + if self.hparams.cfg.DATASET.TASK_PATH: + instructions = pjoin(self.hparams.cfg.DATASET.TASK_PATH) + instructions = json.load(open(instructions, 'r')) + tasks = [instructions["Text-to-Motion"]["t2m"]] * len(texts) + + min_len = lengths.copy() + # Forward + outputs = self.lm.generate_conditional(texts, + lengths=lengths, + stage='test', + tasks=tasks) + + # Motion Decode + feats_rst = torch.zeros_like(feats_ref) + + for i in range(len(texts)): + outputs[i] = torch.clamp(outputs[i], + 0, + self.hparams.codebook_size - 1, + out=None) + + if len(outputs[i]) > 1: + motion = self.vae.decode(outputs[i]) + else: + motion = torch.zeros_like(feats_ref[i:i + 1, ...]) + + min_len[i] = min(motion.shape[1], lengths[i]) + + # Cut Motion + feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]] + + # Recover joints for evaluation + joints_ref = self.feats2joints(feats_ref) + joints_rst = self.feats2joints(feats_rst) + + # Renorm for evaluation + feats_ref = self.datamodule.renorm4t2m(feats_ref) + feats_rst = self.datamodule.renorm4t2m(feats_rst) + + # return set + rs_set = { + "m_ref": feats_ref, + "m_rst": feats_rst, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + "length": min_len + # "length": lengths + } + + return rs_set + + @torch.no_grad() + def val_m2t_forward(self, batch): + self.hparams.metrics_dict = [] + + feats_ref = batch["motion"] + texts = batch["text"] + lengths = batch["length"] + all_captions = batch['all_captions'] + + # Motion Encode + motion_tokens = [] + lengths_tokens = [] + for i in range(len(feats_ref)): + motion_token, _ = self.vae.encode(feats_ref[i:i + 1]) + motion_tokens.append(motion_token[0]) + lengths_tokens.append(motion_token.shape[1]) + + # Forward + outputs = self.lm.generate_conditional(motion_tokens=motion_tokens, + lengths=lengths_tokens, + task="m2t", + stage='test') + + # return set + rs_set = { + "m_ref": feats_ref, + "t_ref": all_captions, + # "t_ref": texts, + "t_pred": outputs, + "length": lengths + } + + return rs_set + + @torch.no_grad() + def val_m2m_forward(self, batch, task="pred"): + feats_ref = batch["motion"] + lengths = batch["length"] + + # Motion Encode + motion_tokens = [] + lengths_tokens = [] + for i in range(len(feats_ref)): + motion_token, _ = self.vae.encode(feats_ref[i:i + 1]) + motion_tokens.append(motion_token[0]) + + # Forward + outputs = self.lm.generate_conditional(motion_tokens=motion_tokens, + lengths=lengths, + task=task, + stage='test') + + # Motion Decode + feats_rst = torch.zeros_like(feats_ref) + min_len = lengths.copy() + + for i in range(len(lengths)): + outputs[i] = torch.clamp(outputs[i], + 0, + self.hparams.codebook_size - 1, + out=None) + + if len(outputs[i]) > 1: + motion = self.vae.decode(outputs[i]) + else: + motion = torch.zeros_like(feats_ref[i:i + 1, ...]) + + min_len[i] = min(motion.shape[1], lengths[i]) + + # Cut Motion + feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]] + + # Recover joints for evaluation + joints_ref = self.feats2joints(feats_ref) + joints_rst = self.feats2joints(feats_rst) + + # Renorm for evaluation + feats_ref = self.datamodule.renorm4t2m(feats_ref) + feats_rst = self.datamodule.renorm4t2m(feats_rst) + + # return set + rs_set = { + "m_ref": feats_ref, + "m_rst": feats_rst, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + "length": min_len + # "length": lengths + } + + return rs_set + + def train_vae_forward(self, batch): + # batch detach + feats_ref = batch["motion"] + joints_ref = self.feats2joints(feats_ref) + # motion encode & decode + feats_rst, loss_commit, perplexity = self.vae(feats_ref) + joints_rst = self.feats2joints(feats_rst) + # return set + rs_set = { + "m_ref": feats_ref, + "joints_ref": joints_ref, + "m_rst": feats_rst, + "joints_rst": joints_rst, + "loss_commit": loss_commit, + "perplexity": perplexity, + } + return rs_set + + @torch.no_grad() + def val_vae_forward(self, batch, split="train"): + # Detach batch + feats_ref = batch["motion"] + lengths = batch["length"] + + # Repeat for multimodal evaluation + if self.trainer.datamodule.is_mm: + feats_ref = feats_ref.repeat_interleave( + self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) + lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS + + # Motion encode & decode + feats_rst = torch.zeros_like(feats_ref) + + for i in range(len(feats_ref)): + if lengths[i] == 0: + continue + feats_pred, _, _ = self.vae(feats_ref[i:i + 1, :lengths[i]]) + feats_rst[i:i + 1, :feats_pred.shape[1], :] = feats_pred + + code_pred, _ = self.vae.encode(feats_ref[i:i + 1, :lengths[i]]) + + # codeFre_pred = torch.bincount(code_pred[0], + # minlength=self.hparams.codebook_size).to( + # self.codeFrequency.device) + # self.codePred.append(code_pred[0]) + # self.codeFrequency += codeFre_pred + + # np.save('../memData/results/codeFrequency.npy', + # self.codeFrequency.cpu().numpy()) + + # Recover joints for evaluation + joints_ref = self.feats2joints(feats_ref) + joints_rst = self.feats2joints(feats_rst) + + # Renorm for evaluation + feats_ref = self.datamodule.renorm4t2m(feats_ref) + feats_rst = self.datamodule.renorm4t2m(feats_rst) + + # Return set + rs_set = { + "m_ref": feats_ref, + "joints_ref": joints_ref, + "m_rst": feats_rst, + "joints_rst": joints_rst, + "length": lengths, + } + + return rs_set + + + def allsplit_step(self, split: str, batch, batch_idx): + # Compute the losses + loss = None + + if self.hparams.stage == "vae" and split in ["train", "val"]: + rs_set = self.train_vae_forward(batch) + loss = self._losses['losses_' + split].update(rs_set) + elif self.hparams.stage in ["lm_instruct", "lm_pretrain" + ] and split in ["train"]: + rs_set = self.train_lm_forward(batch) + loss = self._losses['losses_' + split].update(rs_set) + elif self.hparams.stage == 'lm_rl' and split in ['train']: + rs_set = self.train_rl_forward(batch) + loss = None + + # Compute the metrics + if split in ["val", "test"]: + if self.hparams.stage == "vae": + rs_set = self.val_vae_forward(batch, split) + elif self.hparams.stage in ["lm_instruct", "lm_pretrain", "lm_rl"]: + if self.hparams.task == "t2m": + rs_set = self.val_t2m_forward(batch) + elif self.hparams.task == "m2t": + rs_set = self.val_m2t_forward(batch) + elif self.hparams.task in ["m2m", "pred", "inbetween"]: + rs_set = self.val_m2m_forward(batch, self.hparams.task) + + if self.hparams.task not in ["m2t"]: + # MultiModality evaluation sperately + if self.trainer.datamodule.is_mm: + metrics_dicts = ['MMMetrics'] + else: + metrics_dicts = self.hparams.metrics_dict + + if self.hparams.task not in ['pred', 'inbetween']: + metrics_dicts.remove('PredMetrics') + + for metric in metrics_dicts: + lengths = batch['length'] + if metric == "TemosMetric": + getattr(self.metrics, + metric).update(rs_set["joints_rst"], + rs_set["joints_ref"], lengths) + elif metric == "TM2TMetrics": + if self.hparams.stage in [ + "lm_instruct", "lm_pretrain", "lm_rl" + ]: + word_embs = batch['word_embs'] + pos_ohot = batch['pos_ohot'] + text_lengths = batch['text_len'] + if self.trainer.datamodule.is_mm: + word_embs = word_embs.repeat_interleave( + self.hparams.cfg.METRIC.MM_NUM_REPEATS, + dim=0) + pos_ohot = pos_ohot.repeat_interleave( + self.hparams.cfg.METRIC.MM_NUM_REPEATS, + dim=0) + text_lengths = text_lengths.repeat_interleave( + self.hparams.cfg.METRIC.MM_NUM_REPEATS, + dim=0) + else: + word_embs = None + pos_ohot = None + text_lengths = None + + getattr(self.metrics, metric).update( + feats_ref=rs_set["m_ref"], + feats_rst=rs_set["m_rst"], + lengths_ref=lengths, + lengths_rst=rs_set['length'], + word_embs=word_embs, + pos_ohot=pos_ohot, + text_lengths=text_lengths, + ) + elif metric == "UncondMetrics": + getattr(self.metrics, metric).update( + recmotion_embeddings=rs_set["lat_rm"], + gtmotion_embeddings=rs_set["lat_m"], + lengths=lengths, + ) + elif metric == "MRMetrics": + getattr(self.metrics, + metric).update(rs_set["joints_rst"], + rs_set["joints_ref"], lengths) + elif metric == "PredMetrics": + getattr(self.metrics, + metric).update(rs_set["joints_rst"], + rs_set["joints_ref"], lengths) + elif metric == "MMMetrics": + # pass + getattr(self.metrics, + metric).update(rs_set["m_rst"], + rs_set['length']) + else: + raise TypeError(f"Not support this metric {metric}") + + elif self.hparams.task == "m2t" and self.hparams.stage in [ + "lm_instruct", "lm_pretrain", "lm_rl" + ]: + self.hparams.metrics_dict = metrics_dicts = ['M2TMetrics'] + for metric in metrics_dicts: + if metric == "M2TMetrics": + getattr(self.metrics, metric).update( + feats_ref=rs_set["m_ref"], + pred_texts=rs_set["t_pred"], + gt_texts=batch["all_captions"], + lengths=rs_set['length'], + word_embs=batch["word_embs"], + pos_ohot=batch["pos_ohot"], + text_lengths=batch["text_len"], + ) + + # return forward output rather than loss during test + if split in ["test"]: + if self.hparams.task == "t2m": + return rs_set["joints_rst"], rs_set["length"], rs_set[ + "joints_ref"] + # pass + elif self.hparams.task == "m2t": + return rs_set["t_pred"], batch["length"] + # return batch["length"] + + return loss diff --git a/mGPT/models/utils/__init__.py b/mGPT/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mGPT/models/utils/adain.py b/mGPT/models/utils/adain.py new file mode 100644 index 0000000000000000000000000000000000000000..3588f33e19fa3434ee2801f941c40566923abf41 --- /dev/null +++ b/mGPT/models/utils/adain.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class AdaptiveInstanceNorm1d(nn.Module): + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super(AdaptiveInstanceNorm1d, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = None + self.bias = None + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + + def forward(self, x, direct_weighting=False, no_std=False): + assert self.weight is not None and \ + self.bias is not None, "Please assign AdaIN weight first" + # (bs, nfeats, nframe) <= (nframe, bs, nfeats) + x = x.permute(1,2,0) + + b, c = x.size(0), x.size(1) # batch size & channels + running_mean = self.running_mean.repeat(b) + running_var = self.running_var.repeat(b) + # self.weight = torch.ones_like(self.weight) + + if direct_weighting: + x_reshaped = x.contiguous().view(b * c) + if no_std: + out = x_reshaped + self.bias + else: + out = x_reshaped.mul(self.weight) + self.bias + out = out.view(b, c, *x.size()[2:]) + else: + x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) + out = F.batch_norm( + x_reshaped, running_mean, running_var, self.weight, self.bias, + True, self.momentum, self.eps) + out = out.view(b, c, *x.size()[2:]) + + # (nframe, bs, nfeats) <= (bs, nfeats, nframe) + out = out.permute(2,0,1) + return out + + def __repr__(self): + return self.__class__.__name__ + '(' + str(self.num_features) + ')' + +def assign_adain_params(adain_params, model): + # assign the adain_params to the AdaIN layers in model + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm1d": + mean = adain_params[: , : m.num_features] + std = adain_params[: , m.num_features: 2 * m.num_features] + m.bias = mean.contiguous().view(-1) + m.weight = std.contiguous().view(-1) + if adain_params.size(1) > 2 * m.num_features: + adain_params = adain_params[: , 2 * m.num_features:] + + +def get_num_adain_params(model): + # return the number of AdaIN parameters needed by the model + num_adain_params = 0 + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm1d": + num_adain_params += 2 * m.num_features + return num_adain_params diff --git a/mGPT/models/utils/blocks.py b/mGPT/models/utils/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e657b3863f4b74530e7afbb66973ccf24c18ff50 --- /dev/null +++ b/mGPT/models/utils/blocks.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mGPT.models.notused import AdaptiveInstanceNorm1d + + +class MLP(nn.Module): + + def __init__(self, cfg, out_dim, is_init): + super(MLP, self).__init__() + dims = cfg.MODEL.MOTION_DECODER.MLP_DIM + n_blk = len(dims) + norm = 'none' + acti = 'lrelu' + + layers = [] + for i in range(n_blk - 1): + layers += LinearBlock(dims[i], dims[i + 1], norm=norm, acti=acti) + layers += LinearBlock(dims[-1], out_dim, norm='none', acti='none') + self.model = nn.Sequential(*layers) + + if is_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + #nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(m.weight, 1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + return self.model(x.view(x.size(0), -1)) + + +def ZeroPad1d(sizes): + return nn.ConstantPad1d(sizes, 0) + + +def get_acti_layer(acti='relu', inplace=True): + + if acti == 'relu': + return [nn.ReLU(inplace=inplace)] + elif acti == 'lrelu': + return [nn.LeakyReLU(0.2, inplace=inplace)] + elif acti == 'tanh': + return [nn.Tanh()] + elif acti == 'none': + return [] + else: + assert 0, "Unsupported activation: {}".format(acti) + + +def get_norm_layer(norm='none', norm_dim=None): + + if norm == 'bn': + return [nn.BatchNorm1d(norm_dim)] + elif norm == 'in': + # return [nn.InstanceNorm1d(norm_dim, affine=False)] # for rt42! + return [nn.InstanceNorm1d(norm_dim, affine=True)] + elif norm == 'adain': + return [AdaptiveInstanceNorm1d(norm_dim)] + elif norm == 'none': + return [] + else: + assert 0, "Unsupported normalization: {}".format(norm) + + +def get_dropout_layer(dropout=None): + if dropout is not None: + return [nn.Dropout(p=dropout)] + else: + return [] + + +def ConvLayers(kernel_size, + in_channels, + out_channels, + stride=1, + pad_type='reflect', + use_bias=True): + """ + returns a list of [pad, conv] => should be += to some list, then apply sequential + """ + + if pad_type == 'reflect': + pad = nn.ReflectionPad1d + elif pad_type == 'replicate': + pad = nn.ReplicationPad1d + elif pad_type == 'zero': + pad = ZeroPad1d + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + pad_l = (kernel_size - 1) // 2 + pad_r = kernel_size - 1 - pad_l + return [ + pad((pad_l, pad_r)), + nn.Conv1d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + bias=use_bias) + ] + + +def ConvBlock(kernel_size, + in_channels, + out_channels, + stride=1, + pad_type='reflect', + dropout=None, + norm='none', + acti='lrelu', + acti_first=False, + use_bias=True, + inplace=True): + """ + returns a list of [pad, conv, norm, acti] or [acti, pad, conv, norm] + """ + + layers = ConvLayers(kernel_size, + in_channels, + out_channels, + stride=stride, + pad_type=pad_type, + use_bias=use_bias) + layers += get_dropout_layer(dropout) + layers += get_norm_layer(norm, norm_dim=out_channels) + acti_layers = get_acti_layer(acti, inplace=inplace) + + if acti_first: + return acti_layers + layers + else: + return layers + acti_layers + + +def LinearBlock(in_dim, out_dim, dropout=None, norm='none', acti='relu'): + + use_bias = True + layers = [] + layers.append(nn.Linear(in_dim, out_dim, bias=use_bias)) + layers += get_dropout_layer(dropout) + layers += get_norm_layer(norm, norm_dim=out_dim) + layers += get_acti_layer(acti) + + return layers diff --git a/mGPT/models/utils/cross_attention.py b/mGPT/models/utils/cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..deb1f053e575bd0940d12a9cc526a44f689f24c0 --- /dev/null +++ b/mGPT/models/utils/cross_attention.py @@ -0,0 +1,412 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import List, Optional +from numpy import block + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class SkipTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.d_model = encoder_layer.d_model + + self.num_layers = num_layers + self.norm = norm + + assert num_layers % 2 == 1 + + num_block = (num_layers-1)//2 + self.input_blocks = _get_clones(encoder_layer, num_block) + self.middle_block = _get_clone(encoder_layer) + self.output_blocks = _get_clones(encoder_layer, num_block) + self.linear_blocks = _get_clones(nn.Linear(2*self.d_model, self.d_model), num_block) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + x = src + + xs = [] + for module in self.input_blocks: + x = module(x, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + xs.append(x) + + x = self.middle_block(x, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + for (module, linear) in zip(self.output_blocks, self.linear_blocks): + x = torch.cat([x, xs.pop()], dim=-1) + x = linear(x) + x = module(x, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + x = self.norm(x) + return x + +class SkipTransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None): + super().__init__() + self.d_model = decoder_layer.d_model + + self.num_layers = num_layers + self.norm = norm + + assert num_layers % 2 == 1 + + num_block = (num_layers-1)//2 + self.input_blocks = _get_clones(decoder_layer, num_block) + self.middle_block = _get_clone(decoder_layer) + self.output_blocks = _get_clones(decoder_layer, num_block) + self.linear_blocks = _get_clones(nn.Linear(2*self.d_model, self.d_model), num_block) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + x = tgt + + xs = [] + for module in self.input_blocks: + x = module(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + xs.append(x) + + x = self.middle_block(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + + for (module, linear) in zip(self.output_blocks, self.linear_blocks): + x = torch.cat([x, xs.pop()], dim=-1) + x = linear(x) + x = module(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + + if self.norm is not None: + x = self.norm(x) + + return x + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + mask = mask.flatten(1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) + return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.d_model = d_model + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.d_model = d_model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clone(module): + return copy.deepcopy(module) + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") \ No newline at end of file diff --git a/mGPT/models/utils/position_encoding.py b/mGPT/models/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..669c67707fdc007f4dfb956bc04ba23bc31aa157 --- /dev/null +++ b/mGPT/models/utils/position_encoding.py @@ -0,0 +1,192 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +from typing import List, Optional + +import numpy as np +import torch +from torch import Tensor, nn + +# from util.misc import NestedTensor + + +class NestedTensor(object): + + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, + num_pos_feats=64, + temperature=10000, + normalize=False, + scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, + dtype=torch.float32, + device=x.device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], + dim=-1).permute(2, 0, 1).unsqueeze(0).repeat( + x.shape[0], 1, 1, 1) + return pos + + +class PositionEmbeddingSine1D(nn.Module): + + def __init__(self, d_model, max_len=500, batch_first=False): + super().__init__() + self.batch_first = batch_first + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + # not used in the final model + if self.batch_first: + pos = self.pe.permute(1, 0, 2)[:, :x.shape[1], :] + else: + pos = self.pe[:x.shape[0], :] + return pos + + +class PositionEmbeddingLearned1D(nn.Module): + + def __init__(self, d_model, max_len=500, batch_first=False): + super().__init__() + self.batch_first = batch_first + # self.dropout = nn.Dropout(p=dropout) + + self.pe = nn.Parameter(torch.zeros(max_len, 1, d_model)) + # self.pe = pe.unsqueeze(0).transpose(0, 1) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.pe) + + def forward(self, x): + # not used in the final model + if self.batch_first: + pos = self.pe.permute(1, 0, 2)[:, :x.shape[1], :] + else: + x = x + self.pe[:x.shape[0], :] + return x + # return self.dropout(x) + + +def build_position_encoding(N_steps, + position_embedding="sine", + embedding_dim="1D"): + # N_steps = hidden_dim // 2 + if embedding_dim == "1D": + if position_embedding in ('v2', 'sine'): + position_embedding = PositionEmbeddingSine1D(N_steps) + elif position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned1D(N_steps) + else: + raise ValueError(f"not supported {position_embedding}") + elif embedding_dim == "2D": + if position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {position_embedding}") + else: + raise ValueError(f"not supported {embedding_dim}") + + return position_embedding diff --git a/mGPT/models/utils/position_encoding_layer.py b/mGPT/models/utils/position_encoding_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..699c860bf5d28c384390196b086d93552b2cff64 --- /dev/null +++ b/mGPT/models/utils/position_encoding_layer.py @@ -0,0 +1,30 @@ +import numpy as np +import torch +from torch import nn + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False): + super().__init__() + self.batch_first = batch_first + + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange( + 0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer("pe", pe) + + def forward(self, x): + # not used in the final model + if self.batch_first: + x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] + else: + x = x + self.pe[: x.shape[0], :] + return self.dropout(x) diff --git a/mGPT/models/utils/tools.py b/mGPT/models/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..89ecab5616c1f0d46ed5bc9b348c5e6ad3ee603d --- /dev/null +++ b/mGPT/models/utils/tools.py @@ -0,0 +1,37 @@ +import torch.nn as nn + +def remove_padding(tensors, lengths): + return [tensor[:tensor_length] for tensor, tensor_length in zip(tensors, lengths)] + +class AutoParams(nn.Module): + def __init__(self, **kargs): + try: + for param in self.needed_params: + if param in kargs: + setattr(self, param, kargs[param]) + else: + raise ValueError(f"{param} is needed.") + except : + pass + + try: + for param, default in self.optional_params.items(): + if param in kargs and kargs[param] is not None: + setattr(self, param, kargs[param]) + else: + setattr(self, param, default) + except : + pass + super().__init__() + + +# taken from joeynmt repo +def freeze_params(module: nn.Module) -> None: + """ + Freeze the parameters of this module, + i.e. do not update them during training + + :param module: freeze parameters of this module + """ + for _, p in module.named_parameters(): + p.requires_grad = False diff --git a/mGPT/render/__init__.py b/mGPT/render/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mGPT/render/anim.py b/mGPT/render/anim.py new file mode 100644 index 0000000000000000000000000000000000000000..4a78fecf801c363092e2bc6e7fe02827ba5480ae --- /dev/null +++ b/mGPT/render/anim.py @@ -0,0 +1,155 @@ +# Inspired by +# - https://github.com/anindita127/Complextext2animation/blob/main/src/utils/visualization.py +# - https://github.com/facebookresearch/QuaterNet/blob/main/common/visualization.py + +from typing import List, Tuple +import numpy as np +from mGPT.utils.joints import mmm_kinematic_tree, mmm_to_smplh_scaling_factor + +mmm_colors = ['black', 'magenta', 'red', 'green', 'blue'] + + +def init_axis(fig, title, radius=1.5, dist=10): + ax = fig.add_subplot(1, 1, 1, projection='3d') + ax.view_init(elev=20., azim=-60) + + fact = 2 + ax.set_xlim3d([-radius / fact, radius / fact]) + ax.set_ylim3d([-radius / fact, radius / fact]) + ax.set_zlim3d([0, radius]) + + ax.set_aspect('auto') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + ax.set_axis_off() + + ax.dist = dist + ax.grid(b=False) + + ax.set_title(title, loc='center', wrap=True) + return ax + + +def plot_floor(ax, minx, maxx, miny, maxy, minz): + from mpl_toolkits.mplot3d.art3d import Poly3DCollection + # Plot a plane XZ + verts = [ + [minx, miny, minz], + [minx, maxy, minz], + [maxx, maxy, minz], + [maxx, miny, minz] + ] + xz_plane = Poly3DCollection([verts], zorder=1) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 1)) + ax.add_collection3d(xz_plane) + + # Plot a bigger square plane XZ + radius = max((maxx - minx), (maxy - miny)) + + # center +- radius + minx_all = (maxx + minx) / 2 - radius + maxx_all = (maxx + minx) / 2 + radius + + miny_all = (maxy + miny) / 2 - radius + maxy_all = (maxy + miny) / 2 + radius + + verts = [ + [minx_all, miny_all, minz], + [minx_all, maxy_all, minz], + [maxx_all, maxy_all, minz], + [maxx_all, miny_all, minz] + ] + xz_plane = Poly3DCollection([verts], zorder=1) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) + ax.add_collection3d(xz_plane) + return ax + + +def update_camera(ax, root, radius=1.5): + fact = 2 + ax.set_xlim3d([-radius / fact + root[0], radius / fact + root[0]]) + ax.set_ylim3d([-radius / fact + root[1], radius / fact + root[1]]) + + +def render_animation(joints: np.ndarray, output: str = "notebook", title: str = "", + fps: float = 12.5, + kinematic_tree: List[List[int]] = mmm_kinematic_tree, + colors: List[str] = mmm_colors, + figsize: Tuple[int] = (4, 4), + fontsize: int = 15): + import matplotlib.pyplot as plt + from matplotlib.animation import FuncAnimation + import matplotlib.patheffects as pe + plt.rcParams.update({'font.size': fontsize}) + + # Z is gravity here + x, y, z = 0, 1, 2 + + # Convert mmm joints for visualization + # into smpl-h "scale" and axis + joints = joints.copy()[..., [2, 0, 1]] * mmm_to_smplh_scaling_factor + + # Create a figure and initialize 3d plot + fig = plt.figure(figsize=figsize) + ax = init_axis(fig, title) + + # Create spline line + trajectory = joints[:, 0, [x, y]] + avg_segment_length = np.mean(np.linalg.norm(np.diff(trajectory, axis=0), axis=1)) + 1e-3 + draw_offset = int(25 / avg_segment_length) + spline_line, = ax.plot(*trajectory.T, zorder=10, color="white") + + # Create a floor + minx, miny, _ = joints.min(axis=(0, 1)) + maxx, maxy, _ = joints.max(axis=(0, 1)) + plot_floor(ax, minx, maxx, miny, maxy, 0) + + # Put the character on the floor + height_offset = np.min(joints[:, :, z]) # Min height + joints = joints.copy() + joints[:, :, z] -= height_offset + + # Initialization for redrawing + lines = [] + initialized = False + + def update(frame): + nonlocal initialized + skeleton = joints[frame] + + root = skeleton[0] + update_camera(ax, root) + + for index, (chain, color) in enumerate(zip(reversed(kinematic_tree), reversed(colors))): + if not initialized: + lines.append(ax.plot(skeleton[chain, x], + skeleton[chain, y], + skeleton[chain, z], linewidth=8.0, color=color, zorder=20, + path_effects=[pe.SimpleLineShadow(), pe.Normal()])) + + else: + lines[index][0].set_xdata(skeleton[chain, x]) + lines[index][0].set_ydata(skeleton[chain, y]) + lines[index][0].set_3d_properties(skeleton[chain, z]) + + left = max(frame - draw_offset, 0) + right = min(frame + draw_offset, trajectory.shape[0]) + + spline_line.set_xdata(trajectory[left:right, 0]) + spline_line.set_ydata(trajectory[left:right, 1]) + spline_line.set_3d_properties(np.zeros_like(trajectory[left:right, 0])) + initialized = True + + fig.tight_layout() + frames = joints.shape[0] + anim = FuncAnimation(fig, update, frames=frames, interval=1000 / fps, repeat=False) + + if output == "notebook": + from IPython.display import HTML + HTML(anim.to_jshtml()) + else: + anim.save(output, writer='ffmpeg', fps=fps) + + plt.close() diff --git a/mGPT/render/blender/__init__.py b/mGPT/render/blender/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a82255db45b763479586f83f0b7c904387b814ba --- /dev/null +++ b/mGPT/render/blender/__init__.py @@ -0,0 +1 @@ +from .render import render diff --git a/mGPT/render/blender/camera.py b/mGPT/render/blender/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..ee037c22ba76fb5cd501f7656ea55e4ed46d3edd --- /dev/null +++ b/mGPT/render/blender/camera.py @@ -0,0 +1,52 @@ +import bpy + + +class Camera: + def __init__(self, *, first_root, mode, is_mesh): + camera = bpy.data.objects['Camera'] + + ## initial position + camera.location.x = 7.36 + camera.location.y = -6.93 + if is_mesh: + # camera.location.z = 5.45 + camera.location.z = 5.6 + else: + camera.location.z = 5.2 + + # wider point of view + if mode == "sequence": + if is_mesh: + camera.data.lens = 65 + else: + camera.data.lens = 85 + elif mode == "frame": + if is_mesh: + camera.data.lens = 130 + else: + camera.data.lens = 85 + elif mode == "video": + if is_mesh: + camera.data.lens = 110 + else: + # avoid cutting person + camera.data.lens = 85 + # camera.data.lens = 140 + + # camera.location.x += 0.75 + + self.mode = mode + self.camera = camera + + self.camera.location.x += first_root[0] + self.camera.location.y += first_root[1] + + self._root = first_root + + def update(self, newroot): + delta_root = newroot - self._root + + self.camera.location.x += delta_root[0] + self.camera.location.y += delta_root[1] + + self._root = newroot diff --git a/mGPT/render/blender/data.py b/mGPT/render/blender/data.py new file mode 100644 index 0000000000000000000000000000000000000000..17c6a40dda1721a4f38f176251e913dd04095499 --- /dev/null +++ b/mGPT/render/blender/data.py @@ -0,0 +1,3 @@ +class Data: + def __len__(self): + return self.N diff --git a/mGPT/render/blender/floor.py b/mGPT/render/blender/floor.py new file mode 100644 index 0000000000000000000000000000000000000000..3be1e5926e07d8f591d58c4334f2f1785b4f1f16 --- /dev/null +++ b/mGPT/render/blender/floor.py @@ -0,0 +1,73 @@ +import bpy +from .materials import floor_mat + + +def get_trajectory(data, is_mesh): + if is_mesh: + # mean of the vertices + trajectory = data[:, :, [0, 1]].mean(1) + else: + # get the root joint + trajectory = data[:, 0, [0, 1]] + return trajectory + + +def plot_floor(data, big_plane=True): + # Create a floor + minx, miny, _ = data.min(axis=(0, 1)) + maxx, maxy, _ = data.max(axis=(0, 1)) + minz = 0 + + location = ((maxx + minx)/2, (maxy + miny)/2, 0) + # a little bit bigger + scale = (1.08*(maxx - minx)/2, 1.08*(maxy - miny)/2, 1) + + bpy.ops.mesh.primitive_plane_add(size=2, enter_editmode=False, align='WORLD', location=location, scale=(1, 1, 1)) + + bpy.ops.transform.resize(value=scale, orient_type='GLOBAL', orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), orient_matrix_type='GLOBAL', + constraint_axis=(False, True, False), mirror=True, use_proportional_edit=False, + proportional_edit_falloff='SMOOTH', proportional_size=1, use_proportional_connected=False, + use_proportional_projected=False, release_confirm=True) + obj = bpy.data.objects["Plane"] + obj.name = "SmallPlane" + obj.data.name = "SmallPlane" + + if not big_plane: + obj.active_material = floor_mat(color=(0.2, 0.2, 0.2, 1)) + else: + obj.active_material = floor_mat(color=(0.1, 0.1, 0.1, 1)) + + if big_plane: + location = ((maxx + minx)/2, (maxy + miny)/2, -0.01) + bpy.ops.mesh.primitive_plane_add(size=2, enter_editmode=False, align='WORLD', location=location, scale=(1, 1, 1)) + + bpy.ops.transform.resize(value=[2*x for x in scale], orient_type='GLOBAL', orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), orient_matrix_type='GLOBAL', + constraint_axis=(False, True, False), mirror=True, use_proportional_edit=False, + proportional_edit_falloff='SMOOTH', proportional_size=1, use_proportional_connected=False, + use_proportional_projected=False, release_confirm=True) + + obj = bpy.data.objects["Plane"] + obj.name = "BigPlane" + obj.data.name = "BigPlane" + obj.active_material = floor_mat(color=(0.2, 0.2, 0.2, 1)) + + +def show_traj(coords): + pass + # create the Curve Datablock + # curveData = bpy.data.curves.new('myCurve', type='CURVE') + # curveData.dimensions = '3D' + # curveData.resolution_u = 2 + + # # map coords to spline + # polyline = curveData.splines.new('POLY') + # polyline.points.add(len(coords)-1) + # for i, coord in enumerate(coords): + # x, y = coord + # polyline.points[i].co = (x, y, 0.001, 1) + + # # create Object + # curveOB = bpy.data.objects.new('myCurve', curveData) + # curveData.bevel_depth = 0.01 + + # bpy.context.collection.objects.link(curveOB) diff --git a/mGPT/render/blender/joints.py b/mGPT/render/blender/joints.py new file mode 100644 index 0000000000000000000000000000000000000000..9d836a67fc6e40c2aecf77465d20a8e691d32b1e --- /dev/null +++ b/mGPT/render/blender/joints.py @@ -0,0 +1,378 @@ +import math + +import bpy +import numpy as np + +from mGPT.utils.joints import (humanml3d_joints, humanml3d_kinematic_tree, + mmm_joints, mmm_kinematic_tree, + mmm_to_smplh_scaling_factor) + +# from .materials import colored_material_diffuse_BSDF as colored_material +from .materials import colored_material_relection_BSDF as colored_material + +sat_factor = 1.1 + +JOINTS_MATS = [ + # colored_material(0.2500, 0.0357, 0.0349, saturation_factor = sat_factor), + # # colored_material(0.4500, 0.0357, 0.0349), + # colored_material(0.6500, 0.175, 0.0043, saturation_factor = sat_factor), + # colored_material(0.0349, 0.3500, 0.0349, saturation_factor = sat_factor), + # colored_material(0.018, 0.059, 0.600, saturation_factor = sat_factor), + # colored_material(0.032, 0.325, 0.421, saturation_factor = sat_factor), + # colored_material(0.3, 0.3, 0.3, saturation_factor = sat_factor), + colored_material(0.3500, 0.0357, 0.0349, saturation_factor=sat_factor), + # colored_material(0.4500, 0.0357, 0.0349), + colored_material(0.6500, 0.175, 0.0043, saturation_factor=sat_factor), + colored_material(0.0349, 0.3500, 0.0349, saturation_factor=sat_factor), + colored_material(0.018, 0.059, 0.600, saturation_factor=sat_factor), + colored_material(0.032, 0.325, 0.421, saturation_factor=sat_factor), + colored_material(0.3, 0.3, 0.3, saturation_factor=sat_factor), +] + + +class Joints: + + def __init__(self, + data, + *, + mode, + canonicalize, + always_on_floor, + jointstype="mmm", + **kwargs): + data = prepare_joints( + data, + canonicalize=canonicalize, + always_on_floor=always_on_floor, + jointstype=jointstype, + ) + + self.data = data + self.mode = mode + + self.N = len(data) + + self.N = len(data) + self.trajectory = data[:, 0, [0, 1]] + + if jointstype == "mmm": + self.kinematic_tree = mmm_kinematic_tree + self.joints = mmm_joints + self.joinst.append("") + elif jointstype == "humanml3d": + self.kinematic_tree = humanml3d_kinematic_tree + self.joints = humanml3d_joints + + self.mat = JOINTS_MATS + + def get_sequence_mat(self, frac): + return self.mat + + def get_root(self, index): + return self.data[index][0] + + def get_mean_root(self): + return self.data[:, 0].mean(0) + + def load_in_blender(self, index, mats): + skeleton = self.data[index] + head_mat = mats[0] + body_mat = mats[-1] + for lst, mat in zip(self.kinematic_tree, mats): + for j1, j2 in zip(lst[:-1], lst[1:]): + # spine and head + if self.joints[j2] in [ + "BUN", + ]: + sphere_between(skeleton[j1], skeleton[j2], head_mat) + elif self.joints[j2] in [ + "LE", + "RE", + "LW", + "RW", + ]: + cylinder_sphere_between(skeleton[j1], skeleton[j2], 0.040, + mat) + elif self.joints[j2] in [ + "LMrot", + "RMrot", + "RK", + "LK", + ]: + cylinder_sphere_between(skeleton[j1], skeleton[j2], 0.040, + mat) + elif self.joints[j2] in [ + "LS", + "RS", + "LF", + "RF", + ]: + cylinder_between(skeleton[j1], skeleton[j2], 0.040, mat) + elif self.joints[j2] in ["RK", "LK"]: + print(self.joints[j1], self.joints[j2]) + # body + sphere(0.14, skeleton[self.joints.index("BLN")], body_mat) + sphere_between( + skeleton[self.joints.index("BLN")], + skeleton[self.joints.index("root")], + body_mat, + factor=0.28, + ) + sphere(0.11, skeleton[self.joints.index("root")], body_mat) + # sphere_between( + # skeleton[self.joints.index("BLN")], + # skeleton[self.joints.index("BT")], + # mats[0], + # ) + # hip + # sphere_between( + # skeleton[self.joints.index("LH")], + # skeleton[self.joints.index("RH")], + # mats[0], + # factor=0.6, + # ) + # + # sphere(skeleton[self.joints.index("BLN")], 0.05, mats[0]) + # sphere_between(skeleton[13], skeleton[14], mat) + # node + # print(self.joints.index("BUN")) + # print(len(lst)) + # sphere(lst[self.joints.index("BUN")], 0.2, mat) # head + + return ["Cylinder", "Sphere"] + + def __len__(self): + return self.N + + +def softmax(x, softness=1.0, dim=None): + maxi, mini = x.max(dim), x.min(dim) + return maxi + np.log(softness + np.exp(mini - maxi)) + + +def softmin(x, softness=1.0, dim=0): + return -softmax(-x, softness=softness, dim=dim) + + +def get_forward_direction(poses, jointstype="mmm"): + if jointstype == "mmm" or jointstype == "mmmns": + joints = mmm_joints + elif jointstype == "humanml3d": + joints = humanml3d_joints + else: + raise TypeError("Only supports mmm, mmmns and humanl3d jointstype") + # Shoulders + LS, RS = joints.index("LS"), joints.index("RS") + # Hips + LH, RH = mmm_joints.index("LH"), mmm_joints.index("RH") + + across = (poses[..., RH, :] - poses[..., LH, :] + poses[..., RS, :] - + poses[..., LS, :]) + forward = np.stack((-across[..., 2], across[..., 0]), axis=-1) + forward = forward / np.linalg.norm(forward, axis=-1) + return forward + + +def cylinder_between(t1, t2, r, mat): + x1, y1, z1 = t1 + x2, y2, z2 = t2 + + dx = x2 - x1 + dy = y2 - y1 + dz = z2 - z1 + dist = math.sqrt(dx**2 + dy**2 + dz**2) + + bpy.ops.mesh.primitive_cylinder_add(radius=r, + depth=dist, + location=(dx / 2 + x1, dy / 2 + y1, + dz / 2 + z1)) + + phi = math.atan2(dy, dx) + theta = math.acos(dz / dist) + bpy.context.object.rotation_euler[1] = theta + bpy.context.object.rotation_euler[2] = phi + # bpy.context.object.shade_smooth() + bpy.context.object.active_material = mat + + bpy.ops.mesh.primitive_uv_sphere_add(radius=r, location=(x1, y1, z1)) + bpy.context.object.active_material = mat + bpy.ops.mesh.primitive_uv_sphere_add(radius=r, location=(x2, y2, z2)) + bpy.context.object.active_material = mat + + +def cylinder_sphere_between(t1, t2, r, mat): + x1, y1, z1 = t1 + x2, y2, z2 = t2 + dx = x2 - x1 + dy = y2 - y1 + dz = z2 - z1 + dist = math.sqrt(dx**2 + dy**2 + dz**2) + phi = math.atan2(dy, dx) + theta = math.acos(dz / dist) + dist = dist - 0.2 * r + # sphere node + sphere(r * 0.9, t1, mat) + sphere(r * 0.9, t2, mat) + # leveled cylinder + bpy.ops.mesh.primitive_cylinder_add( + radius=r, + depth=dist, + location=(dx / 2 + x1, dy / 2 + y1, dz / 2 + z1), + enter_editmode=True, + ) + bpy.ops.mesh.select_mode(type="EDGE") + bpy.ops.mesh.select_all(action="DESELECT") + bpy.ops.mesh.select_face_by_sides(number=32, extend=False) + bpy.ops.mesh.bevel(offset=r, segments=8) + bpy.ops.object.editmode_toggle(False) + # bpy.ops.object.shade_smooth() + bpy.context.object.rotation_euler[1] = theta + bpy.context.object.rotation_euler[2] = phi + bpy.context.object.active_material = mat + + +def sphere(r, t, mat): + bpy.ops.mesh.primitive_uv_sphere_add(segments=50, + ring_count=50, + radius=r, + location=t) + # bpy.ops.mesh.primitive_uv_sphere_add(radius=r, location=t) + # bpy.context.object.shade_smooth() + bpy.context.object.active_material = mat + + +def sphere_between(t1, t2, mat, factor=1): + x1, y1, z1 = t1 + x2, y2, z2 = t2 + + dx = x2 - x1 + dy = y2 - y1 + dz = z2 - z1 + dist = math.sqrt(dx**2 + dy**2 + dz**2) * factor + + bpy.ops.mesh.primitive_uv_sphere_add( + segments=50, + ring_count=50, + # bpy.ops.mesh.primitive_uv_sphere_add( + radius=dist, + location=(dx / 2 + x1, dy / 2 + y1, dz / 2 + z1)) + + # bpy.context.object.shade_smooth() + bpy.context.object.active_material = mat + + +def matrix_of_angles(cos, sin, inv=False): + sin = -sin if inv else sin + return np.stack((np.stack( + (cos, -sin), axis=-1), np.stack((sin, cos), axis=-1)), + axis=-2) + + +def get_floor(poses, jointstype="mmm"): + if jointstype == "mmm" or jointstype == "mmmns": + joints = mmm_joints + elif jointstype == "humanml3d": + joints = humanml3d_joints + else: + raise TypeError("Only supports mmm, mmmns and humanl3d jointstype") + # Feet + LM, RM = joints.index("LMrot"), joints.index("RMrot") + LF, RF = joints.index("LF"), joints.index("RF") + ndim = len(poses.shape) + + foot_heights = poses[..., (LM, LF, RM, RF), 1].min(-1) + floor_height = softmin(foot_heights, softness=0.5, dim=-1) + return floor_height[tuple((ndim - 2) * [None])].T + + +def canonicalize_joints(joints, jointstype="mmm"): + poses = joints.copy() + + translation = joints[..., 0, :].copy() + + # Let the root have the Y translation + translation[..., 1] = 0 + # Trajectory => Translation without gravity axis (Y) + trajectory = translation[..., [0, 2]] + + # Remove the floor + poses[..., 1] -= get_floor(poses, jointstype) + + # Remove the trajectory of the joints + poses[..., [0, 2]] -= trajectory[..., None, :] + + # Let the first pose be in the center + trajectory = trajectory - trajectory[..., 0, :] + + # Compute the forward direction of the first frame + forward = get_forward_direction(poses[..., 0, :, :], jointstype) + + # Construct the inverse rotation matrix + sin, cos = forward[..., 0], forward[..., 1] + rotations_inv = matrix_of_angles(cos, sin, inv=True) + + # Rotate the trajectory + trajectory_rotated = np.einsum("...j,...jk->...k", trajectory, + rotations_inv) + + # Rotate the poses + poses_rotated = np.einsum("...lj,...jk->...lk", poses[..., [0, 2]], + rotations_inv) + poses_rotated = np.stack( + (poses_rotated[..., 0], poses[..., 1], poses_rotated[..., 1]), axis=-1) + + # Re-merge the pose and translation + poses_rotated[..., (0, 2)] += trajectory_rotated[..., None, :] + return poses_rotated + + +def prepare_joints(joints, + canonicalize=True, + always_on_floor=False, + jointstype="mmm"): + # All face the same direction for the first frame + if canonicalize: + data = canonicalize_joints(joints, jointstype) + else: + data = joints + + # Rescaling, shift axis and swap left/right + if jointstype == "humanml3d": + data = data * mmm_to_smplh_scaling_factor + data[..., 1] = - data[..., 1] + + # Swap axis (gravity=Z instead of Y) + data = data[..., [2, 0, 1]] + + if jointstype == "mmm": + # Make left/right correct + data[..., [1]] = -data[..., [1]] + + # Center the first root to the first frame + data -= data[[0], [0], :] + + # Remove the floor + data[..., 2] -= data[..., 2].min() + + # Put all the body on the floor + if always_on_floor: + data[..., 2] -= data[..., 2].min(1)[:, None] + + return data + + +def NormalInDirection(normal, direction, limit=0.5): + return direction.dot(normal) > limit + + +def GoingUp(normal, limit=0.5): + return NormalInDirection(normal, (0, 0, 1), limit) + + +def GoingDown(normal, limit=0.5): + return NormalInDirection(normal, (0, 0, -1), limit) + + +def GoingSide(normal, limit=0.5): + return GoingUp(normal, limit) == False and GoingDown(normal, + limit) == False diff --git a/mGPT/render/blender/materials.py b/mGPT/render/blender/materials.py new file mode 100644 index 0000000000000000000000000000000000000000..4f0bf1a1c28254a776469058ab6473c7ca9a451d --- /dev/null +++ b/mGPT/render/blender/materials.py @@ -0,0 +1,135 @@ +import bpy + + +def clear_material(material): + if material.node_tree: + material.node_tree.links.clear() + material.node_tree.nodes.clear() + + +def colored_material_diffuse_BSDF(r, g, b, a=1, roughness=0.127451): + materials = bpy.data.materials + material = materials.new(name="body") + material.use_nodes = True + clear_material(material) + nodes = material.node_tree.nodes + links = material.node_tree.links + output = nodes.new(type='ShaderNodeOutputMaterial') + diffuse = nodes.new(type='ShaderNodeBsdfDiffuse') + diffuse.inputs["Color"].default_value = (r, g, b, a) + diffuse.inputs["Roughness"].default_value = roughness + links.new(diffuse.outputs['BSDF'], output.inputs['Surface']) + return material + +def colored_material_relection_BSDF(r, g, b, a=1, roughness=0.127451, saturation_factor=1): + materials = bpy.data.materials + material = materials.new(name="body") + material.use_nodes = True + # clear_material(material) + nodes = material.node_tree.nodes + links = material.node_tree.links + output = nodes.new(type='ShaderNodeOutputMaterial') + # diffuse = nodes.new(type='ShaderNodeBsdfDiffuse') + diffuse = nodes["Principled BSDF"] + diffuse.inputs["Base Color"].default_value = (r*saturation_factor, g*saturation_factor, b*saturation_factor, a) + diffuse.inputs["Roughness"].default_value = roughness + links.new(diffuse.outputs['BSDF'], output.inputs['Surface']) + return material + +# keys: +# ['Base Color', 'Subsurface', 'Subsurface Radius', 'Subsurface Color', 'Metallic', 'Specular', 'Specular Tint', 'Roughness', 'Anisotropic', 'Anisotropic Rotation', 'Sheen', 1Sheen Tint', 'Clearcoat', 'Clearcoat Roughness', 'IOR', 'Transmission', 'Transmission Roughness', 'Emission', 'Emission Strength', 'Alpha', 'Normal', 'Clearcoat Normal', 'Tangent'] +DEFAULT_BSDF_SETTINGS = {"Subsurface": 0.15, + "Subsurface Radius": [1.1, 0.2, 0.1], + "Metallic": 0.3, + "Specular": 0.5, + "Specular Tint": 0.5, + "Roughness": 0.75, + "Anisotropic": 0.25, + "Anisotropic Rotation": 0.25, + "Sheen": 0.75, + "Sheen Tint": 0.5, + "Clearcoat": 0.5, + "Clearcoat Roughness": 0.5, + "IOR": 1.450, + "Transmission": 0.1, + "Transmission Roughness": 0.1, + "Emission": (0, 0, 0, 1), + "Emission Strength": 0.0, + "Alpha": 1.0} + +def body_material(r, g, b, a=1, name="body", oldrender=True): + if oldrender: + material = colored_material_diffuse_BSDF(r, g, b, a=a) + else: + materials = bpy.data.materials + material = materials.new(name=name) + material.use_nodes = True + nodes = material.node_tree.nodes + diffuse = nodes["Principled BSDF"] + inputs = diffuse.inputs + + settings = DEFAULT_BSDF_SETTINGS.copy() + settings["Base Color"] = (r, g, b, a) + settings["Subsurface Color"] = (r, g, b, a) + settings["Subsurface"] = 0.0 + + for setting, val in settings.items(): + inputs[setting].default_value = val + + return material + + +def colored_material_bsdf(name, **kwargs): + materials = bpy.data.materials + material = materials.new(name=name) + material.use_nodes = True + nodes = material.node_tree.nodes + diffuse = nodes["Principled BSDF"] + inputs = diffuse.inputs + + settings = DEFAULT_BSDF_SETTINGS.copy() + for key, val in kwargs.items(): + settings[key] = val + + for setting, val in settings.items(): + inputs[setting].default_value = val + + return material + + +def floor_mat(name="floor_mat", color=(0.1, 0.1, 0.1, 1), roughness=0.127451): + return colored_material_diffuse_BSDF(color[0], color[1], color[2], a=color[3], roughness=roughness) + + +def plane_mat(): + materials = bpy.data.materials + material = materials.new(name="plane") + material.use_nodes = True + clear_material(material) + nodes = material.node_tree.nodes + links = material.node_tree.links + output = nodes.new(type='ShaderNodeOutputMaterial') + diffuse = nodes.new(type='ShaderNodeBsdfDiffuse') + checker = nodes.new(type="ShaderNodeTexChecker") + checker.inputs["Scale"].default_value = 1024 + checker.inputs["Color1"].default_value = (0.8, 0.8, 0.8, 1) + checker.inputs["Color2"].default_value = (0.3, 0.3, 0.3, 1) + links.new(checker.outputs["Color"], diffuse.inputs['Color']) + links.new(diffuse.outputs['BSDF'], output.inputs['Surface']) + diffuse.inputs["Roughness"].default_value = 0.127451 + return material + + +def plane_mat_uni(): + materials = bpy.data.materials + material = materials.new(name="plane_uni") + material.use_nodes = True + clear_material(material) + nodes = material.node_tree.nodes + links = material.node_tree.links + output = nodes.new(type='ShaderNodeOutputMaterial') + diffuse = nodes.new(type='ShaderNodeBsdfDiffuse') + diffuse.inputs["Color"].default_value = (0.8, 0.8, 0.8, 1) + diffuse.inputs["Roughness"].default_value = 0.127451 + links.new(diffuse.outputs['BSDF'], output.inputs['Surface']) + return material diff --git a/mGPT/render/blender/meshes.py b/mGPT/render/blender/meshes.py new file mode 100644 index 0000000000000000000000000000000000000000..284de6c5bef4c17078b316fa2f4501b33dcb2444 --- /dev/null +++ b/mGPT/render/blender/meshes.py @@ -0,0 +1,93 @@ +import numpy as np + +from .materials import body_material + +# green +# GT_SMPL = body_material(0.009, 0.214, 0.029) +GT_SMPL = body_material(0.035, 0.415, 0.122) + +# blue +# GEN_SMPL = body_material(0.022, 0.129, 0.439) +# Blues => cmap(0.87) +# GEN_SMPL = body_material(0.035, 0.322, 0.615) +# Oranges => cmap(0.87) +GEN_SMPL = body_material(0.658, 0.214, 0.0114) + + +class Meshes: + def __init__(self, data, *, gt, mode, faces_path, canonicalize, always_on_floor, oldrender=True, is_smplx=False, **kwargs): + data = prepare_meshes(data, canonicalize=canonicalize, + always_on_floor=always_on_floor, + is_smplx=is_smplx) + + if isinstance(faces_path, str): + self.faces = np.load(faces_path) + else: + self.faces = faces_path + + self.data = data + self.mode = mode + self.oldrender = oldrender + + self.N = len(data) + self.trajectory = data[:, :, [0, 1]].mean(1) + + if gt: + self.mat = GT_SMPL + else: + self.mat = GEN_SMPL + + def get_sequence_mat(self, frac): + import matplotlib + # cmap = matplotlib.cm.get_cmap('Blues') + cmap = matplotlib.cm.get_cmap('Oranges') + # begin = 0.60 + # end = 0.90 + begin = 0.50 + end = 0.90 + rgbcolor = cmap(begin + (end-begin)*frac) + mat = body_material(*rgbcolor, oldrender=self.oldrender) + return mat + + def get_root(self, index): + return self.data[index].mean(0) + + def get_mean_root(self): + return self.data.mean((0, 1)) + + def load_in_blender(self, index, mat): + vertices = self.data[index] + faces = self.faces + name = f"{str(index).zfill(4)}" + + from .tools import load_numpy_vertices_into_blender + load_numpy_vertices_into_blender(vertices, faces, name, mat) + + return name + + def __len__(self): + return self.N + + +def prepare_meshes(data, canonicalize=True, always_on_floor=False, is_smplx=False): + if canonicalize: + print("No canonicalization for now") + + # fitted mesh do not need fixing axis + # fix axis + if is_smplx: + data[..., 1] = - data[..., 1] + # data[..., 0] = - data[..., 0] + + + # Swap axis (gravity=Z instead of Y) + data = data[..., [2, 0, 1]] + + # Remove the floor + data[..., 2] -= data[..., 2].min() + + # Put all the body on the floor + if always_on_floor: + data[..., 2] -= data[..., 2].min(1)[:, None] + + return data diff --git a/mGPT/render/blender/sampler.py b/mGPT/render/blender/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa8d853f5867974aceae5b50ac9ae4b99f1e686 --- /dev/null +++ b/mGPT/render/blender/sampler.py @@ -0,0 +1,15 @@ +import numpy as np + +def get_frameidx(*, mode, nframes, exact_frame, frames_to_keep): + if mode == "sequence": + frameidx = np.linspace(0, nframes - 1, frames_to_keep) + frameidx = np.round(frameidx).astype(int) + frameidx = list(frameidx) + elif mode == "frame": + index_frame = int(exact_frame*nframes) + frameidx = [index_frame] + elif mode == "video": + frameidx = range(0, nframes) + else: + raise ValueError(f"Not support {mode} render mode") + return frameidx diff --git a/mGPT/render/blender/scene.py b/mGPT/render/blender/scene.py new file mode 100644 index 0000000000000000000000000000000000000000..5b35e6c64dc0e0cd7a0168286cbd868c5936573d --- /dev/null +++ b/mGPT/render/blender/scene.py @@ -0,0 +1,96 @@ +import bpy +from .materials import plane_mat # noqa + + +def setup_renderer(denoising=True, oldrender=True, accelerator="gpu", device=[0]): + bpy.context.scene.render.engine = "CYCLES" + bpy.data.scenes[0].render.engine = "CYCLES" + if accelerator.lower() == "gpu": + bpy.context.preferences.addons[ + "cycles" + ].preferences.compute_device_type = "CUDA" + bpy.context.scene.cycles.device = "GPU" + i = 0 + bpy.context.preferences.addons["cycles"].preferences.get_devices() + for d in bpy.context.preferences.addons["cycles"].preferences.devices: + if i in device: # gpu id + d["use"] = 1 + print(d["name"], "".join(str(i) for i in device)) + else: + d["use"] = 0 + i += 1 + + if denoising: + bpy.context.scene.cycles.use_denoising = True + + bpy.context.scene.render.tile_x = 256 + bpy.context.scene.render.tile_y = 256 + bpy.context.scene.cycles.samples = 64 + # bpy.context.scene.cycles.denoiser = 'OPTIX' + + if not oldrender: + bpy.context.scene.view_settings.view_transform = "Standard" + bpy.context.scene.render.film_transparent = True + bpy.context.scene.display_settings.display_device = "sRGB" + bpy.context.scene.view_settings.gamma = 1.2 + bpy.context.scene.view_settings.exposure = -0.75 + + +# Setup scene +def setup_scene( + res="high", denoising=True, oldrender=True, accelerator="gpu", device=[0] +): + scene = bpy.data.scenes["Scene"] + assert res in ["ultra", "high", "med", "low"] + if res == "high": + scene.render.resolution_x = 1280 + scene.render.resolution_y = 1024 + elif res == "med": + scene.render.resolution_x = 1280 // 2 + scene.render.resolution_y = 1024 // 2 + elif res == "low": + scene.render.resolution_x = 1280 // 4 + scene.render.resolution_y = 1024 // 4 + elif res == "ultra": + scene.render.resolution_x = 1280 * 2 + scene.render.resolution_y = 1024 * 2 + + scene.render.film_transparent= True + world = bpy.data.worlds["World"] + world.use_nodes = True + bg = world.node_tree.nodes["Background"] + bg.inputs[0].default_value[:3] = (1.0, 1.0, 1.0) + bg.inputs[1].default_value = 1.0 + + # Remove default cube + if "Cube" in bpy.data.objects: + bpy.data.objects["Cube"].select_set(True) + bpy.ops.object.delete() + + bpy.ops.object.light_add( + type="SUN", align="WORLD", location=(0, 0, 0), scale=(1, 1, 1) + ) + bpy.data.objects["Sun"].data.energy = 1.5 + + # rotate camera + bpy.ops.object.empty_add( + type="PLAIN_AXES", align="WORLD", location=(0, 0, 0), scale=(1, 1, 1) + ) + bpy.ops.transform.resize( + value=(10, 10, 10), + orient_type="GLOBAL", + orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), + orient_matrix_type="GLOBAL", + mirror=True, + use_proportional_edit=False, + proportional_edit_falloff="SMOOTH", + proportional_size=1, + use_proportional_connected=False, + use_proportional_projected=False, + ) + bpy.ops.object.select_all(action="DESELECT") + + setup_renderer( + denoising=denoising, oldrender=oldrender, accelerator=accelerator, device=device + ) + return scene diff --git a/mGPT/render/blender/tools.py b/mGPT/render/blender/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..3c64ea62f3934b4c356bd7f29cd9c949f58d1050 --- /dev/null +++ b/mGPT/render/blender/tools.py @@ -0,0 +1,56 @@ +import bpy +import numpy as np + + +def style_detect(data): + is_mesh = False + is_smplx = False + jointstyle = 'mmm' + # heuristic + if data.shape[1] > 1000: + is_mesh = True + if data.shape[1] == 10475: + is_smplx = True + if data.shape[1] == 22: + jointstyle = 'humanml3d' + + return is_mesh, is_smplx, jointstyle + + + +# see this for more explanation +# https://gist.github.com/iyadahmed/7c7c0fae03c40bd87e75dc7059e35377 +# This should be solved with new version of blender +class ndarray_pydata(np.ndarray): + def __bool__(self) -> bool: + return len(self) > 0 + + +def load_numpy_vertices_into_blender(vertices, faces, name, mat): + mesh = bpy.data.meshes.new(name) + mesh.from_pydata(vertices, [], faces.view(ndarray_pydata)) + mesh.validate() + + obj = bpy.data.objects.new(name, mesh) + bpy.context.scene.collection.objects.link(obj) + + bpy.ops.object.select_all(action='DESELECT') + obj.select_set(True) + obj.active_material = mat + bpy.context.view_layer.objects.active = obj + bpy.ops.object.shade_smooth() + bpy.ops.object.select_all(action='DESELECT') + return True + + +def delete_objs(names): + if not isinstance(names, list): + names = [names] + # bpy.ops.object.mode_set(mode='OBJECT') + bpy.ops.object.select_all(action='DESELECT') + for obj in bpy.context.scene.objects: + for name in names: + if obj.name.startswith(name) or obj.name.endswith(name): + obj.select_set(True) + bpy.ops.object.delete() + bpy.ops.object.select_all(action='DESELECT') diff --git a/mGPT/render/blender/vertices.py b/mGPT/render/blender/vertices.py new file mode 100644 index 0000000000000000000000000000000000000000..78be1b12a2fec4ca43ab9065e99a0a1ba368be5a --- /dev/null +++ b/mGPT/render/blender/vertices.py @@ -0,0 +1,17 @@ +import numpy as np + + +def prepare_vertices(vertices, canonicalize=True): + data = vertices + # Swap axis (gravity=Z instead of Y) + # data = data[..., [2, 0, 1]] + + # Make left/right correct + # data[..., [1]] = -data[..., [1]] + + # Center the first root to the first frame + data -= data[[0], [0], :] + + # Remove the floor + data[..., 2] -= np.min(data[..., 2]) + return data diff --git a/mGPT/render/matplot/plot_3d_global.py b/mGPT/render/matplot/plot_3d_global.py new file mode 100644 index 0000000000000000000000000000000000000000..c55d5667660a2aac1ec4d3393a016633b3fbc412 --- /dev/null +++ b/mGPT/render/matplot/plot_3d_global.py @@ -0,0 +1,151 @@ +import torch +import matplotlib.pyplot as plt +import numpy as np +import io +import matplotlib +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import mpl_toolkits.mplot3d.axes3d as p3 +from textwrap import wrap +import imageio + + +def plot_3d_motion(args, figsize=(10, 10), fps=120, radius=4): + matplotlib.use('Agg') + + joints, out_name, title = args + + title_sp = title.split(' ') + if len(title_sp) > 20: + title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])]) + elif len(title_sp) > 10: + title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])]) + + data = joints.copy().reshape(len(joints), -1, 3) + + nb_joints = joints.shape[1] + smpl_kinetic_chain = [ + [0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], + [3, 5, 6, 7], [3, 8, 9, 10] + ] if nb_joints == 21 else [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], + [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], + [9, 13, 16, 18, 20]] + limits = 1000 if nb_joints == 21 else 2 + + MINS = data.min(axis=0).min(axis=0) + MAXS = data.max(axis=0).max(axis=0) + + colors = [ + 'red', 'blue', 'black', 'red', 'blue', 'darkblue', 'darkblue', + 'darkblue', 'darkblue', 'darkblue', 'darkred', 'darkred', 'darkred', + 'darkred', 'darkred' + ] + frame_number = data.shape[0] + # print(data.shape) + + height_offset = MINS[1] + data[:, :, 1] -= height_offset + trajec = data[:, 0, [0, 2]] + + data[..., 0] -= data[:, 0:1, 0] + data[..., 2] -= data[:, 0:1, 2] + + def update(index): + def init(): + ax.set_xlim3d([-radius / 2, radius / 2]) + ax.set_ylim3d([0, radius]) + ax.set_zlim3d([0, radius]) + ax.grid(b=False) + + def plot_xzPlane(minx, maxx, miny, minz, maxz): + ## Plot a plane XZ + verts = [[minx, miny, minz], [minx, miny, maxz], + [maxx, miny, maxz], [maxx, miny, minz]] + xz_plane = Poly3DCollection([verts]) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) + ax.add_collection3d(xz_plane) + + fig = plt.figure(figsize=(480 / 96., 320 / 96.), + dpi=96) if nb_joints == 21 else plt.figure( + figsize=(10, 10), dpi=96) + # fig.tight_layout() + if title is not None: + wraped_title = '\n'.join(wrap(title, 40)) + fig.suptitle(wraped_title, fontsize=16) + ax = p3.Axes3D(fig, auto_add_to_figure=False) + fig.add_axes(ax) + + init() + + # ax.lines = [] + # ax.collections = [] + ax.view_init(elev=110, azim=-90) + ax.dist = 7.5 + # ax = + plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, + MINS[2] - trajec[index, 1], MAXS[2] - trajec[index, 1]) + # ax.scatter(data[index, :22, 0], data[index, :22, 1], data[index, :22, 2], color='black', s=3) + + if index > 1: + ax.plot3D(trajec[:index, 0] - trajec[index, 0], + np.zeros_like(trajec[:index, 0]), + trajec[:index, 1] - trajec[index, 1], + linewidth=1.0, + color='blue') + # ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2]) + + for i, (chain, color) in enumerate(zip(smpl_kinetic_chain, colors)): + # print(color) + if i < 5: + linewidth = 4.0 + else: + linewidth = 2.0 + ax.plot3D(data[index, chain, 0], + data[index, chain, 1], + data[index, chain, 2], + linewidth=linewidth, + color=color) + # print(trajec[:index, 0].shape) + + plt.axis('off') + + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + if out_name is not None: + plt.savefig(out_name, dpi=96) + plt.close() + + else: + io_buf = io.BytesIO() + fig.savefig(io_buf, format='raw', dpi=96) + io_buf.seek(0) + # print(fig.bbox.bounds) + arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8), + newshape=(int(fig.bbox.bounds[3]), + int(fig.bbox.bounds[2]), -1)) + io_buf.close() + plt.close() + return arr + + out = [] + for i in range(frame_number): + out.append(update(i)) + out = np.stack(out, axis=0) + return torch.from_numpy(out) + + +def draw_to_batch(smpl_joints_batch, title_batch=None, outname=None): + + batch_size = len(smpl_joints_batch) + out = [] + for i in range(batch_size): + out.append( + plot_3d_motion([ + smpl_joints_batch[i], None, + title_batch[i] if title_batch is not None else None + ])) + if outname is not None: + imageio.mimsave(outname[i], np.array(out[-1]), duration=50) + out = torch.stack(out, axis=0) + return out diff --git a/mGPT/render/pyrender/hybrik_loc2rot.py b/mGPT/render/pyrender/hybrik_loc2rot.py new file mode 100644 index 0000000000000000000000000000000000000000..5739617e95a5200edef446cdfd15d89f4b36160e --- /dev/null +++ b/mGPT/render/pyrender/hybrik_loc2rot.py @@ -0,0 +1,140 @@ +import numpy as np + +SMPL_BODY_BONES = [-0.0018, -0.2233, 0.0282, 0.0695, -0.0914, -0.0068, -0.0677, -0.0905, -0.0043, + -0.0025, 0.1090, -0.0267, 0.0343, -0.3752, -0.0045, -0.0383, -0.3826, -0.0089, + 0.0055, 0.1352, 0.0011, -0.0136, -0.3980, -0.0437, 0.0158, -0.3984, -0.0423, + 0.0015, 0.0529, 0.0254, 0.0264, -0.0558, 0.1193, -0.0254, -0.0481, 0.1233, + -0.0028, 0.2139, -0.0429, 0.0788, 0.1217, -0.0341, -0.0818, 0.1188, -0.0386, + 0.0052, 0.0650, 0.0513, 0.0910, 0.0305, -0.0089, -0.0960, 0.0326, -0.0091, + 0.2596, -0.0128, -0.0275, -0.2537, -0.0133, -0.0214, 0.2492, 0.0090, -0.0012, + -0.2553, 0.0078, -0.0056, 0.0840, -0.0082, -0.0149, -0.0846, -0.0061, -0.0103] + + +class HybrIKJointsToRotmat: + def __init__(self): + self.naive_hybrik = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] + self.num_nodes = 22 + self.parents = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19] + self.child = [-1, 4, 5, 6, 7, 8, 9, 10, 11, -1, -2, -2, 15, + 16, 17, -2, 18, 19, 20, 21, -2, -2] + self.bones = np.reshape(np.array(SMPL_BODY_BONES), [24, 3])[:self.num_nodes] + + def multi_child_rot(self, t, p, + pose_global_parent): + """ + t: B x 3 x child_num + p: B x 3 x child_num + pose_global_parent: B x 3 x 3 + """ + m = np.matmul(t, np.transpose(np.matmul(np.linalg.inv(pose_global_parent), p), [0, 2, 1])) + u, s, vt = np.linalg.svd(m) + r = np.matmul(np.transpose(vt, [0, 2, 1]), np.transpose(u, [0, 2, 1])) + err_det_mask = (np.linalg.det(r) < 0.0).reshape(-1, 1, 1) + id_fix = np.reshape(np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]), + [1, 3, 3]) + r_fix = np.matmul(np.transpose(vt, [0, 2, 1]), + np.matmul(id_fix, + np.transpose(u, [0, 2, 1]))) + r = r * (1.0 - err_det_mask) + r_fix * err_det_mask + return r, np.matmul(pose_global_parent, r) + + def single_child_rot(self, t, p, pose_global_parent, twist=None): + """ + t: B x 3 x 1 + p: B x 3 x 1 + pose_global_parent: B x 3 x 3 + twist: B x 2 if given, default to None + """ + p_rot = np.matmul(np.linalg.inv(pose_global_parent), p) + cross = np.cross(t, p_rot, axisa=1, axisb=1, axisc=1) + sina = np.linalg.norm(cross, axis=1, keepdims=True) / (np.linalg.norm(t, axis=1, keepdims=True) * + np.linalg.norm(p_rot, axis=1, keepdims=True)) + cross = cross / np.linalg.norm(cross, axis=1, keepdims=True) + cosa = np.sum(t * p_rot, axis=1, keepdims=True) / (np.linalg.norm(t, axis=1, keepdims=True) * + np.linalg.norm(p_rot, axis=1, keepdims=True)) + sina = np.reshape(sina, [-1, 1, 1]) + cosa = np.reshape(cosa, [-1, 1, 1]) + skew_sym_t = np.stack([0.0 * cross[:, 0], -cross[:, 2], cross[:, 1], + cross[:, 2], 0.0 * cross[:, 0], -cross[:, 0], + -cross[:, 1], cross[:, 0], 0.0 * cross[:, 0]], 1) + skew_sym_t = np.reshape(skew_sym_t, [-1, 3, 3]) + dsw_rotmat = np.reshape(np.eye(3), [1, 3, 3] + ) + sina * skew_sym_t + (1.0 - cosa) * np.matmul(skew_sym_t, + skew_sym_t) + if twist is not None: + skew_sym_t = np.stack([0.0 * t[:, 0], -t[:, 2], t[:, 1], + t[:, 2], 0.0 * t[:, 0], -t[:, 0], + -t[:, 1], t[:, 0], 0.0 * t[:, 0]], 1) + skew_sym_t = np.reshape(skew_sym_t, [-1, 3, 3]) + sina = np.reshape(twist[:, 1], [-1, 1, 1]) + cosa = np.reshape(twist[:, 0], [-1, 1, 1]) + dtw_rotmat = np.reshape(np.eye(3), [1, 3, 3] + ) + sina * skew_sym_t + (1.0 - cosa) * np.matmul(skew_sym_t, + skew_sym_t) + dsw_rotmat = np.matmul(dsw_rotmat, dtw_rotmat) + return dsw_rotmat, np.matmul(pose_global_parent, dsw_rotmat) + + def __call__(self, joints, twist=None): + """ + joints: B x N x 3 + twist: B x N x 2 if given, default to None + """ + expand_dim = False + if len(joints.shape) == 2: + expand_dim = True + joints = np.expand_dims(joints, 0) + if twist is not None: + twist = np.expand_dims(twist, 0) + assert (len(joints.shape) == 3) + batch_size = np.shape(joints)[0] + joints_rel = joints - joints[:, self.parents] + joints_hybrik = 0.0 * joints_rel + pose_global = np.zeros([batch_size, self.num_nodes, 3, 3]) + pose = np.zeros([batch_size, self.num_nodes, 3, 3]) + for i in range(self.num_nodes): + if i == 0: + joints_hybrik[:, 0] = joints[:, 0] + else: + joints_hybrik[:, i] = np.matmul(pose_global[:, self.parents[i]], + np.reshape(self.bones[i], [1, 3, 1])).reshape(-1, 3) + \ + joints_hybrik[:, self.parents[i]] + if self.child[i] == -2: + pose[:, i] = pose[:, i] + np.eye(3).reshape(1, 3, 3) + pose_global[:, i] = pose_global[:, self.parents[i]] + continue + if i == 0: + r, rg = self.multi_child_rot(np.transpose(self.bones[[1, 2, 3]].reshape(1, 3, 3), [0, 2, 1]), + np.transpose(joints_rel[:, [1, 2, 3]], [0, 2, 1]), + np.eye(3).reshape(1, 3, 3)) + + elif i == 9: + r, rg = self.multi_child_rot(np.transpose(self.bones[[12, 13, 14]].reshape(1, 3, 3), [0, 2, 1]), + np.transpose(joints_rel[:, [12, 13, 14]], [0, 2, 1]), + pose_global[:, self.parents[9]]) + else: + p = joints_rel[:, self.child[i]] + if self.naive_hybrik[i] == 0: + p = joints[:, self.child[i]] - joints_hybrik[:, i] + twi = None + if twist is not None: + twi = twist[:, i] + r, rg = self.single_child_rot(self.bones[self.child[i]].reshape(1, 3, 1), + p.reshape(-1, 3, 1), + pose_global[:, self.parents[i]], + twi) + pose[:, i] = r + pose_global[:, i] = rg + if expand_dim: + pose = pose[0] + return pose + + +if __name__ == "__main__": + jts2rot_hybrik = HybrIKJointsToRotmat() + joints = np.array(SMPL_BODY_BONES).reshape(1, 24, 3)[:, :22] + parents = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19] + for i in range(1, 22): + joints[:, i] = joints[:, i] + joints[:, parents[i]] + pose = jts2rot_hybrik(joints) + print(pose) diff --git a/mGPT/render/pyrender/j3ds_render_smpl.py b/mGPT/render/pyrender/j3ds_render_smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..3c8bd86528565c7ab1cded16ef61c996cc527650 --- /dev/null +++ b/mGPT/render/pyrender/j3ds_render_smpl.py @@ -0,0 +1,48 @@ +import os +import argparse +import numpy as np +from scripts.hybrik_loc2rot import HybrIKJointsToRotmat +from scripts.pyrender import SMPLRender +import cv2 +from scipy.spatial.transform import Rotation as RRR + +parser = argparse.ArgumentParser( + description='Render a SMPL video by a j3ds npy file.') +parser.add_argument('--input', type=str, default='', help='the npy file path') +parser.add_argument('--render', + type=int, + default=1, + help='render the video if 1') +args = parser.parse_args() + +input_path = args.input +output_npy_path = args.input.replace('.npy', '_pose.npy') +data = np.load(input_path) +data = data - data[0, 0] +pose_generator = HybrIKJointsToRotmat() +pose = pose_generator(data) +pose = np.concatenate( + [pose, np.stack([np.stack([np.eye(3)] * pose.shape[0], 0)] * 2, 1)], 1) +np.save(output_npy_path, pose) +shape = [768, 768] +if args.render: + render = SMPLRender() + output_mp4_path = args.input.replace('.npy', '_smpl.mp4') + os.environ['PYOPENGL_PLATFORM'] = 'egl' + size = (shape[1], shape[0]) + fps = 30.0 + fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') + videoWriter = cv2.VideoWriter(output_mp4_path, fourcc, fps, size) + r = RRR.from_rotvec(np.array([np.pi, 0.0, 0.0])) + pose[:, 0] = np.matmul(r.as_matrix().reshape(1, 3, 3), pose[:, 0]) + for i in range(data.shape[0]): + img = np.zeros([shape[0], shape[1], 3]) + aroot = data[[i], 0] + np.array([[0.0, 0.0, 30.0]]) + aroot[:, 1] = -aroot[:, 1] + params = dict(pred_shape=np.zeros([1, 10]), + pred_root=aroot, + pred_pose=pose[[i]]) + renderImg = render.render(img.copy(), params) + renderImg = (renderImg * 255).astype(np.uint8) + videoWriter.write(renderImg) + videoWriter.release() diff --git a/mGPT/render/pyrender/smpl_render.py b/mGPT/render/pyrender/smpl_render.py new file mode 100644 index 0000000000000000000000000000000000000000..5c5591461516c3d7b571759f129ccc7f930a9fd3 --- /dev/null +++ b/mGPT/render/pyrender/smpl_render.py @@ -0,0 +1,135 @@ +import os + +os.environ['PYOPENGL_PLATFORM'] = 'egl' +import torch +import numpy as np +import cv2 + +import matplotlib.pyplot as plt +import glob +import pickle +import pyrender +import trimesh +from smplx import SMPL as _SMPL +from smplx.utils import SMPLOutput as ModelOutput +from scipy.spatial.transform.rotation import Rotation as RRR + +class SMPL(_SMPL): + """ Extension of the official SMPL implementation to support more joints """ + + def __init__(self, *args, **kwargs): + super(SMPL, self).__init__(*args, **kwargs) + # joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] + # J_regressor_extra = np.load(config.JOINT_REGRESSOR_TRAIN_EXTRA) + # self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) + # self.joint_map = torch.tensor(joints, dtype=torch.long) + + def forward(self, *args, **kwargs): + kwargs['get_skin'] = True + smpl_output = super(SMPL, self).forward(*args, **kwargs) + # extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) #Additional 9 joints #Check doc/J_regressor_extra.png + # joints = torch.cat([smpl_output.joints, extra_joints], dim=1) #[N, 24 + 21, 3] + [N, 9, 3] + # joints = joints[:, self.joint_map, :] + joints = smpl_output.joints + output = ModelOutput(vertices=smpl_output.vertices, + global_orient=smpl_output.global_orient, + body_pose=smpl_output.body_pose, + joints=joints, + betas=smpl_output.betas, + full_pose=smpl_output.full_pose) + return output + +class Renderer: + """ + Renderer used for visualizing the SMPL model + Code adapted from https://github.com/vchoutas/smplify-x + """ + def __init__(self, focal_length=5000, img_res=(224,224), faces=None): + self.renderer = pyrender.OffscreenRenderer(viewport_width=img_res[0], + viewport_height=img_res[1], + point_size=1.0) + self.focal_length = focal_length + self.camera_center = [img_res[0] // 2, img_res[1] // 2] + self.faces = faces + def __call__(self, vertices, camera_translation, image): + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.2, + alphaMode='OPAQUE', + baseColorFactor=(0.8, 0.3, 0.3, 1.0)) + + camera_translation[0] *= -1. + + mesh = trimesh.Trimesh(vertices, self.faces) + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + + scene = pyrender.Scene(bg_color = [1, 1, 1, 0.8], ambient_light=(0.4, 0.4, 0.4)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation + camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length, + cx=self.camera_center[0], cy=self.camera_center[1]) + scene.add(camera, pose=camera_pose) + + + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=300) + light_pose = np.eye(4) + + light_pose[:3, 3] = np.array([0, -1, 1]) + scene.add(light, pose=light_pose) + + light_pose[:3, 3] = np.array([0, 1, 1]) + scene.add(light, pose=light_pose) + + light_pose[:3, 3] = np.array([1, 1, 2]) + scene.add(light, pose=light_pose) + + color, rend_depth = self.renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + valid_mask = (rend_depth > 0)[:,:,None] + output_img = (color[:, :, :3] * valid_mask + + (1 - valid_mask) * image) + return output_img + +class SMPLRender(): + def __init__(self, SMPL_MODEL_DIR): + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + self.smpl = SMPL(SMPL_MODEL_DIR, + batch_size=1, + create_transl=False).to(self.device) + + self.focal_length = 5000 + + def render(self, image, smpl_param, is_headroot=False): + pose = smpl_param['pred_pose'] + if pose.size==72: + pose = pose.reshape(-1,3) + pose = RRR.from_rotvec(pose).as_matrix() + pose = pose.reshape(1,24,3,3) + pred_betas = torch.from_numpy(smpl_param['pred_shape'].reshape(1, 10).astype(np.float32)).to(self.device) + pred_rotmat = torch.from_numpy(pose.astype(np.float32)).to(self.device) + pred_camera_t = smpl_param['pred_root'].reshape(1, 3).astype(np.float32) + smpl_output = self.smpl(betas=pred_betas, body_pose=pred_rotmat[:, 1:], + global_orient=pred_rotmat[:, 0].unsqueeze(1), pose2rot=False) + + + vertices = smpl_output.vertices[0].detach().cpu().numpy() + pred_camera_t = pred_camera_t[0] + + if is_headroot: + pred_camera_t = pred_camera_t - smpl_output.joints[0,12].detach().cpu().numpy() + + renderer = Renderer(focal_length=self.focal_length, + img_res=(image.shape[1], image.shape[0]), faces=self.smpl.faces) + + renderImg = renderer(vertices, pred_camera_t.copy(), image / 255.0) + renderer.renderer.delete() + return renderImg diff --git a/mGPT/render/renderer.py b/mGPT/render/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd66ab748f795bd3b1915b5b609dc5daace6ce4 --- /dev/null +++ b/mGPT/render/renderer.py @@ -0,0 +1,179 @@ +""" +This script is borrowed from https://github.com/mkocabas/VIBE + Adhere to their licence to use this script + It has been modified +""" + +import os +import math +import trimesh + +import pyrender +import numpy as np +from pyrender.constants import RenderFlags + + +# os.environ['DISPLAY'] = ':0.0' +# os.environ['PYOPENGL_PLATFORM'] = 'egl' +# os.environ['PYOPENGL_PLATFORM'] = 'osmesa' +SMPL_MODEL_DIR = "data/smpl_data/" + + +def get_smpl_faces(): + return np.load(os.path.join(SMPL_MODEL_DIR, "smplfaces.npy")) + + +class WeakPerspectiveCamera(pyrender.Camera): + def __init__(self, + scale, + translation, + znear=pyrender.camera.DEFAULT_Z_NEAR, + zfar=None, + name=None): + super(WeakPerspectiveCamera, self).__init__( + znear=znear, + zfar=zfar, + name=name, + ) + self.scale = scale + self.translation = translation + + def get_projection_matrix(self, width=None, height=None): + P = np.eye(4) + P[0, 0] = self.scale[0] + P[1, 1] = self.scale[1] + P[0, 3] = self.translation[0] * self.scale[0] + P[1, 3] = -self.translation[1] * self.scale[1] + P[2, 2] = -1 + return P + + +class Renderer: + def __init__(self, background=None, resolution=(224, 224), bg_color=[0, 0, 0, 0.5], orig_img=False, wireframe=False, cam_pose=np.eye(4)): + width, height = resolution + self.background = np.zeros((height, width, 3)) + self.resolution = resolution + + self.faces = get_smpl_faces() + self.orig_img = orig_img + self.wireframe = wireframe + self.renderer = pyrender.OffscreenRenderer( + viewport_width=self.resolution[0], + viewport_height=self.resolution[1], + point_size=0.5 + ) + + # set the scene + self.scene = pyrender.Scene(bg_color=bg_color, ambient_light=(0.4, 0.4, 0.4)) + + light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=4) + + + light_pose = np.eye(4) + light_pose[:3, 3] = [0, -1, 1] + self.scene.add(light, pose=np.dot(cam_pose,light_pose).copy()) + + light_pose[:3, 3] = [0, 1, 1] + self.scene.add(light, pose=np.dot(cam_pose,light_pose).copy()) + + light_pose[:3, 3] = [1, 1, 2] + self.scene.add(light, pose=np.dot(cam_pose,light_pose).copy()) + + """ok + light_pose = np.eye(4) + light_pose[:3, 3] = [0, -1, 1] + self.scene.add(light, pose=light_pose) + + light_pose[:3, 3] = [0, 1, 1] + self.scene.add(light, pose=light_pose) + + light_pose[:3, 3] = [1, 1, 2] + self.scene.add(light, pose=light_pose) + """ + + # light_pose[:3, 3] = [0, -2, 2] + # [droite, hauteur, profondeur camera] + """ + light_pose = np.eye(4) + light_pose[:3, 3] = [0, -1, 1] + self.scene.add(light, pose=light_pose) + + light_pose[:3, 3] = [0, 1, 1] + self.scene.add(light, pose=light_pose) + + light_pose[:3, 3] = [1, 1, 2] + self.scene.add(light, pose=light_pose) + """ + + def render(self, img, verts, cam, angle=None, axis=None, mesh_filename=None, color=[1.0, 1.0, 0.9], + cam_pose=np.eye(4)): + mesh = trimesh.Trimesh(vertices=verts, faces=self.faces, process=False) + Rx = trimesh.transformations.rotation_matrix(math.radians(180), [1, 0, 0]) + # Rx = trimesh.transformations.rotation_matrix(math.radians(-90), [1, 0, 0]) + mesh.apply_transform(Rx) + + if mesh_filename is not None: + mesh.export(mesh_filename) + + if angle and axis: + R = trimesh.transformations.rotation_matrix(math.radians(angle), axis) + mesh.apply_transform(R) + + sx, sy, tx, ty = cam + + camera = WeakPerspectiveCamera( + scale=[sx, sy], + translation=[tx, ty], + zfar=100000. + ) + + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.0, # 0.0 for no specular lighting + # metallicFactor=0.7, # 0.0 for no specular lighting + alphaMode='OPAQUE', + baseColorFactor=(color[0], color[1], color[2], 1.0) + ) + + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + mesh_node = self.scene.add(mesh, 'mesh') + + cam_node = self.scene.add(camera, pose=cam_pose) + + if self.wireframe: + render_flags = RenderFlags.RGBA | RenderFlags.ALL_WIREFRAME + else: + render_flags = RenderFlags.RGBA + + rgb, _ = self.renderer.render(self.scene, flags=render_flags) + if rgb.shape[-1]==3: + # Debug + # 0 not distinguish alpha + valid_mask = (rgb[:, :, -1] > 0)[:, :, np.newaxis] + output_img = rgb * valid_mask + (1 - valid_mask) * img + elif rgb.shape[-1]==4: + # valid_mask = (rgb[:, :, -1] > 128)[:, :, np.newaxis] + # output_img = rgb[:, :, :-1] * valid_mask + (1 - valid_mask) * img + + # # output alpha + valid_mask = (rgb[:, :, -1] > 128)[:, :] + output_img = np.copy(rgb) + output_img[:, :, -1] *= valid_mask + # output_img = img + else: + raise ValueError(f"rgb shape {rgb.shape[-1]} is not correct!") + image = output_img.astype(np.uint8) + + self.scene.remove_node(mesh_node) + self.scene.remove_node(cam_node) + + return image + + +def get_renderer(width, height, cam_pose): + renderer = Renderer(resolution=(width, height), + bg_color=[1, 1, 1, 0.5], + orig_img=False, + wireframe=False, + cam_pose=cam_pose) + return renderer diff --git a/mGPT/render/rendermotion.py b/mGPT/render/rendermotion.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d29d936f53374f6d08b4f6de5cc6723913e90a --- /dev/null +++ b/mGPT/render/rendermotion.py @@ -0,0 +1,134 @@ +import numpy as np +import imageio +import os +import argparse +from tqdm import tqdm +from .renderer import get_renderer + + +def get_rotation(theta=np.pi / 3): + import mGPT.utils.rotation_conversions as geometry + import torch + axis = torch.tensor([0, 1, 0], dtype=torch.float) + axisangle = theta * axis + matrix = geometry.axis_angle_to_matrix(axisangle) + return matrix.numpy() + + +def render_video(meshes, + key, + action, + renderer, + savepath, + backgrounds, + cam_pose, + cams=(0.75, 0.75, 0, 0.10), + color=[0.11, 0.53, 0.8]): + # cams=(0.75, 0.75, 0, 0.10), color=[165.0/255,112/255,140/255]): + # center the first frame + if key not in ["real", "ntf", "side"]: + w = int(key) / 6.0 + # purpole to green + # color = w*np.array([0.9,102/255,120/255]) + (1-w)*np.array([0.11, 0.9, 0.11]) + # color = (1-w)*np.array([165.0/255,112/255,140/255]) + w*np.array([0.11, 0.8, 0.11]) + color = (1 - w) * np.array([0.75, 0.13, 0.7]) + w * np.array( + [0.12, 0.7, 0.14]) + + meshes = meshes - meshes[0].mean(axis=0) + imgs = [] + idx = 0 + # for mesh in meshes: + for mesh in tqdm(meshes, desc=f"Visualize {key}, action {action}"): + # file_name = '3dpw_rot-90_glob_trimesh.ply' mesh_filename=file_name, + # prepare background + if len(backgrounds.shape) == 3: + background = backgrounds + cam = cams + elif len(backgrounds.shape) == 4: + background = backgrounds[idx] + cam = cams[idx] + idx += 1 + # prepare cams + img = renderer.render(background, + mesh, + cam, + color=color, + cam_pose=cam_pose) + imgs.append(img) + # show(img) + + imgs = np.array(imgs) + # masks = ~(imgs/255. > 0.96).all(-1) + # coords = np.argwhere(masks.sum(axis=0)) + # y1, x1 = coords.min(axis=0) + # y2, x2 = coords.max(axis=0) + # writer = imageio.get_writer(savepath, fps=30) + # for cimg in imgs[:, y1:y2, x1:x2]: + # writer.append_data(cimg) + # writer.close() + + # from mld.utils.uicap_utils import write_rgba_seqs + # write_rgba_seqs(imgs, savepath) + + writer = imageio.get_writer(savepath, fps=30) + for cimg in imgs: + writer.append_data(cimg) + writer.close() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("filename") + opt = parser.parse_args() + filename = opt.filename + savefolder = os.path.splitext(filename)[0] + os.makedirs(savefolder, exist_ok=True) + + output = np.load(filename) + + if output.shape[0] == 3: + visualization, generation, reconstruction = output + output = { + "visualization": visualization, + "generation": generation, + "reconstruction": reconstruction + } + else: + # output = {f"generation_{key}": output[key] for key in range(2)} # len(output))} + # output = {f"generation_{key}": output[key] for key in range(len(output))} + output = { + f"generation_{key}": output[key] + for key in range(len(output)) + } + + width = 1024 + height = 1024 + + background = np.zeros((height, width, 3)) + renderer = get_renderer(width, height) + + # if duration mode, put back durations + if output["generation_3"].shape[-1] == 100: + output["generation_0"] = output["generation_0"][:, :, :, :40] + output["generation_1"] = output["generation_1"][:, :, :, :60] + output["generation_2"] = output["generation_2"][:, :, :, :80] + output["generation_3"] = output["generation_3"][:, :, :, :100] + elif output["generation_3"].shape[-1] == 160: + print("160 mode") + output["generation_0"] = output["generation_0"][:, :, :, :100] + output["generation_1"] = output["generation_1"][:, :, :, :120] + output["generation_2"] = output["generation_2"][:, :, :, :140] + output["generation_3"] = output["generation_3"][:, :, :, :160] + + # if str(action) == str(1) and str(key) == "generation_4": + for key in output: + vidmeshes = output[key] + for action in range(len(vidmeshes)): + meshes = vidmeshes[action].transpose(2, 0, 1) + path = os.path.join(savefolder, + "action{}_{}.mp4".format(action, key)) + render_video(meshes, key, action, renderer, path, background) + + +if __name__ == "__main__": + main() diff --git a/mGPT/render/video.py b/mGPT/render/video.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d4eeb2072d5c23f5917efdb4aa8f21ed3a3bb5 --- /dev/null +++ b/mGPT/render/video.py @@ -0,0 +1,67 @@ +import moviepy.editor as mp +import moviepy.video.fx.all as vfx +import os +import imageio + + +def mask_png(frames): + for frame in frames: + im = imageio.imread(frame) + im[im[:, :, 3] < 1, :] = 255 + imageio.imwrite(frame, im[:, :, 0:3]) + return + + +class Video: + def __init__(self, frame_path: str, fps: float = 12.5, res="high"): + frame_path = str(frame_path) + self.fps = fps + + self._conf = {"codec": "libx264", + "fps": self.fps, + "audio_codec": "aac", + "temp_audiofile": "temp-audio.m4a", + "remove_temp": True} + + if res == "low": + bitrate = "500k" + else: + bitrate = "5000k" + + self._conf = {"bitrate": bitrate, + "fps": self.fps} + + # Load video + # video = mp.VideoFileClip(video1_path, audio=False) + # Load with frames + frames = [os.path.join(frame_path, x) + for x in sorted(os.listdir(frame_path))] + + # mask background white for videos + mask_png(frames) + + video = mp.ImageSequenceClip(frames, fps=fps) + self.video = video + self.duration = video.duration + + def add_text(self, text): + # needs ImageMagick + video_text = mp.TextClip(text, + font='Amiri', + color='white', + method='caption', + align="center", + size=(self.video.w, None), + fontsize=30) + video_text = video_text.on_color(size=(self.video.w, video_text.h + 5), + color=(0, 0, 0), + col_opacity=0.6) + # video_text = video_text.set_pos('bottom') + video_text = video_text.set_pos('top') + + self.video = mp.CompositeVideoClip([self.video, video_text]) + + def save(self, out_path): + out_path = str(out_path) + self.video.subclip(0, self.duration).write_videofile( + out_path, **self._conf) diff --git a/mGPT/render/visualize.py b/mGPT/render/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc9c6cd9f77ef8f031aa4a9f2fe5926f6b84272 --- /dev/null +++ b/mGPT/render/visualize.py @@ -0,0 +1,747 @@ +from operator import mod +import os +# from cv2 import CAP_PROP_INTELPERC_DEPTH_LOW_CONFIDENCE_VALUE +import imageio +import shutil +import numpy as np +import torch +from tqdm import tqdm + +from scipy.spatial.transform import Rotation as R +from mGPT.render.renderer import get_renderer +from mGPT.render.rendermotion import render_video +# from mld.utils.img_utils import convert_img +# from mld.utils.uicap_utils import output_pkl + + +def parsename(path): + basebane = os.path.basename(path) + base = os.path.splitext(basebane)[0] + strs = base.split('_') + key = strs[-2] + action = strs[-1] + return key, action + + +def load_anim(path, timesize=None): + data = np.array(imageio.mimread(path, memtest=False)) #[..., :3] + if timesize is None: + return data + + # take the last frame and put shadow repeat the last frame but with a little shadow + # lastframe = add_shadow(data[-1]) + # alldata = np.tile(lastframe, (timesize, 1, 1, 1)) + alldata = data + + # debug fix mat dim + if len(data.shape) == 3 and len(alldata.shape) == 4: + data = data[:, None, :, :] + + # copy the first frames + lenanim = data.shape[0] + alldata[:lenanim] = data[:lenanim] + return alldata + + +def plot_3d_motion_dico(x): + motion, length, save_path, params, kargs = x + plot_3d_motion(motion, length, save_path, params, **kargs) + + +def plot_3d_motion(motion, + length, + save_path, + params, + title="", + interval=50, + pred_cam=None, + imgs=None, + bbox=None, + side=None): + # render smpl + # [nframes, nVs, 3] + if motion.shape[1] == 6890: + # width = 250 + # height = 250 + width = 600 + height = 600 + if pred_cam is None: + # cam=(0.75, 0.75, 0, 0.1) + cam = (0.8, 0.8, 0, 0.1) + # cam=(0.9, 0.9, 0, 0.1) + else: + assert bbox is not None + assert imgs is not None + + # Tmp visulize + # weak perspective camera parameters in cropped image space (s,tx,ty) + # to + # weak perspective camera parameters in original image space (sx,sy,tx,ty) + cam = np.concatenate( + (pred_cam[:, [0]], pred_cam[:, [0]], pred_cam[:, 1:3]), axis=1) + + # ToDo convert to original cam + # load original img? + # calculate cam after padding??? + # + # cam = convert_crop_cam_to_orig_img( + # cam=pred_cam, + # bbox=bbox, + # img_width=width, + # img_height=height + # ) + cam_pose = np.eye(4) + cam_pose[0:3, 0:3] = R.from_euler('x', -90, degrees=True).as_matrix() + cam_pose[0:3, 3] = [0, 0, 0] + if side: + rz = np.eye(4) + rz[0:3, 0:3] = R.from_euler('z', -90, degrees=True).as_matrix() + cam_pose = np.matmul(rz, cam_pose) + + # # reshape input imgs + # if imgs is not None: + # imgs = convert_img(imgs.unsqueeze(0), height)[:,0] + backgrounds = imgs if imgs is not None else np.ones( + (height, width, 3)) * 255 + renderer = get_renderer(width, height, cam_pose) + + # [nframes, nVs, 3] + meshes = motion + key, action = parsename(save_path) + render_video(meshes, + key, + action, + renderer, + save_path, + backgrounds, + cam_pose, + cams=cam) + return + + +def stack_images(real, real_gens, gen, real_imgs=None): + # change to 3 channel + # print(real.shape) + # print(real_gens.shape) + # print(real_gens.shape) + # real = real[:3] + # real_gens = real_gens[:3] + # gen = gen[:3] + + nleft_cols = len(real_gens) + 1 + print("Stacking frames..") + allframes = np.concatenate( + (real[:, None, ...], *[x[:, None, ...] for x in real_gens], gen), 1) + nframes, nspa, nats, h, w, pix = allframes.shape + + blackborder = np.zeros((w // 30, h * nats, pix), dtype=allframes.dtype) + # blackborder = np.ones((w//30, h*nats, pix), dtype=allframes.dtype)*255 + frames = [] + for frame_idx in tqdm(range(nframes)): + columns = np.vstack(allframes[frame_idx].transpose(1, 2, 3, 4, + 0)).transpose( + 3, 1, 0, 2) + frame = np.concatenate( + (*columns[0:nleft_cols], blackborder, *columns[nleft_cols:]), + 0).transpose(1, 0, 2) + + frames.append(frame) + + if real_imgs is not None: + resize_imgs = convert_img(real_imgs, h)[:nframes, ...] + + for i in range(len(frames)): + imgs = np.vstack(resize_imgs[i, ...]) + imgs4 = np.ones( + (imgs.shape[0], imgs.shape[1], 4), dtype=np.uint8) * 255 + imgs4[:, :, :3] = imgs + #imgs = torch2numpy(imgs) + frames[i] = np.concatenate((imgs4, frames[i]), 1) + return np.stack(frames) + + +def stack_images_gen(gen, real_imgs=None): + print("Stacking frames..") + allframes = gen + nframes, nspa, nats, h, w, pix = allframes.shape + blackborder = np.zeros((w * nspa, h // 30, pix), dtype=allframes.dtype) + blackborder = blackborder[None, ...].repeat(nats, + axis=0).transpose(0, 2, 1, 3) + + frames = [] + for frame_idx in tqdm(range(nframes)): + rows = np.vstack(allframes[frame_idx].transpose(0, 3, 2, 4, + 1)).transpose( + 3, 1, 0, 2) + rows = np.concatenate((rows, blackborder), 1) + frame = np.concatenate(rows, 0) + frames.append(frame) + + if real_imgs is not None: + # ToDo Add images + resize_imgs = convert_img(real_imgs, h)[:nframes, ...] + for i in range(len(frames)): + imgs = np.vstack(resize_imgs[i, ...]) + #imgs = torch2numpy(imgs) + frames[i] = np.concatenate((imgs, frames[i]), 1) + return np.stack(frames) + + +def generate_by_video(visualization, reconstructions, generation, + label_to_action_name, params, nats, nspa, tmp_path): + # shape : (17, 3, 4, 480, 640, 3) + # (nframes, row, column, h, w, 3) + fps = params["fps"] + + params = params.copy() + + gen_only = False + if visualization is None: + gen_only = True + outputkey = "output_vertices" + params["pose_rep"] = "vertices" + elif "output_vertices" in visualization: + outputkey = "output_vertices" + params["pose_rep"] = "vertices" + elif "output_xyz" in visualization: + outputkey = "output_xyz" + params["pose_rep"] = "xyz" + else: + outputkey = "poses" + + keep = [outputkey, 'lengths', "y"] + gener = {key: generation[key].data.cpu().numpy() for key in keep} + if not gen_only: + visu = {key: visualization[key].data.cpu().numpy() for key in keep} + recons = {} + # visualize regressor results + if 'vertices_hat' in reconstructions['ntf']: + recons['regressor'] = { + 'output_vertices': + reconstructions['ntf']['vertices_hat'].data.cpu().numpy(), + 'lengths': + reconstructions['ntf']['lengths'].data.cpu().numpy(), + 'y': + reconstructions['ntf']['y'].data.cpu().numpy() + } + + recons['regressor_side'] = { + 'output_vertices': + reconstructions['ntf']['vertices_hat'].data.cpu().numpy(), + 'lengths': + reconstructions['ntf']['lengths'].data.cpu().numpy(), + 'y': + reconstructions['ntf']['y'].data.cpu().numpy(), + 'side': + True + } + # ToDo rendering overlap results + # recons['overlap'] = {'output_vertices':reconstructions['ntf']['vertices_hat'].data.cpu().numpy(), + # 'lengths':reconstructions['ntf']['lengths'].data.cpu().numpy(), + # 'y':reconstructions['ntf']['y'].data.cpu().numpy(), + # 'imgs':reconstructions['ntf']['imgs'], + # 'bbox':reconstructions['ntf']['bbox'].data.cpu().numpy(), + # 'cam':reconstructions['ntf']['preds'][0]['cam'].data.cpu().numpy()} + for mode, reconstruction in reconstructions.items(): + recons[mode] = { + key: reconstruction[key].data.cpu().numpy() + for key in keep + } + recons[mode + '_side'] = { + key: reconstruction[key].data.cpu().numpy() + for key in keep + } + recons[mode + '_side']['side'] = True + + # lenmax = max(gener['lengths'].max(), visu['lengths'].max()) + # timesize = lenmax + 5 longer visulization + lenmax = gener['lengths'].max() + timesize = lenmax + + import multiprocessing + + def pool_job_with_desc(pool, iterator, desc, max_, save_path_format, isij): + with tqdm(total=max_, desc=desc.format("Render")) as pbar: + for data in iterator: + plot_3d_motion_dico(data) + # for _ in pool.imap_unordered(plot_3d_motion_dico, iterator): + # pbar.update() + if isij: + array = np.stack([[ + load_anim(save_path_format.format(i, j), timesize) + for j in range(nats) + ] for i in tqdm(range(nspa), desc=desc.format("Load"))]) + return array.transpose(2, 0, 1, 3, 4, 5) + else: + array = np.stack([ + load_anim(save_path_format.format(i), timesize) + for i in tqdm(range(nats), desc=desc.format("Load")) + ]) + return array.transpose(1, 0, 2, 3, 4) + + pool = None + # if True: + with multiprocessing.Pool() as pool: + # Generated samples + save_path_format = os.path.join(tmp_path, "gen_{}_{}.gif") + iterator = ((gener[outputkey][i, j], gener['lengths'][i, j], + save_path_format.format(i, j), params, { + "title": + f"gen: {label_to_action_name(gener['y'][i, j])}", + "interval": 1000 / fps + }) for j in range(nats) for i in range(nspa)) + gener["frames"] = pool_job_with_desc(pool, iterator, + "{} the generated samples", + nats * nspa, save_path_format, + True) + if not gen_only: + # Real samples + save_path_format = os.path.join(tmp_path, "real_{}.gif") + iterator = ((visu[outputkey][i], visu['lengths'][i], + save_path_format.format(i), params, { + "title": + f"real: {label_to_action_name(visu['y'][i])}", + "interval": 1000 / fps + }) for i in range(nats)) + visu["frames"] = pool_job_with_desc(pool, iterator, + "{} the real samples", nats, + save_path_format, False) + for mode, recon in recons.items(): + # Reconstructed samples + save_path_format = os.path.join( + tmp_path, f"reconstructed_{mode}_" + "{}.gif") + if mode == 'overlap': + iterator = (( + recon[outputkey][i], recon['lengths'][i], + save_path_format.format(i), params, { + "title": + f"recons: {label_to_action_name(recon['y'][i])}", + "interval": 1000 / fps, + "pred_cam": recon['cam'][i], + "imgs": recon['imgs'][i], + "bbox": recon['bbox'][i] + }) for i in range(nats)) + else: + side = True if 'side' in recon.keys() else False + iterator = (( + recon[outputkey][i], recon['lengths'][i], + save_path_format.format(i), params, { + "title": + f"recons: {label_to_action_name(recon['y'][i])}", + "interval": 1000 / fps, + "side": side + }) for i in range(nats)) + recon["frames"] = pool_job_with_desc( + pool, iterator, "{} the reconstructed samples", nats, + save_path_format, False) + # vis img in visu + if not gen_only: + input_imgs = visualization["imgs"] if visualization[ + "imgs"] is not None else None + vis = visu["frames"] if not gen_only else None + rec = [recon["frames"] + for recon in recons.values()] if not gen_only else None + gen = gener["frames"] + frames = stack_images(vis, rec, gen, input_imgs) + else: + gen = gener["frames"] + frames = stack_images_gen(gen) + return frames + + +def viz_epoch(model, + dataset, + epoch, + params, + folder, + module=None, + writer=None, + exps=''): + """ Generate & viz samples """ + module = model if module is None else module + + # visualize with joints3D + model.outputxyz = True + + print(f"Visualization of the epoch {epoch}") + + noise_same_action = params["noise_same_action"] + noise_diff_action = params["noise_diff_action"] + duration_mode = params["duration_mode"] + reconstruction_mode = params["reconstruction_mode"] + decoder_test = params["decoder_test"] + + fact = params["fact_latent"] + figname = params["figname"].format(epoch) + + nspa = params["num_samples_per_action"] + nats = params["num_actions_to_sample"] + + num_classes = params["num_classes"] + # nats = min(num_classes, nats) + + # define some classes + classes = torch.randperm(num_classes)[:nats] + # duplicate same classes when sampling too much + if nats > num_classes: + classes = classes.expand(nats) + + meandurations = torch.from_numpy( + np.array([ + round(dataset.get_mean_length_label(cl.item())) for cl in classes + ])) + + if duration_mode == "interpolate" or decoder_test == "diffduration": + points, step = np.linspace(-nspa, nspa, nspa, retstep=True) + # points = np.round(10*points/step).astype(int) + points = np.array([5, 10, 16, 30, 60, 80]).astype(int) + # gendurations = meandurations.repeat((nspa, 1)) + points[:, None] + gendurations = torch.from_numpy(points[:, None]).expand( + (nspa, 1)).repeat((1, nats)) + else: + gendurations = meandurations.repeat((nspa, 1)) + print("Duration time: ") + print(gendurations[:, 0]) + + # extract the real samples + # real_samples, real_theta, mask_real, real_lengths, imgs, paths + batch = dataset.get_label_sample_batch(classes.numpy()) + + # ToDo + # clean these data + # Visualizaion of real samples + visualization = { + "x": batch['x'].to(model.device), + "y": classes.to(model.device), + "mask": batch['mask'].to(model.device), + 'lengths': batch['lengths'].to(model.device), + "output": batch['x'].to(model.device), + "theta": + batch['theta'].to(model.device) if 'theta' in batch.keys() else None, + "imgs": + batch['imgs'].to(model.device) if 'imgs' in batch.keys() else None, + "paths": batch['paths'] if 'paths' in batch.keys() else None, + } + + # Visualizaion of real samples + if reconstruction_mode == "both": + reconstructions = { + "tf": { + "x": + batch['x'].to(model.device), + "y": + classes.to(model.device), + 'lengths': + batch['lengths'].to(model.device), + "mask": + batch['mask'].to(model.device), + "teacher_force": + True, + "theta": + batch['theta'].to(model.device) + if 'theta' in batch.keys() else None + }, + "ntf": { + "x": + batch['x'].to(model.device), + "y": + classes.to(model.device), + 'lengths': + batch['lengths'].to(model.device), + "mask": + batch['mask'].to(model.device), + "theta": + batch['theta'].to(model.device) + if 'theta' in batch.keys() else None + } + } + else: + reconstructions = { + reconstruction_mode: { + "x": + batch['x'].to(model.device), + "y": + classes.to(model.device), + 'lengths': + batch['lengths'].to(model.device), + "mask": + batch['mask'].to(model.device), + "teacher_force": + reconstruction_mode == "tf", + "imgs": + batch['imgs'].to(model.device) + if 'imgs' in batch.keys() else None, + "theta": + batch['theta'].to(model.device) + if 'theta' in batch.keys() else None, + "bbox": + batch['bbox'] if 'bbox' in batch.keys() else None + } + } + print("Computing the samples poses..") + + # generate the repr (joints3D/pose etc) + model.eval() + with torch.no_grad(): + # Reconstruction of the real data + for mode in reconstructions: + # update reconstruction dicts + reconstructions[mode] = model(reconstructions[mode]) + reconstruction = reconstructions[list(reconstructions.keys())[0]] + + if decoder_test == "gt": + # Generate the new data + gt_input = { + "x": batch['x'].repeat(nspa, 1, 1, 1).to(model.device), + "y": classes.repeat(nspa).to(model.device), + "mask": batch['mask'].repeat(nspa, 1).to(model.device), + 'lengths': batch['lengths'].repeat(nspa).to(model.device) + } + generation = model(gt_input) + if decoder_test == "new": + # Generate the new data + generation = module.generate(gendurations, + classes=classes, + nspa=nspa, + noise_same_action=noise_same_action, + noise_diff_action=noise_diff_action, + fact=fact) + elif decoder_test == "diffaction": + assert nats == nspa + # keep the same noise for each "sample" + z = reconstruction["z"].repeat((nspa, 1)) + mask = reconstruction["mask"].repeat((nspa, 1)) + lengths = reconstruction['lengths'].repeat(nspa) + # but use other labels + y = classes.repeat_interleave(nspa).to(model.device) + generation = {"z": z, "y": y, "mask": mask, 'lengths': lengths} + model.decoder(generation) + + elif decoder_test == "diffduration": + z = reconstruction["z"].repeat((nspa, 1)) + lengths = gendurations.reshape(-1).to(model.device) + mask = model.lengths_to_mask(lengths) + y = classes.repeat(nspa).to(model.device) + generation = {"z": z, "y": y, "mask": mask, 'lengths': lengths} + model.decoder(generation) + + elif decoder_test == "interpolate_action": + assert nats == nspa + # same noise for each sample + z_diff_action = torch.randn(1, + model.latent_dim, + device=model.device).repeat(nats, 1) + z = z_diff_action.repeat((nspa, 1)) + + # but use combination of labels and labels below + y = F.one_hot(classes.to(model.device), + model.num_classes).to(model.device) + y_below = F.one_hot(torch.cat((classes[1:], classes[0:1])), + model.num_classes).to(model.device) + convex_factors = torch.linspace(0, 1, nspa, device=model.device) + y_mixed = torch.einsum("nk,m->mnk", y, 1-convex_factors) + \ + torch.einsum("nk,m->mnk", y_below, convex_factors) + y_mixed = y_mixed.reshape(nspa * nats, y_mixed.shape[-1]) + + durations = gendurations[0].to(model.device) + durations_below = torch.cat((durations[1:], durations[0:1])) + + gendurations = torch.einsum("l,k->kl", durations, 1-convex_factors) + \ + torch.einsum("l,k->kl", durations_below, convex_factors) + gendurations = gendurations.to(dtype=durations.dtype) + + lengths = gendurations.to(model.device).reshape(z.shape[0]) + mask = model.lengths_to_mask(lengths) + + generation = { + "z": z, + "y": y_mixed, + "mask": mask, + 'lengths': lengths + } + generation = model.decoder(generation) + + visualization = module.prepare(visualization) + visualization["output_xyz"] = visualization["x_xyz"] + visualization["output_vertices"] = visualization["x_vertices"] + # Get xyz for the real ones + # visualization["output_xyz"] = module.rot2xyz(visualization["output"], visualization["mask"], jointstype="smpl") + # # Get smpl vertices for the real ones + # if module.cvae.pose_rep != "xyz": + # visualization["output_vertices"] = module.rot2xyz(visualization["output"], visualization["mask"], jointstype="vertices") + + for key, val in generation.items(): + if len(generation[key].shape) == 1: + generation[key] = val.reshape(nspa, nats) + else: + generation[key] = val.reshape(nspa, nats, *val.shape[1:]) + + finalpath = os.path.join(folder, figname + exps + ".gif") + tmp_path = os.path.join(folder, f"subfigures_{figname}") + os.makedirs(tmp_path, exist_ok=True) + + print("Generate the videos..") + frames = generate_by_video(visualization, reconstructions, generation, + dataset.label_to_action_name, params, nats, + nspa, tmp_path) + + print(f"Writing video {finalpath}") + imageio.mimsave(finalpath.replace('gif', 'mp4'), frames, fps=params["fps"]) + shutil.rmtree(tmp_path) + + # output npy + output = { + "data_id": batch['id'], + "paths": batch['paths'], + "x": batch['x'].cpu().numpy(), + "x_vertices": visualization["x_vertices"].cpu().numpy(), + "output_vertices": + reconstructions['ntf']["output_vertices"].cpu().numpy(), + "gen_vertices": generation["output_vertices"].cpu().numpy() + } + + outputpath = finalpath.replace('gif', 'npy') + np.save(outputpath, output) + + # output pkl + batch_recon = reconstructions["ntf"] + outputpath = finalpath.replace('gif', 'pkl') + # output_pkl([batch_recon], outputpath) + + if writer is not None: + writer.add_video(f"Video/Epoch {epoch}", + frames.transpose(0, 3, 1, 2)[None], + epoch, + fps=params["fps"]) + return finalpath + + +def viz_dataset(dataset, params, folder): + """ Generate & viz samples """ + print("Visualization of the dataset") + + nspa = params["num_samples_per_action"] + nats = params["num_actions_to_sample"] + + num_classes = params["num_classes"] + + figname = "{}_{}_numframes_{}_sampling_{}_step_{}".format( + params["dataset"], params["pose_rep"], params["num_frames"], + params["sampling"], params["sampling_step"]) + + # define some classes + classes = torch.randperm(num_classes)[:nats] + + allclasses = classes.repeat(nspa, 1).reshape(nspa * nats) + # extract the real samples + real_samples, mask_real, real_lengths = dataset.get_label_sample_batch( + allclasses.numpy()) + # to visualize directly + + # Visualizaion of real samples + visualization = { + "x": real_samples, + "y": allclasses, + "mask": mask_real, + 'lengths': real_lengths, + "output": real_samples + } + + from mGPT.models.rotation2xyz import Rotation2xyz + + device = params["device"] + rot2xyz = Rotation2xyz(device=device) + + rot2xyz_params = { + "pose_rep": params["pose_rep"], + "glob_rot": params["glob_rot"], + "glob": params["glob"], + "jointstype": params["jointstype"], + "translation": params["translation"] + } + + output = visualization["output"] + visualization["output_xyz"] = rot2xyz(output.to(device), + visualization["mask"].to(device), + **rot2xyz_params) + + for key, val in visualization.items(): + if len(visualization[key].shape) == 1: + visualization[key] = val.reshape(nspa, nats) + else: + visualization[key] = val.reshape(nspa, nats, *val.shape[1:]) + + finalpath = os.path.join(folder, figname + ".gif") + tmp_path = os.path.join(folder, f"subfigures_{figname}") + os.makedirs(tmp_path, exist_ok=True) + + print("Generate the videos..") + frames = generate_by_video_sequences(visualization, + dataset.label_to_action_name, params, + nats, nspa, tmp_path) + + print(f"Writing video {finalpath}..") + imageio.mimsave(finalpath, frames, fps=params["fps"]) + + +def generate_by_video_sequences(visualization, label_to_action_name, params, + nats, nspa, tmp_path): + # shape : (17, 3, 4, 480, 640, 3) + # (nframes, row, column, h, w, 3) + fps = params["fps"] + if "output_vetices" in visualization: + outputkey = "output_vetices" + params["pose_rep"] = "vertices" + elif "output_xyz" in visualization: + outputkey = "output_xyz" + params["pose_rep"] = "xyz" + else: + outputkey = "poses" + + keep = [outputkey, 'lengths', "y"] + visu = {key: visualization[key].data.cpu().numpy() for key in keep} + lenmax = visu['lengths'].max() + + timesize = lenmax + 5 + + # import multiprocessing + + def pool_job_with_desc(pool, iterator, desc, max_, save_path_format): + for data in iterator: + plot_3d_motion_dico(data) + # with tqdm(total=max_, desc=desc.format("Render")) as pbar: + # for _ in pool.imap_unordered(plot_3d_motion_dico, iterator): + # pbar.update() + array = np.stack([[ + load_anim(save_path_format.format(i, j), timesize) + for j in range(nats) + ] for i in tqdm(range(nspa), desc=desc.format("Load"))]) + return array.transpose(2, 0, 1, 3, 4, 5) + + pool = None + # with multiprocessing.Pool() as pool: + # Real samples + save_path_format = os.path.join(tmp_path, "real_{}_{}.gif") + iterator = ((visu[outputkey][i, j], visu['lengths'][i, j], + save_path_format.format(i, j), params, { + "title": f"real: {label_to_action_name(visu['y'][i, j])}", + "interval": 1000 / fps + }) for j in range(nats) for i in range(nspa)) + visu["frames"] = pool_job_with_desc(pool, iterator, "{} the real samples", + nats, save_path_format) + frames = stack_images_sequence(visu["frames"]) + return frames + + +def stack_images_sequence(visu): + print("Stacking frames..") + allframes = visu + nframes, nspa, nats, h, w, pix = allframes.shape + frames = [] + for frame_idx in tqdm(range(nframes)): + columns = np.vstack(allframes[frame_idx].transpose(1, 2, 3, 4, + 0)).transpose( + 3, 1, 0, 2) + frame = np.concatenate(columns).transpose(1, 0, 2) + frames.append(frame) + return np.stack(frames) diff --git a/mGPT/utils/__init__.py b/mGPT/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mGPT/utils/demo_utils.py b/mGPT/utils/demo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6299332523f4ed07c72740b4eb433a6550022c --- /dev/null +++ b/mGPT/utils/demo_utils.py @@ -0,0 +1,79 @@ +import os +from pathlib import Path + + +# load example data +def load_example_input(txt_path): + file = open(txt_path, "r") + Lines = file.readlines() + count = 0 + texts, lens = [], [] + # Strips the newline character + for line in Lines: + count += 1 + s = line.strip() + s_l = s.split(" ")[0] + s_t = s[(len(s_l) + 1):] + lens.append(int(s_l)) + texts.append(s_t) + print("Length-{}: {}".format(s_l, s_t)) + return texts, lens + + +# render batch +def render_batch(npy_dir, execute_python="./scripts/visualize_motion.sh", mode="sequence"): + os.system(f"{execute_python} {npy_dir} {mode}") + + +# render +def render(execute_python, npy_path, jointtype, cfg_path): + # execute_python = "/apdcephfs/share_1227775/shingxchen/libs/blender_bpy/blender-2.93.2-linux-x64/blender" + # execute_python = "/apdcephfs/share_1227775/mingzhenzhu/jiangbiao/libs/blender-2.93.2-linux-x64/blender" + export_scripts = "render.py" + + os.system( + f"{execute_python} --background --python {export_scripts} -- --cfg={cfg_path} --npy={npy_path} --joint_type={jointtype}" + ) + + fig_path = Path(str(npy_path).replace(".npy", ".png")) + return fig_path + + +# origin render +# def render(npy_path, jointtype): +# execute_python = '/apdcephfs/share_1227775/shingxchen/libs/blender_bpy/blender-2.93.2-linux-x64/blender' +# export_scripts = 'render.py' + +# os.system(f"{execute_python} --background --python {export_scripts} -- npy={npy_path} jointstype={jointtype}") + +# fig_path = Path(str(npy_path).replace(".npy",".png")) +# return fig_path + +# export fbx with hand params from pkl files +# refer to /apdcephfs/share_1227775/shingxchen/AIMotion/TMOST/scripts/fbx_output_smplx.py +def export_fbx_hand(pkl_path): + input = pkl_path + output = pkl_path.replace(".pkl", ".fbx") + + execute_python = "/apdcephfs/share_1227775/shingxchen/libs/blender_bpy/blender-2.93.2-linux-x64/blender" + export_scripts = "./scripts/fbx_output_smplx.py" + os.system( + f"{execute_python} -noaudio --background --python {export_scripts}\ + --input {input} \ + --output {output}" + ) + + +# export fbx without hand params from pkl files +# refer to /apdcephfs/share_1227775/shingxchen/AIMotion/TMOST/scripts/fbx_output.py +def export_fbx(pkl_path): + input = pkl_path + output = pkl_path.replace(".pkl", ".fbx") + + execute_python = "/apdcephfs/share_1227775/shingxchen/libs/blender_bpy/blender-2.93.2-linux-x64/blender" + export_scripts = "./scripts/fbx_output.py" + os.system( + f"{execute_python} -noaudio --background --python {export_scripts}\ + --input {input} \ + --output {output}" + ) diff --git a/mGPT/utils/easyconvert.py b/mGPT/utils/easyconvert.py new file mode 100644 index 0000000000000000000000000000000000000000..248073783ec412b9089b195d61df94e6754a6bfe --- /dev/null +++ b/mGPT/utils/easyconvert.py @@ -0,0 +1,84 @@ +from .geometry_tools import * + + +def rep_to_rep(oldtype, newtype, rotations): + if newtype in ["matrix"]: + return to_matrix(oldtype, rotations) + + if oldtype in ["rotvec", "axisangle"]: + return axis_angle_to(newtype, rotations) + elif oldtype in ["matrix"]: + return matrix_to(newtype, rotations) + else: + raise NotImplementedError("Only rotvec and matrix are supported.") + +def nfeats_of(rottype): + if rottype in ["rotvec", "axisangle"]: + return 3 + elif rottype in ["rotquat", "quaternion"]: + return 4 + elif rottype in ["rot6d", "6drot", "rotation6d"]: + return 6 + elif rottype in ["rotmat"]: + return 9 + else: + return TypeError("This rotation type doesn't have features.") + + +def axis_angle_to(newtype, rotations): + if newtype in ["matrix"]: + rotations = axis_angle_to_matrix(rotations) + return rotations + elif newtype in ["rotmat"]: + rotations = axis_angle_to_matrix(rotations) + rotations = matrix_to("rotmat", rotations) + return rotations + elif newtype in ["rot6d", "6drot", "rotation6d"]: + rotations = axis_angle_to_matrix(rotations) + rotations = matrix_to("rot6d", rotations) + return rotations + elif newtype in ["rotquat", "quaternion"]: + rotations = axis_angle_to_quaternion(rotations) + return rotations + elif newtype in ["rotvec", "axisangle"]: + return rotations + else: + raise NotImplementedError + + +def matrix_to(newtype, rotations): + if newtype in ["matrix"]: + return rotations + if newtype in ["rotmat"]: + rotations = rotations.reshape((*rotations.shape[:-2], 9)) + return rotations + elif newtype in ["rot6d", "6drot", "rotation6d"]: + rotations = matrix_to_rotation_6d(rotations) + return rotations + elif newtype in ["rotquat", "quaternion"]: + rotations = matrix_to_quaternion(rotations) + return rotations + elif newtype in ["rotvec", "axisangle"]: + rotations = matrix_to_axis_angle(rotations) + return rotations + else: + raise NotImplementedError + + +def to_matrix(oldtype, rotations): + if oldtype in ["matrix"]: + return rotations + if oldtype in ["rotmat"]: + rotations = rotations.reshape((*rotations.shape[:-2], 3, 3)) + return rotations + elif oldtype in ["rot6d", "6drot", "rotation6d"]: + rotations = rotation_6d_to_matrix(rotations) + return rotations + elif oldtype in ["rotquat", "quaternion"]: + rotations = quaternion_to_matrix(rotations) + return rotations + elif oldtype in ["rotvec", "axisangle"]: + rotations = axis_angle_to_matrix(rotations) + return rotations + else: + raise NotImplementedError diff --git a/mGPT/utils/fixseed.py b/mGPT/utils/fixseed.py new file mode 100644 index 0000000000000000000000000000000000000000..a43a273b138c45dccafef4da3628dd4c2a3f84a4 --- /dev/null +++ b/mGPT/utils/fixseed.py @@ -0,0 +1,18 @@ +import numpy as np +import torch +import random + + +def fixseed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +SEED = 10 +EVALSEED = 0 +# Provoc warning: not fully functionnal yet +# torch.set_deterministic(True) +torch.backends.cudnn.benchmark = False + +fixseed(SEED) diff --git a/mGPT/utils/geometry_conver.py b/mGPT/utils/geometry_conver.py new file mode 100644 index 0000000000000000000000000000000000000000..5c5c8443c47f685dec0350b533f165c413feac69 --- /dev/null +++ b/mGPT/utils/geometry_conver.py @@ -0,0 +1,550 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import torch +import numpy as np +from torch.nn import functional as F + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles]) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], + dim=-1) + return quaternions + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_of_angles(cos, sin, inv=False, dim=2): + assert dim in [2, 3] + sin = -sin if inv else sin + if dim == 2: + row1 = torch.stack((cos, -sin), axis=-1) + row2 = torch.stack((sin, cos), axis=-1) + return torch.stack((row1, row2), axis=-2) + elif dim == 3: + row1 = torch.stack((cos, -sin, 0 * cos), axis=-1) + row2 = torch.stack((sin, cos, 0 * cos), axis=-1) + row3 = torch.stack((0 * sin, 0 * cos, 1 + 0 * cos), axis=-1) + return torch.stack((row1, row2, row3), axis=-2) + + +def matrot2axisangle(matrots): + # This function is borrowed from https://github.com/davrempe/humor/utils/transforms.py + # axisang N x 3 + ''' + :param matrots: N*num_joints*9 + :return: N*num_joints*3 + ''' + import cv2 + batch_size = matrots.shape[0] + matrots = matrots.reshape([batch_size, -1, 9]) + out_axisangle = [] + for mIdx in range(matrots.shape[0]): + cur_axisangle = [] + for jIdx in range(matrots.shape[1]): + a = cv2.Rodrigues(matrots[mIdx, + jIdx:jIdx + 1, :].reshape(3, + 3))[0].reshape( + (1, 3)) + cur_axisangle.append(a) + + out_axisangle.append(np.array(cur_axisangle).reshape([1, -1, 3])) + return np.vstack(out_axisangle) + + +def axisangle2matrots(axisangle): + # This function is borrowed from https://github.com/davrempe/humor/utils/transforms.py + # axisang N x 3 + ''' + :param axisangle: N*num_joints*3 + :return: N*num_joints*9 + ''' + import cv2 + batch_size = axisangle.shape[0] + axisangle = axisangle.reshape([batch_size, -1, 3]) + out_matrot = [] + for mIdx in range(axisangle.shape[0]): + cur_axisangle = [] + for jIdx in range(axisangle.shape[1]): + a = cv2.Rodrigues(axisangle[mIdx, jIdx:jIdx + 1, :].reshape(1, + 3))[0] + cur_axisangle.append(a) + + out_matrot.append(np.array(cur_axisangle).reshape([1, -1, 9])) + return np.vstack(out_matrot) + + +def batch_rodrigues(axisang): + # This function is borrowed from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py#L37 + # axisang N x 3 + axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(axisang_norm, -1) + axisang_normalized = torch.div(axisang, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + + quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1) + rot_mat = quat2mat(quat) + rot_mat = rot_mat.view(rot_mat.shape[0], 9) + return rot_mat + + +def quat2mat(quat): + """ + This function is borrowed from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py#L50 + + Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [batch_size, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, + 2], norm_quat[:, + 3] + + batch_size = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, + w2 - x2 - y2 + z2 + ], + dim=1).view(batch_size, 3, 3) + return rotMat + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to Rodrigues vector + + Args: + rotation_matrix (Tensor): rotation matrix. + + Returns: + Tensor: Rodrigues vector transformation. + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 3)` + + Example: + >>> input = torch.rand(2, 3, 4) # Nx4x4 + >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 + """ + if rotation_matrix.shape[1:] == (3, 3): + rot_mat = rotation_matrix.reshape(-1, 3, 3) + hom = torch.tensor([0, 0, 1], + dtype=torch.float32, + device=rotation_matrix.device).reshape( + 1, 3, 1).expand(rot_mat.shape[0], -1, -1) + rotation_matrix = torch.cat([rot_mat, hom], dim=-1) + + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + aa = quaternion_to_angle_axis(quaternion) + aa[torch.isnan(aa)] = 0.0 + return aa + + +def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert quaternion vector to angle axis of rotation. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (torch.Tensor): tensor with quaternions. + + Return: + torch.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError( + "Input must be a tensor of shape Nx4 or 4. Got {}".format( + quaternion.shape)) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) + cos_theta: torch.Tensor = quaternion[..., 0] + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta)) + + k_pos: torch.Tensor = two_theta / sin_theta + k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) + k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to 4d quaternion vector + + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + Args: + rotation_matrix (Tensor): the rotation matrix to convert. + + Return: + Tensor: the rotation in quaternion + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 4)` + + Example: + >>> input = torch.rand(4, 3, 4) # Nx3x4 + >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 + """ + if not torch.is_tensor(rotation_matrix): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(rotation_matrix))) + + if len(rotation_matrix.shape) > 3: + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format( + rotation_matrix.shape)) + if not rotation_matrix.shape[-2:] == (3, 4): + raise ValueError( + "Input size must be a N x 3 x 4 tensor. Got {}".format( + rotation_matrix.shape)) + + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack([ + rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2] + ], -1) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack([ + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1] + ], -1) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack([ + rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2 + ], -1) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack([ + t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0] + ], -1) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * ~mask_d0_d1 + mask_c2 = ~mask_d2 * mask_d0_nd1 + mask_c3 = ~mask_d2 * ~mask_d0_nd1 + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa + q *= 0.5 + return q + + +def estimate_translation_np(S, + joints_2d, + joints_conf, + focal_length=5000., + img_size=224.): + """ + This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py + + Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (25, 3) 3D joint locations + joints: (25, 3) 2D joint locations and confidence + Returns: + (3,) camera translation vector + """ + + num_joints = S.shape[0] + # focal length + f = np.array([focal_length, focal_length]) + # optical center + center = np.array([img_size / 2., img_size / 2.]) + + # transformations + Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1) + XY = np.reshape(S[:, 0:2], -1) + O = np.tile(center, num_joints) + F = np.tile(f, num_joints) + weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1) + + # least squares + Q = np.array([ + F * np.tile(np.array([1, 0]), num_joints), + F * np.tile(np.array([0, 1]), num_joints), + O - np.reshape(joints_2d, -1) + ]).T + c = (np.reshape(joints_2d, -1) - O) * Z - F * XY + + # weighted least squares + W = np.diagflat(weight2) + Q = np.dot(W, Q) + c = np.dot(W, c) + + # square matrix + A = np.dot(Q.T, Q) + b = np.dot(Q.T, c) + + # solution + trans = np.linalg.solve(A, b) + + return trans + + +def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.): + """ + This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py + + Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (B, 49, 3) 3D joint locations + joints: (B, 49, 3) 2D joint locations and confidence + Returns: + (B, 3) camera translation vectors + """ + + device = S.device + # Use only joints 25:49 (GT joints) + S = S[:, 25:, :].cpu().numpy() + joints_2d = joints_2d[:, 25:, :].cpu().numpy() + joints_conf = joints_2d[:, :, -1] + joints_2d = joints_2d[:, :, :-1] + trans = np.zeros((S.shape[0], 3), dtype=np.float6432) + # Find the translation for each example in the batch + for i in range(S.shape[0]): + S_i = S[i] + joints_i = joints_2d[i] + conf_i = joints_conf[i] + trans[i] = estimate_translation_np(S_i, + joints_i, + conf_i, + focal_length=focal_length, + img_size=img_size) + return torch.from_numpy(trans).to(device) + + +def rot6d_to_rotmat_spin(x): + """Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1, 3, 2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + + # inp = a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1 + # denom = inp.pow(2).sum(dim=1).sqrt().unsqueeze(-1) + 1e-8 + # b2 = inp / denom + + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + + +def rot6d_to_rotmat(x): + x = x.view(-1, 3, 2) + + # Normalize the first vector + b1 = F.normalize(x[:, :, 0], dim=1, eps=1e-6) + + dot_prod = torch.sum(b1 * x[:, :, 1], dim=1, keepdim=True) + # Compute the second vector by finding the orthogonal complement to it + b2 = F.normalize(x[:, :, 1] - dot_prod * b1, dim=-1, eps=1e-6) + + # Finish building the basis by taking the cross product + b3 = torch.cross(b1, b2, dim=1) + rot_mats = torch.stack([b1, b2, b3], dim=-1) + + return rot_mats + + +import mGPT.utils.rotation_conversions as rotation_conversions + + +def rot6d(x_rotations, pose_rep): + time, njoints, feats = x_rotations.shape + + # Compute rotations (convert only masked sequences output) + if pose_rep == "rotvec": + rotations = rotation_conversions.axis_angle_to_matrix(x_rotations) + elif pose_rep == "rotmat": + rotations = x_rotations.view(njoints, 3, 3) + elif pose_rep == "rotquat": + rotations = rotation_conversions.quaternion_to_matrix(x_rotations) + elif pose_rep == "rot6d": + rotations = rotation_conversions.rotation_6d_to_matrix(x_rotations) + else: + raise NotImplementedError("No geometry for this one.") + + rotations_6d = rotation_conversions.matrix_to_rotation_6d(rotations) + return rotations_6d + + +def rot6d_batch(x_rotations, pose_rep): + nsamples, time, njoints, feats = x_rotations.shape + + # Compute rotations (convert only masked sequences output) + if pose_rep == "rotvec": + rotations = rotation_conversions.axis_angle_to_matrix(x_rotations) + elif pose_rep == "rotmat": + rotations = x_rotations.view(-1, njoints, 3, 3) + elif pose_rep == "rotquat": + rotations = rotation_conversions.quaternion_to_matrix(x_rotations) + elif pose_rep == "rot6d": + rotations = rotation_conversions.rotation_6d_to_matrix(x_rotations) + else: + raise NotImplementedError("No geometry for this one.") + + rotations_6d = rotation_conversions.matrix_to_rotation_6d(rotations) + return rotations_6d + + +def rot6d_to_rotvec_batch(pose): + # nsamples, time, njoints, feats = rot6d.shape + bs, nfeats = pose.shape + rot6d = pose.reshape(bs, 24, 6) + rotations = rotation_conversions.rotation_6d_to_matrix(rot6d) + rotvec = rotation_conversions.matrix_to_axis_angle(rotations) + return rotvec.reshape(bs, 24 * 3) diff --git a/mGPT/utils/geometry_tools.py b/mGPT/utils/geometry_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..e6eafa2e1f2459a0f6f5ad1280c71e6a9625549e --- /dev/null +++ b/mGPT/utils/geometry_tools.py @@ -0,0 +1,566 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Check PYTORCH3D_LICENCE before use + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +# Added +def matrix_of_angles(cos, sin, inv=False, dim=2): + assert dim in [2, 3] + sin = -sin if inv else sin + if dim == 2: + row1 = torch.stack((cos, -sin), axis=-1) + row2 = torch.stack((sin, cos), axis=-1) + return torch.stack((row1, row2), axis=-2) + elif dim == 3: + row1 = torch.stack((cos, -sin, 0*cos), axis=-1) + row2 = torch.stack((sin, cos, 0*cos), axis=-1) + row3 = torch.stack((0*sin, 0*cos, 1+0*cos), axis=-1) + return torch.stack((row1, row2, row3),axis=-2) + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/mGPT/utils/joints.py b/mGPT/utils/joints.py new file mode 100644 index 0000000000000000000000000000000000000000..98199c6831c3416a6be3170b21ae3be119ef8981 --- /dev/null +++ b/mGPT/utils/joints.py @@ -0,0 +1,444 @@ +mmm_joints = [ + "root", + "BP", + "BT", + "BLN", + "BUN", + "LS", + "LE", + "LW", + "RS", + "RE", + "RW", + "LH", + "LK", + "LA", + "LMrot", + "LF", + "RH", + "RK", + "RA", + "RMrot", + "RF", +] + +humanml3d_joints = [ + "root", + "RH", + "LH", + "BP", + "RK", + "LK", + "BT", + "RMrot", + "LMrot", + "BLN", + "RF", + "LF", + "BMN", + "RSI", + "LSI", + "BUN", + "RS", + "LS", + "RE", + "LE", + "RW", + "LW", +] + +smplx_joints = [ + "pelvis", + "left_hip", + "right_hip", + "spine1", + "left_knee", + "right_knee", + "spine2", + "left_ankle", + "right_ankle", + "spine3", + "left_foot", + "right_foot", + "neck", + "left_collar", + "right_collar", + "head", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "jaw", + "left_eye_smplhf", + "right_eye_smplhf", + "left_index1", + "left_index2", + "left_index3", + "left_middle1", + "left_middle2", + "left_middle3", + "left_pinky1", + "left_pinky2", + "left_pinky3", + "left_ring1", + "left_ring2", + "left_ring3", + "left_thumb1", + "left_thumb2", + "left_thumb3", + "right_index1", + "right_index2", + "right_index3", + "right_middle1", + "right_middle2", + "right_middle3", + "right_pinky1", + "right_pinky2", + "right_pinky3", + "right_ring1", + "right_ring2", + "right_ring3", + "right_thumb1", + "right_thumb2", + "right_thumb3", + "nose", + "right_eye", + "left_eye", + "right_ear", + "left_ear", + "left_big_toe", + "left_small_toe", + "left_heel", + "right_big_toe", + "right_small_toe", + "right_heel", + "left_thumb", + "left_index", + "left_middle", + "left_ring", + "left_pinky", + "right_thumb", + "right_index", + "right_middle", + "right_ring", + "right_pinky", + "right_eye_brow1", + "right_eye_brow2", + "right_eye_brow3", + "right_eye_brow4", + "right_eye_brow5", + "left_eye_brow5", + "left_eye_brow4", + "left_eye_brow3", + "left_eye_brow2", + "left_eye_brow1", + "nose1", + "nose2", + "nose3", + "nose4", + "right_nose_2", + "right_nose_1", + "nose_middle", + "left_nose_1", + "left_nose_2", + "right_eye1", + "right_eye2", + "right_eye3", + "right_eye4", + "right_eye5", + "right_eye6", + "left_eye4", + "left_eye3", + "left_eye2", + "left_eye1", + "left_eye6", + "left_eye5", + "right_mouth_1", + "right_mouth_2", + "right_mouth_3", + "mouth_top", + "left_mouth_3", + "left_mouth_2", + "left_mouth_1", + "left_mouth_5", # 59 in OpenPose output + "left_mouth_4", # 58 in OpenPose output + "mouth_bottom", + "right_mouth_4", + "right_mouth_5", + "right_lip_1", + "right_lip_2", + "lip_top", + "left_lip_2", + "left_lip_1", + "left_lip_3", + "lip_bottom", + "right_lip_3", + # Face contour + "right_contour_1", + "right_contour_2", + "right_contour_3", + "right_contour_4", + "right_contour_5", + "right_contour_6", + "right_contour_7", + "right_contour_8", + "contour_middle", + "left_contour_8", + "left_contour_7", + "left_contour_6", + "left_contour_5", + "left_contour_4", + "left_contour_3", + "left_contour_2", + "left_contour_1", +] + +smplxnh_joints = [ + "pelvis", + "left_hip", + "right_hip", + "spine1", + "left_knee", + "right_knee", + "spine2", + "left_ankle", + "right_ankle", + "spine3", + "left_foot", + "right_foot", + "neck", + "left_collar", + "right_collar", + "head", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", +] + +smplh_joints = [ + "pelvis", + "left_hip", + "right_hip", + "spine1", + "left_knee", + "right_knee", + "spine2", + "left_ankle", + "right_ankle", + "spine3", + "left_foot", + "right_foot", + "neck", + "left_collar", + "right_collar", + "head", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_index1", + "left_index2", + "left_index3", + "left_middle1", + "left_middle2", + "left_middle3", + "left_pinky1", + "left_pinky2", + "left_pinky3", + "left_ring1", + "left_ring2", + "left_ring3", + "left_thumb1", + "left_thumb2", + "left_thumb3", + "right_index1", + "right_index2", + "right_index3", + "right_middle1", + "right_middle2", + "right_middle3", + "right_pinky1", + "right_pinky2", + "right_pinky3", + "right_ring1", + "right_ring2", + "right_ring3", + "right_thumb1", + "right_thumb2", + "right_thumb3", + "nose", + "right_eye", + "left_eye", + "right_ear", + "left_ear", + "left_big_toe", + "left_small_toe", + "left_heel", + "right_big_toe", + "right_small_toe", + "right_heel", + "left_thumb", + "left_index", + "left_middle", + "left_ring", + "left_pinky", + "right_thumb", + "right_index", + "right_middle", + "right_ring", + "right_pinky", +] + +smplnh_joints = [ + "pelvis", + "left_hip", + "right_hip", + "spine1", + "left_knee", + "right_knee", + "spine2", + "left_ankle", + "right_ankle", + "spine3", + "left_foot", + "right_foot", + "neck", + "left_collar", + "right_collar", + "head", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", +] + +mmm2smplh_correspondence = { + "root": "pelvis", + "BP": "spine1", + "BT": "spine3", + "BLN": "neck", + "BUN": "head", + "LS": "left_shoulder", + "LE": "left_elbow", + "LW": "left_wrist", + "RS": "right_shoulder", + "RE": "right_elbow", + "RW": "right_wrist", + "LH": "left_hip", + "LK": "left_knee", + "LA": "left_ankle", + "LMrot": "left_heel", + "LF": "left_foot", + "RH": "right_hip", + "RK": "right_knee", + "RA": "right_ankle", + "RMrot": "right_heel", + "RF": "right_foot", +} + +smplh2mmm_correspondence = { + val: key + for key, val in mmm2smplh_correspondence.items() +} +smplh2mmm_indexes = [ + smplh_joints.index(mmm2smplh_correspondence[x]) for x in mmm_joints +] + +smplnh2smplh_correspondence = {key: key for key in smplnh_joints} +smplh2smplnh_correspondence = { + val: key + for key, val in smplnh2smplh_correspondence.items() +} + +smplh2smplnh_indexes = [ + smplh_joints.index(smplnh2smplh_correspondence[x]) for x in smplnh_joints +] + +mmm_kinematic_tree = [ + [0, 1, 2, 3, 4], # body + [3, 5, 6, 7], # right arm + [3, 8, 9, 10], # left arm + [0, 11, 12, 13, 14, 15], # right leg + [0, 16, 17, 18, 19, 20], +] # left leg + +humanml3d_kinematic_tree = [ + [0, 3, 6, 9, 12, 15], # body + [9, 14, 17, 19, 21], # right arm + [9, 13, 16, 18, 20], # left arm + [0, 2, 5, 8, 11], # right leg + [0, 1, 4, 7, 10], +] # left leg + +smplh_to_mmm_scaling_factor = 480 / 0.75 +mmm_to_smplh_scaling_factor = 0.75 / 480 + +mmm_joints_info = { + "root": + mmm_joints.index("root"), + "feet": [ + mmm_joints.index("LMrot"), + mmm_joints.index("RMrot"), + mmm_joints.index("LF"), + mmm_joints.index("RF"), + ], + "shoulders": [mmm_joints.index("LS"), + mmm_joints.index("RS")], + "hips": [mmm_joints.index("LH"), + mmm_joints.index("RH")], +} + +smplnh_joints_info = { + "root": + smplnh_joints.index("pelvis"), + "feet": [ + smplnh_joints.index("left_ankle"), + smplnh_joints.index("right_ankle"), + smplnh_joints.index("left_foot"), + smplnh_joints.index("right_foot"), + ], + "shoulders": [ + smplnh_joints.index("left_shoulder"), + smplnh_joints.index("right_shoulder"), + ], + "hips": + [smplnh_joints.index("left_hip"), + smplnh_joints.index("right_hip")], +} + +infos = {"mmm": mmm_joints_info, "smplnh": smplnh_joints_info} + +smplh_indexes = {"mmm": smplh2mmm_indexes, "smplnh": smplh2smplnh_indexes} + +root_joints = { + "mmm": mmm_joints_info["root"], + "mmmns": mmm_joints_info["root"], + "smplmmm": mmm_joints_info["root"], + "smplnh": smplnh_joints_info["root"], + "smplh": smplh_joints.index("pelvis"), +} + + +def get_root_idx(joinstype): + return root_joints[joinstype] + + +# def mmm2smpl(joints_mmm): +# mmm2smplnh_indexes = [] +# for x in smplnh_joints: +# if x in smplh2mmm_correspondence: +# mmm2smplnh_indexes.append(mmm_joints.index(smplh2mmm_correspondence[x])) + +# spine2 = 0.5*(joints[mmm_joints.index("spine1")] + joints[mmm_joints.index("spine3")]) + +# joints = joints_mmm[indexes] +# return joints diff --git a/mGPT/utils/load_checkpoint.py b/mGPT/utils/load_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..121495808add1285ba0a4db3a1686d552547898c --- /dev/null +++ b/mGPT/utils/load_checkpoint.py @@ -0,0 +1,34 @@ +import torch + +def load_pretrained(cfg, model, logger, phase="train"): + logger.info(f"Loading pretrain model from {cfg.TRAIN.PRETRAINED}") + if phase == "train": + ckpt_path = cfg.TRAIN.PRETRAINED + elif phase == "test": + ckpt_path = cfg.TEST.CHECKPOINTS + + state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + model.load_state_dict(state_dict, strict=True) + return model + + +def load_pretrained_vae(cfg, model, logger): + state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE, + map_location="cpu")['state_dict'] + logger.info(f"Loading pretrain vae from {cfg.TRAIN.PRETRAINED_VAE}") + # Extract encoder/decoder + from collections import OrderedDict + vae_dict = OrderedDict() + for k, v in state_dict.items(): + if "motion_vae" in k: + name = k.replace("motion_vae.", "") + vae_dict[name] = v + elif "vae" in k: + name = k.replace("vae.", "") + vae_dict[name] = v + if hasattr(model, 'vae'): + model.vae.load_state_dict(vae_dict, strict=True) + else: + model.motion_vae.load_state_dict(vae_dict, strict=True) + + return model diff --git a/mGPT/utils/logger.py b/mGPT/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..6e1cafed01595c0afb6b3f8c9674e688342302a8 --- /dev/null +++ b/mGPT/utils/logger.py @@ -0,0 +1,68 @@ +from pathlib import Path +import os +import time +import logging +from omegaconf import OmegaConf +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +def create_logger(cfg, phase='train'): + # root dir set by cfg + root_output_dir = Path(cfg.FOLDER) + # set up logger + if not root_output_dir.exists(): + print('=> creating {}'.format(root_output_dir)) + root_output_dir.mkdir() + + cfg_name = cfg.NAME + model = cfg.model.target.split('.')[-2] + cfg_name = os.path.basename(cfg_name).split('.')[0] + + final_output_dir = root_output_dir / model / cfg_name + cfg.FOLDER_EXP = str(final_output_dir) + + time_str = time.strftime('%Y-%m-%d-%H-%M-%S') + + new_dir(cfg, phase, time_str, final_output_dir) + + head = '%(asctime)-15s %(message)s' + logger = config_logger(final_output_dir, time_str, phase, head) + if logger is None: + logger = logging.getLogger() + logger.setLevel(logging.CRITICAL) + logging.basicConfig(format=head) + return logger + + +@rank_zero_only +def config_logger(final_output_dir, time_str, phase, head): + log_file = '{}_{}_{}.log'.format('log', time_str, phase) + final_log_file = final_output_dir / log_file + logging.basicConfig(filename=str(final_log_file)) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + console = logging.StreamHandler() + formatter = logging.Formatter(head) + console.setFormatter(formatter) + logging.getLogger('').addHandler(console) + file_handler = logging.FileHandler(final_log_file, 'w') + file_handler.setFormatter(logging.Formatter(head)) + file_handler.setLevel(logging.INFO) + logging.getLogger('').addHandler(file_handler) + return logger + + +@rank_zero_only +def new_dir(cfg, phase, time_str, final_output_dir): + # new experiment folder + cfg.TIME = str(time_str) + if os.path.exists(final_output_dir) and not os.path.exists(cfg.TRAIN.RESUME) and not cfg.DEBUG and phase not in ['test', 'demo']: + file_list = sorted(os.listdir(final_output_dir), reverse=True) + for item in file_list: + if item.endswith('.log'): + os.rename(str(final_output_dir), str(final_output_dir) + '_' + cfg.TIME) + break + final_output_dir.mkdir(parents=True, exist_ok=True) + # write config yaml + config_file = '{}_{}_{}.yaml'.format('config', time_str, phase) + final_config_file = final_output_dir / config_file + OmegaConf.save(config=cfg, f=final_config_file) diff --git a/mGPT/utils/misc.py b/mGPT/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2a68d68019098e66905e0e21cb96678031bab0 --- /dev/null +++ b/mGPT/utils/misc.py @@ -0,0 +1,29 @@ +import torch + + +def to_numpy(tensor): + if torch.is_tensor(tensor): + return tensor.cpu().numpy() + elif type(tensor).__module__ != 'numpy': + raise ValueError("Cannot convert {} to numpy array".format( + type(tensor))) + return tensor + + +def to_torch(ndarray): + if type(ndarray).__module__ == 'numpy': + return torch.from_numpy(ndarray) + elif not torch.is_tensor(ndarray): + raise ValueError("Cannot convert {} to torch tensor".format( + type(ndarray))) + return ndarray + + +def cleanexit(): + import sys + import os + try: + sys.exit(0) + except SystemExit: + os._exit(0) + diff --git a/mGPT/utils/rotation_conversions.py b/mGPT/utils/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..770c3bf36f05fcaf89cbb03e17035357f3c0a4df --- /dev/null +++ b/mGPT/utils/rotation_conversions.py @@ -0,0 +1,551 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Check PYTORCH3D_LICENCE before use + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/mGPT/utils/sample_utils.py b/mGPT/utils/sample_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..724b109af5ecd08e20cf0af681df317b390d4d44 --- /dev/null +++ b/mGPT/utils/sample_utils.py @@ -0,0 +1,18 @@ +import logging +from pathlib import Path +logger = logging.getLogger(__name__) + +def cfg_mean_nsamples_resolution(cfg): + if cfg.mean and cfg.number_of_samples > 1: + logger.error("All the samples will be the mean.. cfg.number_of_samples=1 will be forced.") + cfg.number_of_samples = 1 + + return cfg.number_of_samples == 1 + + +def get_path(sample_path: Path, is_amass: bool, gender: str, split: str, onesample: bool, mean: bool, fact: float): + extra_str = ("_mean" if mean else "") if onesample else "_multi" + fact_str = "" if fact == 1 else f"{fact}_" + gender_str = gender + "_" if is_amass else "" + path = sample_path / f"{fact_str}{gender_str}{split}{extra_str}" + return path diff --git a/mGPT/utils/temos_utils.py b/mGPT/utils/temos_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..888fd69dbece5727f3955548af4faf6bb9d430bd --- /dev/null +++ b/mGPT/utils/temos_utils.py @@ -0,0 +1,133 @@ +from typing import Dict, List + +import numpy as np +import torch +from torch import Tensor + +import mGPT.utils.geometry_conver as geometry_conver + + +def lengths_to_mask(lengths: List[int], + device: torch.device, + max_len: int = None) -> Tensor: + lengths = torch.tensor(lengths, device=device) + max_len = max_len if max_len else max(lengths) + mask = torch.arange(max_len, device=device).expand( + len(lengths), max_len) < lengths.unsqueeze(1) + return mask + + +def detach_to_numpy(tensor): + return tensor.detach().cpu().numpy() + + +def remove_padding(tensors, lengths): + return [ + tensor[:tensor_length] + for tensor, tensor_length in zip(tensors, lengths) + ] + + +def nfeats_of(rottype): + if rottype in ["rotvec", "axisangle"]: + return 3 + elif rottype in ["rotquat", "quaternion"]: + return 4 + elif rottype in ["rot6d", "6drot", "rotation6d"]: + return 6 + elif rottype in ["rotmat"]: + return 9 + else: + return TypeError("This rotation type doesn't have features.") + + +def axis_angle_to(newtype, rotations): + if newtype in ["matrix"]: + rotations = geometry_conver.axis_angle_to_matrix(rotations) + return rotations + elif newtype in ["rotmat"]: + rotations = geometry_conver.axis_angle_to_matrix(rotations) + rotations = matrix_to("rotmat", rotations) + return rotations + elif newtype in ["rot6d", "6drot", "rotation6d"]: + rotations = geometry_conver.axis_angle_to_matrix(rotations) + rotations = matrix_to("rot6d", rotations) + return rotations + elif newtype in ["rotquat", "quaternion"]: + rotations = geometry_conver.axis_angle_to_quaternion(rotations) + return rotations + elif newtype in ["rotvec", "axisangle"]: + return rotations + else: + raise NotImplementedError + + +def matrix_to(newtype, rotations): + if newtype in ["matrix"]: + return rotations + if newtype in ["rotmat"]: + rotations = rotations.reshape((*rotations.shape[:-2], 9)) + return rotations + elif newtype in ["rot6d", "6drot", "rotation6d"]: + rotations = geometry_conver.matrix_to_rotation_6d(rotations) + return rotations + elif newtype in ["rotquat", "quaternion"]: + rotations = geometry_conver.matrix_to_quaternion(rotations) + return rotations + elif newtype in ["rotvec", "axisangle"]: + rotations = geometry_conver.matrix_to_axis_angle(rotations) + return rotations + else: + raise NotImplementedError + + +def to_matrix(oldtype, rotations): + if oldtype in ["matrix"]: + return rotations + if oldtype in ["rotmat"]: + rotations = rotations.reshape((*rotations.shape[:-2], 3, 3)) + return rotations + elif oldtype in ["rot6d", "6drot", "rotation6d"]: + rotations = geometry_conver.rotation_6d_to_matrix(rotations) + return rotations + elif oldtype in ["rotquat", "quaternion"]: + rotations = geometry_conver.quaternion_to_matrix(rotations) + return rotations + elif oldtype in ["rotvec", "axisangle"]: + rotations = geometry_conver.axis_angle_to_matrix(rotations) + return rotations + else: + raise NotImplementedError + + +# TODO: use a real subsampler.. +def subsample(num_frames, last_framerate, new_framerate): + step = int(last_framerate / new_framerate) + assert step >= 1 + frames = np.arange(0, num_frames, step) + return frames + + +# TODO: use a real upsampler.. +def upsample(motion, last_framerate, new_framerate): + step = int(new_framerate / last_framerate) + assert step >= 1 + + # Alpha blending => interpolation + alpha = np.linspace(0, 1, step + 1) + last = np.einsum("l,...->l...", 1 - alpha, motion[:-1]) + new = np.einsum("l,...->l...", alpha, motion[1:]) + + chuncks = (last + new)[:-1] + output = np.concatenate(chuncks.swapaxes(1, 0)) + # Don't forget the last one + output = np.concatenate((output, motion[[-1]])) + return output + + +if __name__ == "__main__": + motion = np.arange(105) + submotion = motion[subsample(len(motion), 100.0, 12.5)] + newmotion = upsample(submotion, 12.5, 100) + + print(newmotion) diff --git a/mGPT/utils/tensors.py b/mGPT/utils/tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..166143893e5ad1494e3bdf8a9a12261f61e77335 --- /dev/null +++ b/mGPT/utils/tensors.py @@ -0,0 +1,74 @@ +import torch + + +def lengths_to_mask(lengths): + max_len = max(lengths) + mask = torch.arange(max_len, device=lengths.device).expand( + len(lengths), max_len) < lengths.unsqueeze(1) + return mask + + +def collate_tensors(batch): + dims = batch[0].dim() + max_size = [max([b.size(i) for b in batch]) for i in range(dims)] + size = (len(batch),) + tuple(max_size) + canvas = batch[0].new_zeros(size=size) + for i, b in enumerate(batch): + sub_tensor = canvas[i] + for d in range(dims): + sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) + sub_tensor.add_(b) + return canvas + + +def collate(batch): + databatch = [b[0] for b in batch] + labelbatch = [b[1] for b in batch] + lenbatch = [len(b[0][0][0]) for b in batch] + + databatchTensor = collate_tensors(databatch) + labelbatchTensor = torch.as_tensor(labelbatch) + lenbatchTensor = torch.as_tensor(lenbatch) + + maskbatchTensor = lengths_to_mask(lenbatchTensor) + # x - [bs, njoints, nfeats, lengths] + # - nfeats, the representation of a joint + # y - [bs] + # mask - [bs, lengths] + # lengths - [bs] + batch = {"x": databatchTensor, "y": labelbatchTensor, + "mask": maskbatchTensor, 'lengths': lenbatchTensor} + return batch + + +# slow version with padding +def collate_data3d_slow(batch): + batchTensor = {} + for key in batch[0].keys(): + databatch = [b[key] for b in batch] + batchTensor[key] = collate_tensors(databatch) + batch = batchTensor + # theta - [bs, lengths, 85], theta shape (85,) + # - (np.array([1., 0., 0.]), pose(72), shape(10)), axis=0) + # kp_2d - [bs, lengths, njoints, nfeats], nfeats (x,y,weight) + # kp_3d - [bs, lengths, njoints, nfeats], nfeats (x,y,z) + # w_smpl - [bs, lengths] zeros + # w_3d - [bs, lengths] zeros + return batch + +def collate_data3d(batch): + batchTensor = {} + for key in batch[0].keys(): + databatch = [b[key] for b in batch] + if key == "paths": + batchTensor[key] = databatch + else: + batchTensor[key] = torch.stack(databatch,axis=0) + batch = batchTensor + # theta - [bs, lengths, 85], theta shape (85,) + # - (np.array([1., 0., 0.]), pose(72), shape(10)), axis=0) + # kp_2d - [bs, lengths, njoints, nfeats], nfeats (x,y,weight) + # kp_3d - [bs, lengths, njoints, nfeats], nfeats (x,y,z) + # w_smpl - [bs, lengths] zeros + # w_3d - [bs, lengths] zeros + return batch diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..bd9891223a368ff7f620ba294de9858b04b720b1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,37 @@ +opencv-python +tensorboard +pytorch_lightning +torchmetrics +omegaconf +shortuuid +chumpy +transformers +diffusers +einops +wandb +rich +matplotlib +bert-score +nlg-metricverse +gdown +numpy==1.23.1 + +# for visualization +gradio==3.43.2 +pyglet==1.4.0a1 +pyrender +PyOpenGL==3.1.4 +PyOpenGL_accelerate +smplx==0.1.28 +trimesh==3.9.24 +joblib==1.2.0 +shapely +triangle +h5py +scikit-image +spacy +ftfy +more-itertools +natsort +moviepy +librosa