Spaces:
Runtime error
Runtime error
bill-jiang
commited on
Commit
•
4409449
1
Parent(s):
a0563b6
Init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +165 -0
- README.md +4 -4
- app.py +511 -0
- assets/css/custom.css +359 -0
- assets/images/avatar_bot.jpg +0 -0
- assets/meta/mean.npy +3 -0
- assets/meta/mean_eval.npy +3 -0
- assets/meta/std.npy +3 -0
- assets/meta/std_eval.npy +3 -0
- assets/videos/m2t_0.mp4 +0 -0
- assets/videos/t2m_0.mp4 +0 -0
- configs/assets.yaml +32 -0
- configs/default.yaml +141 -0
- configs/evaluator/tm2t.yaml +19 -0
- configs/lm/default.yaml +7 -0
- configs/render.yaml +23 -0
- configs/vq/default.yaml +15 -0
- configs/webui.yaml +74 -0
- deps/smpl/smpl_models/SMPL_downsample_index.pkl +3 -0
- deps/smpl/smpl_models/gmm_08.pkl +3 -0
- deps/smpl/smpl_models/neutral_smpl_mean_params.h5 +3 -0
- deps/smpl/smpl_models/smpl.faces +0 -0
- deps/smpl/smpl_models/smpl.tar.gz +3 -0
- deps/smpl/smpl_models/smpl/SMPL_FEMALE.pkl +3 -0
- deps/smpl/smpl_models/smpl/SMPL_MALE.pkl +3 -0
- deps/smpl/smpl_models/smpl/SMPL_NEUTRAL.pkl +3 -0
- deps/smpl/smpl_models/smpl/readme.txt +1 -0
- deps/smpl/smpl_models/smplh/SMPLH_FEMALE.npz +3 -0
- deps/smpl/smpl_models/smplh/SMPLH_MALE.npz +3 -0
- deps/smpl/smpl_models/smplh/SMPLH_NEUTRAL.npz +3 -0
- deps/smpl/smpl_models/smplh/mano_v1_2.zip +3 -0
- deps/smpl/smpl_models/smplh/smplh.faces +0 -0
- deps/smpl/smpl_models/smplh/smplh.tar.xz +3 -0
- deps/smpl/smpl_models/smplx_parts_segm.pkl +3 -0
- mGPT/__init__.py +0 -0
- mGPT/archs/__init__.py +0 -0
- mGPT/archs/mgpt_lm.py +592 -0
- mGPT/archs/mgpt_vq.py +190 -0
- mGPT/archs/tm2t_evaluator.py +111 -0
- mGPT/archs/tools/embeddings.py +322 -0
- mGPT/archs/tools/quantize_cnn.py +414 -0
- mGPT/archs/tools/resnet.py +82 -0
- mGPT/archs/tools/token_emb.py +73 -0
- mGPT/archs/tools/transformer_layers.py +285 -0
- mGPT/callback.py +200 -0
- mGPT/config.py +217 -0
- mGPT/data/HumanML3D.py +117 -0
- mGPT/data/Kit.py +88 -0
- mGPT/data/__init__.py +103 -0
- mGPT/data/build_data.py +15 -0
.gitignore
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
.DS_Store
|
9 |
+
pyglet
|
10 |
+
app2.py
|
11 |
+
render.py
|
12 |
+
cache
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# poetry
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
107 |
+
#poetry.lock
|
108 |
+
|
109 |
+
# pdm
|
110 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
111 |
+
#pdm.lock
|
112 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
113 |
+
# in version control.
|
114 |
+
# https://pdm.fming.dev/#use-with-ide
|
115 |
+
.pdm.toml
|
116 |
+
|
117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118 |
+
__pypackages__/
|
119 |
+
|
120 |
+
# Celery stuff
|
121 |
+
celerybeat-schedule
|
122 |
+
celerybeat.pid
|
123 |
+
|
124 |
+
# SageMath parsed files
|
125 |
+
*.sage.py
|
126 |
+
|
127 |
+
# Environments
|
128 |
+
.env
|
129 |
+
.venv
|
130 |
+
env/
|
131 |
+
venv/
|
132 |
+
ENV/
|
133 |
+
env.bak/
|
134 |
+
venv.bak/
|
135 |
+
|
136 |
+
# Spyder project settings
|
137 |
+
.spyderproject
|
138 |
+
.spyproject
|
139 |
+
|
140 |
+
# Rope project settings
|
141 |
+
.ropeproject
|
142 |
+
|
143 |
+
# mkdocs documentation
|
144 |
+
/site
|
145 |
+
|
146 |
+
# mypy
|
147 |
+
.mypy_cache/
|
148 |
+
.dmypy.json
|
149 |
+
dmypy.json
|
150 |
+
|
151 |
+
# Pyre type checker
|
152 |
+
.pyre/
|
153 |
+
|
154 |
+
# pytype static type analyzer
|
155 |
+
.pytype/
|
156 |
+
|
157 |
+
# Cython debug symbols
|
158 |
+
cython_debug/
|
159 |
+
|
160 |
+
# PyCharm
|
161 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
162 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
163 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
164 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
165 |
+
#.idea/
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
title: MotionGPT
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.43.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: MotionGPT
|
3 |
+
emoji: 🏃
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.43.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: mit
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
import cv2
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
import OpenGL.GL as gl
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
import moviepy.editor as mp
|
11 |
+
from pathlib import Path
|
12 |
+
from mGPT.data.build_data import build_data
|
13 |
+
from mGPT.models.build_model import build_model
|
14 |
+
from mGPT.config import parse_args
|
15 |
+
from scipy.spatial.transform import Rotation as RRR
|
16 |
+
import mGPT.render.matplot.plot_3d_global as plot_3d
|
17 |
+
from mGPT.render.pyrender.hybrik_loc2rot import HybrIKJointsToRotmat
|
18 |
+
from mGPT.render.pyrender.smpl_render import SMPLRender
|
19 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
20 |
+
import librosa
|
21 |
+
from huggingface_hub import snapshot_download
|
22 |
+
|
23 |
+
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
24 |
+
os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
|
25 |
+
os.system('pip install /home/user/app/pyrender')
|
26 |
+
|
27 |
+
# Load model
|
28 |
+
cfg = parse_args(phase="webui") # parse config file
|
29 |
+
cfg.FOLDER = 'cache'
|
30 |
+
output_dir = Path(cfg.FOLDER)
|
31 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
32 |
+
pl.seed_everything(cfg.SEED_VALUE)
|
33 |
+
if torch.cuda.is_available():
|
34 |
+
device = torch.device("cuda")
|
35 |
+
else:
|
36 |
+
device = torch.device("cpu")
|
37 |
+
|
38 |
+
model_path = snapshot_download(repo_id="bill-jiang/MotionGPT-base")
|
39 |
+
|
40 |
+
datamodule = build_data(cfg, phase="test")
|
41 |
+
model = build_model(cfg, datamodule)
|
42 |
+
state_dict = torch.load(f'{model_path}/motiongpt_s3_h3d.tar',
|
43 |
+
map_location="cpu")["state_dict"]
|
44 |
+
model.load_state_dict(state_dict)
|
45 |
+
model.to(device)
|
46 |
+
|
47 |
+
audio_processor = WhisperProcessor.from_pretrained(cfg.model.whisper_path)
|
48 |
+
audio_model = WhisperForConditionalGeneration.from_pretrained(
|
49 |
+
cfg.model.whisper_path).to(device)
|
50 |
+
forced_decoder_ids_zh = audio_processor.get_decoder_prompt_ids(
|
51 |
+
language="zh", task="translate")
|
52 |
+
forced_decoder_ids_en = audio_processor.get_decoder_prompt_ids(
|
53 |
+
language="en", task="translate")
|
54 |
+
|
55 |
+
# HTML Style
|
56 |
+
|
57 |
+
Video_Components = """
|
58 |
+
<div class="side-video" style="position: relative;">
|
59 |
+
<video width="340" autoplay loop>
|
60 |
+
<source src="file/{video_path}" type="video/mp4">
|
61 |
+
</video>
|
62 |
+
<a class="videodl-button" href="file/{video_path}" download="{video_fname}" title="Download Video">
|
63 |
+
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#000000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-video"><path d="m22 8-6 4 6 4V8Z"/><rect width="14" height="12" x="2" y="6" rx="2" ry="2"/></svg>
|
64 |
+
</a>
|
65 |
+
<a class="npydl-button" href="file/{motion_path}" download="{motion_fname}" title="Download Motion">
|
66 |
+
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#000000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-file-box"><path d="M14.5 22H18a2 2 0 0 0 2-2V7.5L14.5 2H6a2 2 0 0 0-2 2v4"/><polyline points="14 2 14 8 20 8"/><path d="M2.97 13.12c-.6.36-.97 1.02-.97 1.74v3.28c0 .72.37 1.38.97 1.74l3 1.83c.63.39 1.43.39 2.06 0l3-1.83c.6-.36.97-1.02.97-1.74v-3.28c0-.72-.37-1.38-.97-1.74l-3-1.83a1.97 1.97 0 0 0-2.06 0l-3 1.83Z"/><path d="m7 17-4.74-2.85"/><path d="m7 17 4.74-2.85"/><path d="M7 17v5"/></svg>
|
67 |
+
</a>
|
68 |
+
</div>
|
69 |
+
"""
|
70 |
+
|
71 |
+
Video_Components_example = """
|
72 |
+
<div class="side-video" style="position: relative;">
|
73 |
+
<video width="340" autoplay loop controls>
|
74 |
+
<source src="file/{video_path}" type="video/mp4">
|
75 |
+
</video>
|
76 |
+
<a class="npydl-button" href="file/{video_path}" download="{video_fname}" title="Download Video">
|
77 |
+
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-video"><path d="m22 8-6 4 6 4V8Z"/><rect width="14" height="12" x="2" y="6" rx="2" ry="2"/></svg>
|
78 |
+
</a>
|
79 |
+
</div>
|
80 |
+
"""
|
81 |
+
|
82 |
+
Text_Components = """
|
83 |
+
<h3 class="side-content" >{msg}</h3>
|
84 |
+
"""
|
85 |
+
|
86 |
+
|
87 |
+
def motion_token_to_string(motion_token, lengths, codebook_size=512):
|
88 |
+
motion_string = []
|
89 |
+
for i in range(motion_token.shape[0]):
|
90 |
+
motion_i = motion_token[i].cpu(
|
91 |
+
) if motion_token.device.type == 'cuda' else motion_token[i]
|
92 |
+
motion_list = motion_i.tolist()[:lengths[i]]
|
93 |
+
motion_string.append(
|
94 |
+
(f'<motion_id_{codebook_size}>' +
|
95 |
+
''.join([f'<motion_id_{int(i)}>' for i in motion_list]) +
|
96 |
+
f'<motion_id_{codebook_size + 1}>'))
|
97 |
+
return motion_string
|
98 |
+
|
99 |
+
|
100 |
+
def render_motion(data, feats, method='fast'):
|
101 |
+
fname = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(
|
102 |
+
time.time())) + str(np.random.randint(10000, 99999))
|
103 |
+
video_fname = fname + '.mp4'
|
104 |
+
feats_fname = fname + '.npy'
|
105 |
+
output_npy_path = os.path.join(output_dir, feats_fname)
|
106 |
+
output_mp4_path = os.path.join(output_dir, video_fname)
|
107 |
+
np.save(output_npy_path, feats)
|
108 |
+
|
109 |
+
if method == 'slow':
|
110 |
+
if len(data.shape) == 4:
|
111 |
+
data = data[0]
|
112 |
+
data = data - data[0, 0]
|
113 |
+
pose_generator = HybrIKJointsToRotmat()
|
114 |
+
pose = pose_generator(data)
|
115 |
+
pose = np.concatenate([
|
116 |
+
pose,
|
117 |
+
np.stack([np.stack([np.eye(3)] * pose.shape[0], 0)] * 2, 1)
|
118 |
+
], 1)
|
119 |
+
shape = [768, 768]
|
120 |
+
render = SMPLRender(cfg.RENDER.SMPL_MODEL_PATH)
|
121 |
+
|
122 |
+
if not os.environ.get("PYOPENGL_PLATFORM"):
|
123 |
+
os.environ["DISPLAY"] = ":0.0"
|
124 |
+
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
125 |
+
|
126 |
+
size = (shape[1], shape[0])
|
127 |
+
fps = 20.0
|
128 |
+
fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
|
129 |
+
videoWriter = cv2.VideoWriter(output_mp4_path, fourcc, fps, size)
|
130 |
+
r = RRR.from_rotvec(np.array([np.pi, 0.0, 0.0]))
|
131 |
+
pose[:, 0] = np.matmul(r.as_matrix().reshape(1, 3, 3), pose[:, 0])
|
132 |
+
for i in range(data.shape[0]):
|
133 |
+
img = np.zeros([shape[0], shape[1], 3])
|
134 |
+
aroot = data[[i], 0] + np.array([[0.0, 0.0, 30.0]])
|
135 |
+
aroot[:, 1] = -aroot[:, 1]
|
136 |
+
params = dict(pred_shape=np.zeros([1, 10]),
|
137 |
+
pred_root=aroot,
|
138 |
+
pred_pose=pose[[i]])
|
139 |
+
renderImg = render.render(img.copy(), params)
|
140 |
+
renderImg = (renderImg * 255).astype(np.uint8)
|
141 |
+
videoWriter.write(renderImg)
|
142 |
+
videoWriter.release()
|
143 |
+
output_video_h264_name = output_mp4_path[:-4] + '_h264.mp4'
|
144 |
+
command = 'ffmpeg -y -i {} -vcodec h264 {}'.format(
|
145 |
+
output_mp4_path, output_video_h264_name)
|
146 |
+
os.system(command)
|
147 |
+
output_mp4_path = output_video_h264_name
|
148 |
+
video_fname = video_fname[:-4] + '_h264.mp4'
|
149 |
+
elif method == 'fast':
|
150 |
+
output_gif_path = output_mp4_path[:-4] + '.gif'
|
151 |
+
if len(data.shape) == 3:
|
152 |
+
data = data[None]
|
153 |
+
if isinstance(data, torch.Tensor):
|
154 |
+
data = data.cpu().numpy()
|
155 |
+
pose_vis = plot_3d.draw_to_batch(data, [''], [output_gif_path])
|
156 |
+
out_video = mp.VideoFileClip(output_gif_path)
|
157 |
+
out_video.write_videofile(output_mp4_path)
|
158 |
+
|
159 |
+
return output_mp4_path, video_fname, output_npy_path, feats_fname
|
160 |
+
|
161 |
+
|
162 |
+
def load_motion(motion_uploaded, method):
|
163 |
+
file = motion_uploaded['file']
|
164 |
+
|
165 |
+
feats = torch.tensor(np.load(file), device=model.device)
|
166 |
+
if len(feats.shape) == 2:
|
167 |
+
feats = feats[None]
|
168 |
+
# feats = model.datamodule.normalize(feats)
|
169 |
+
|
170 |
+
# Motion tokens
|
171 |
+
motion_lengths = feats.shape[0]
|
172 |
+
motion_token, _ = model.vae.encode(feats)
|
173 |
+
|
174 |
+
motion_token_string = model.lm.motion_token_to_string(
|
175 |
+
motion_token, [motion_token.shape[1]])[0]
|
176 |
+
motion_token_length = motion_token.shape[1]
|
177 |
+
|
178 |
+
# Motion rendered
|
179 |
+
joints = model.datamodule.feats2joints(feats.cpu()).cpu().numpy()
|
180 |
+
output_mp4_path, video_fname, output_npy_path, joints_fname = render_motion(
|
181 |
+
joints,
|
182 |
+
feats.to('cpu').numpy(), method)
|
183 |
+
|
184 |
+
motion_uploaded.update({
|
185 |
+
"feats": feats,
|
186 |
+
"joints": joints,
|
187 |
+
"motion_video": output_mp4_path,
|
188 |
+
"motion_video_fname": video_fname,
|
189 |
+
"motion_joints": output_npy_path,
|
190 |
+
"motion_joints_fname": joints_fname,
|
191 |
+
"motion_lengths": motion_lengths,
|
192 |
+
"motion_token": motion_token,
|
193 |
+
"motion_token_string": motion_token_string,
|
194 |
+
"motion_token_length": motion_token_length,
|
195 |
+
})
|
196 |
+
|
197 |
+
return motion_uploaded
|
198 |
+
|
199 |
+
|
200 |
+
def add_text(history, text, motion_uploaded, data_stored, method):
|
201 |
+
data_stored = data_stored + [{'user_input': text}]
|
202 |
+
|
203 |
+
text = f"""<h3>{text}</h3>"""
|
204 |
+
history = history + [(text, None)]
|
205 |
+
if 'file' in motion_uploaded.keys():
|
206 |
+
motion_uploaded = load_motion(motion_uploaded, method)
|
207 |
+
output_mp4_path = motion_uploaded['motion_video']
|
208 |
+
video_fname = motion_uploaded['motion_video_fname']
|
209 |
+
output_npy_path = motion_uploaded['motion_joints']
|
210 |
+
joints_fname = motion_uploaded['motion_joints_fname']
|
211 |
+
history = history + [(Video_Components.format(
|
212 |
+
video_path=output_mp4_path,
|
213 |
+
video_fname=video_fname,
|
214 |
+
motion_path=output_npy_path,
|
215 |
+
motion_fname=joints_fname), None)]
|
216 |
+
|
217 |
+
return history, gr.update(value="",
|
218 |
+
interactive=False), motion_uploaded, data_stored
|
219 |
+
|
220 |
+
|
221 |
+
def add_audio(history, audio_path, data_stored, language='en'):
|
222 |
+
audio, sampling_rate = librosa.load(audio_path, sr=16000)
|
223 |
+
input_features = audio_processor(
|
224 |
+
audio, sampling_rate, return_tensors="pt"
|
225 |
+
).input_features # whisper training sampling rate, do not modify
|
226 |
+
input_features = torch.Tensor(input_features).to(device)
|
227 |
+
|
228 |
+
if language == 'English':
|
229 |
+
forced_decoder_ids = forced_decoder_ids_en
|
230 |
+
else:
|
231 |
+
forced_decoder_ids = forced_decoder_ids_zh
|
232 |
+
predicted_ids = audio_model.generate(input_features,
|
233 |
+
forced_decoder_ids=forced_decoder_ids)
|
234 |
+
text_input = audio_processor.batch_decode(predicted_ids,
|
235 |
+
skip_special_tokens=True)
|
236 |
+
text_input = str(text_input).strip('[]"')
|
237 |
+
data_stored = data_stored + [{'user_input': text_input}]
|
238 |
+
gr.update(value=data_stored, interactive=False)
|
239 |
+
history = history + [(text_input, None)]
|
240 |
+
|
241 |
+
return history, data_stored
|
242 |
+
|
243 |
+
|
244 |
+
def add_file(history, file, txt, motion_uploaded):
|
245 |
+
motion_uploaded['file'] = file.name
|
246 |
+
txt = txt.replace(" <Motion_Placeholder>", "") + " <Motion_Placeholder>"
|
247 |
+
return history, gr.update(value=txt, interactive=True), motion_uploaded
|
248 |
+
|
249 |
+
|
250 |
+
def bot(history, motion_uploaded, data_stored, method):
|
251 |
+
|
252 |
+
motion_length, motion_token_string = motion_uploaded[
|
253 |
+
"motion_lengths"], motion_uploaded["motion_token_string"]
|
254 |
+
|
255 |
+
input = data_stored[-1]['user_input']
|
256 |
+
prompt = model.lm.placeholder_fulfill(input, motion_length,
|
257 |
+
motion_token_string, "")
|
258 |
+
data_stored[-1]['model_input'] = prompt
|
259 |
+
batch = {
|
260 |
+
"length": [motion_length],
|
261 |
+
"text": [prompt],
|
262 |
+
}
|
263 |
+
|
264 |
+
outputs = model(batch, task="t2m")
|
265 |
+
out_feats = outputs["feats"][0]
|
266 |
+
out_lengths = outputs["length"][0]
|
267 |
+
out_joints = outputs["joints"][:out_lengths].detach().cpu().numpy()
|
268 |
+
out_texts = outputs["texts"][0]
|
269 |
+
output_mp4_path, video_fname, output_npy_path, joints_fname = render_motion(
|
270 |
+
out_joints,
|
271 |
+
out_feats.to('cpu').numpy(), method)
|
272 |
+
|
273 |
+
motion_uploaded = {
|
274 |
+
"feats": None,
|
275 |
+
"joints": None,
|
276 |
+
"motion_video": None,
|
277 |
+
"motion_lengths": 0,
|
278 |
+
"motion_token": None,
|
279 |
+
"motion_token_string": '',
|
280 |
+
"motion_token_length": 0,
|
281 |
+
}
|
282 |
+
|
283 |
+
data_stored[-1]['model_output'] = {
|
284 |
+
"feats": out_feats,
|
285 |
+
"joints": out_joints,
|
286 |
+
"length": out_lengths,
|
287 |
+
"texts": out_texts,
|
288 |
+
"motion_video": output_mp4_path,
|
289 |
+
"motion_video_fname": video_fname,
|
290 |
+
"motion_joints": output_npy_path,
|
291 |
+
"motion_joints_fname": joints_fname,
|
292 |
+
}
|
293 |
+
|
294 |
+
if '<Motion_Placeholder>' == out_texts:
|
295 |
+
response = [
|
296 |
+
Video_Components.format(video_path=output_mp4_path,
|
297 |
+
video_fname=video_fname,
|
298 |
+
motion_path=output_npy_path,
|
299 |
+
motion_fname=joints_fname)
|
300 |
+
]
|
301 |
+
elif '<Motion_Placeholder>' in out_texts:
|
302 |
+
response = [
|
303 |
+
Text_Components.format(
|
304 |
+
msg=out_texts.split("<Motion_Placeholder>")[0]),
|
305 |
+
Video_Components.format(video_path=output_mp4_path,
|
306 |
+
video_fname=video_fname,
|
307 |
+
motion_path=output_npy_path,
|
308 |
+
motion_fname=joints_fname),
|
309 |
+
Text_Components.format(
|
310 |
+
msg=out_texts.split("<Motion_Placeholder>")[1]),
|
311 |
+
]
|
312 |
+
else:
|
313 |
+
response = f"""<h3>{out_texts}</h3>"""
|
314 |
+
|
315 |
+
history[-1][1] = ""
|
316 |
+
for character in response:
|
317 |
+
history[-1][1] += character
|
318 |
+
time.sleep(0.02)
|
319 |
+
yield history, motion_uploaded, data_stored
|
320 |
+
|
321 |
+
|
322 |
+
def bot_example(history, responses):
|
323 |
+
for response in responses:
|
324 |
+
history[-1][1] = ""
|
325 |
+
for character in response:
|
326 |
+
history[-1][1] += character
|
327 |
+
time.sleep(0.02)
|
328 |
+
yield history, motion_uploaded, data_stored
|
329 |
+
|
330 |
+
|
331 |
+
# Examples
|
332 |
+
chat_instruct = [
|
333 |
+
(None,
|
334 |
+
"**👋 Hi, I'm MotionGPT! I can generate realistic human motion from text, or generate text from motion.**"
|
335 |
+
),
|
336 |
+
(None,
|
337 |
+
"You can chat with me in pure text like generating human motion following your descriptions."
|
338 |
+
),
|
339 |
+
(None,
|
340 |
+
"After generation, you can click the button in the top right of generation human motion result to download the human motion video or feature stored in .npy format."
|
341 |
+
),
|
342 |
+
(None,
|
343 |
+
"With the human motion feature file downloaded or got from dataset, you are able to ask me to translate it!"
|
344 |
+
),
|
345 |
+
(None,
|
346 |
+
"Of courser, you can also purely chat with me and let me give you human motion in text, here are some examples!"
|
347 |
+
),
|
348 |
+
(None,
|
349 |
+
"We provide two motion visulization methods. The default fast method is skeleton line ploting which is like the examples below:"
|
350 |
+
),
|
351 |
+
(None,
|
352 |
+
Video_Components_example.format(video_path="assets/videos/t2m_0.mp4",
|
353 |
+
video_fname="example1.mp4")),
|
354 |
+
(None,
|
355 |
+
"And the slow method is SMPL model rendering which is more realistic but slower."
|
356 |
+
),
|
357 |
+
(None,
|
358 |
+
Video_Components_example.format(video_path="assets/videos/t2m_0.mp4",
|
359 |
+
video_fname="example1.mp4")),
|
360 |
+
(None, "👉 Follow the examples and try yourself!"),
|
361 |
+
]
|
362 |
+
|
363 |
+
t2m_examples = [
|
364 |
+
(None,
|
365 |
+
"You can chat with me in pure text, following are some examples of text-to-motion generation!"
|
366 |
+
),
|
367 |
+
("Generate a person is walking forwards, but stumbles and steps back, then carries on forward.",
|
368 |
+
Video_Components_example.format(video_path="assets/videos/t2m_0.mp4",
|
369 |
+
video_fname="example1.mp4")),
|
370 |
+
("Generate a person is walking forwards, but stumbles and steps back, then carries on forward.",
|
371 |
+
Video_Components_example.format(video_path="assets/videos/t2m_0.mp4",
|
372 |
+
video_fname="example1.mp4")),
|
373 |
+
("Generate a person is walking forwards, but stumbles and steps back, then carries on forward.",
|
374 |
+
Video_Components_example.format(video_path="assets/videos/t2m_0.mp4",
|
375 |
+
video_fname="example1.mp4")),
|
376 |
+
]
|
377 |
+
|
378 |
+
m2t_examples = [
|
379 |
+
(None,
|
380 |
+
"With the human motion feature file downloaded or got from dataset, you are able to ask me to translate it, here are some examples!"
|
381 |
+
),
|
382 |
+
("Please explain the movement shown in [Motion_tokens] using natural language.",
|
383 |
+
None),
|
384 |
+
(Video_Components_example.format(video_path="assets/videos/m2t_0.mp4",
|
385 |
+
video_fname="example2.mp4"),
|
386 |
+
"a person walks forward then does a backwards z-shape movement to its left side. then back to the right."
|
387 |
+
),
|
388 |
+
("Please explain the movement shown in [Motion_tokens] using natural language.",
|
389 |
+
None),
|
390 |
+
(Video_Components_example.format(video_path="assets/videos/m2t_0.mp4",
|
391 |
+
video_fname="example2.mp4"),
|
392 |
+
"a person walks forward then does a backwards z-shape movement to its left side. then back to the right."
|
393 |
+
),
|
394 |
+
]
|
395 |
+
|
396 |
+
t2t_examples = [
|
397 |
+
(None,
|
398 |
+
"Of courser, you can also purely chat with me and let me give you human motion in text, here are some examples!"
|
399 |
+
),
|
400 |
+
('Depict a motion as like you have seen it.',
|
401 |
+
"The person walks while swaying their hips along a curved path to the left slowly then stops to look down at the edge of the grey platform at something."
|
402 |
+
),
|
403 |
+
('Depict a motion as like you have seen it.',
|
404 |
+
"The person walks while swaying their hips along a curved path to the left slowly then stops to look down at the edge of the grey platform at something."
|
405 |
+
),
|
406 |
+
]
|
407 |
+
|
408 |
+
Init_chatbot = [
|
409 |
+
(None,
|
410 |
+
"**👋 Hi, I'm MotionGPT! I can generate realistic human motion from text, or generate text from motion.**"
|
411 |
+
)
|
412 |
+
] + t2m_examples[:3] + m2t_examples[:2] + t2t_examples[:2] + chat_instruct[-4:]
|
413 |
+
|
414 |
+
with open("assets/css/custom.css", "r", encoding="utf-8") as f:
|
415 |
+
customCSS = f.read()
|
416 |
+
|
417 |
+
with gr.Blocks(css=customCSS) as demo:
|
418 |
+
|
419 |
+
# Variables
|
420 |
+
motion_uploaded = gr.State({
|
421 |
+
"feats": None,
|
422 |
+
"joints": None,
|
423 |
+
"motion_video": None,
|
424 |
+
"motion_lengths": 0,
|
425 |
+
"motion_token": None,
|
426 |
+
"motion_token_string": '',
|
427 |
+
"motion_token_length": 0,
|
428 |
+
})
|
429 |
+
data_stored = gr.State([])
|
430 |
+
|
431 |
+
gr.Markdown("# MotionGPT")
|
432 |
+
|
433 |
+
chatbot = gr.Chatbot(Init_chatbot,
|
434 |
+
elem_id="mGPT",
|
435 |
+
height=600,
|
436 |
+
label="MotionGPT",
|
437 |
+
avatar_images=(None,
|
438 |
+
("assets/images/avatar_bot.jpg")),
|
439 |
+
bubble_full_width=False)
|
440 |
+
|
441 |
+
with gr.Row():
|
442 |
+
with gr.Column(scale=0.85):
|
443 |
+
with gr.Row():
|
444 |
+
txt = gr.Textbox(
|
445 |
+
label="Text",
|
446 |
+
show_label=False,
|
447 |
+
placeholder=
|
448 |
+
"Enter text and press ENTER or speak to input. You can also upload motion.",
|
449 |
+
container=False)
|
450 |
+
|
451 |
+
with gr.Row():
|
452 |
+
aud = gr.Audio(source="microphone",
|
453 |
+
label="Speak input",
|
454 |
+
type='filepath')
|
455 |
+
btn = gr.UploadButton("📁 Upload motion",
|
456 |
+
elem_id="upload",
|
457 |
+
file_types=["file"],
|
458 |
+
variant='primary')
|
459 |
+
regen = gr.Button("🔄 Regenerate", elem_id="regen")
|
460 |
+
clear = gr.ClearButton([txt, chatbot, aud], value='🗑️ Clear')
|
461 |
+
|
462 |
+
with gr.Row():
|
463 |
+
gr.Markdown('''
|
464 |
+
### You can get more examples (pre-generated for faster response) by clicking the buttons below:
|
465 |
+
''')
|
466 |
+
|
467 |
+
with gr.Row():
|
468 |
+
instruct = gr.Button("Instructions", elem_id="instruction")
|
469 |
+
t2m_eg = gr.Button("Text-to-Motion", elem_id="t2m")
|
470 |
+
m2t_eg = gr.Button("Motion-to-Text", elem_id="m2t")
|
471 |
+
t2t_eg = gr.Button("Random description", elem_id="t2t")
|
472 |
+
|
473 |
+
with gr.Column(scale=0.15, min_width=150):
|
474 |
+
method = gr.Dropdown(["slow", "fast"],
|
475 |
+
label="Visulization method",
|
476 |
+
interactive=True,
|
477 |
+
elem_id="method",
|
478 |
+
value="fast")
|
479 |
+
|
480 |
+
language = gr.Dropdown(["English", "中文"],
|
481 |
+
label="Speech language",
|
482 |
+
interactive=True,
|
483 |
+
elem_id="language",
|
484 |
+
value="English")
|
485 |
+
|
486 |
+
txt_msg = txt.submit(
|
487 |
+
add_text, [chatbot, txt, motion_uploaded, data_stored, method],
|
488 |
+
[chatbot, txt, motion_uploaded, data_stored],
|
489 |
+
queue=False).then(bot, [chatbot, motion_uploaded, data_stored, method],
|
490 |
+
[chatbot, motion_uploaded, data_stored])
|
491 |
+
|
492 |
+
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
|
493 |
+
|
494 |
+
file_msg = btn.upload(add_file, [chatbot, btn, txt, motion_uploaded],
|
495 |
+
[chatbot, txt, motion_uploaded],
|
496 |
+
queue=False)
|
497 |
+
aud_msg = aud.stop_recording(
|
498 |
+
add_audio, [chatbot, aud, data_stored, language],
|
499 |
+
[chatbot, data_stored],
|
500 |
+
queue=False).then(bot, [chatbot, motion_uploaded, data_stored, method],
|
501 |
+
[chatbot, motion_uploaded, data_stored])
|
502 |
+
regen_msg = regen.click(bot,
|
503 |
+
[chatbot, motion_uploaded, data_stored, method],
|
504 |
+
[chatbot, motion_uploaded, data_stored],
|
505 |
+
queue=False)
|
506 |
+
chatbot.change(scroll_to_output=True)
|
507 |
+
|
508 |
+
demo.queue()
|
509 |
+
|
510 |
+
if __name__ == "__main__":
|
511 |
+
demo.launch(debug=True)
|
assets/css/custom.css
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Borrowed from https://huggingface.co/spaces/project-baize/chat-with-baize */
|
2 |
+
|
3 |
+
:root {
|
4 |
+
--chatbot-color-light: #f6f6f6;
|
5 |
+
--chatbot-color-dark: #121111;
|
6 |
+
}
|
7 |
+
|
8 |
+
/* Light mode (default) */
|
9 |
+
#mGPT {
|
10 |
+
background-color: var(--chatbot-color-light) !important;
|
11 |
+
color: #000000 !important;
|
12 |
+
}
|
13 |
+
[data-testid='bot'] {
|
14 |
+
background-color: #ffffff !important;
|
15 |
+
}
|
16 |
+
[data-testid='user'] {
|
17 |
+
background-color: #95ec69 !important;
|
18 |
+
}
|
19 |
+
|
20 |
+
/* Dark mode */
|
21 |
+
.dark #mGPT {
|
22 |
+
background-color: var(--chatbot-color-dark) !important;
|
23 |
+
color: #ffffff !important;
|
24 |
+
}
|
25 |
+
.dark [data-testid='bot'] {
|
26 |
+
background-color: #2c2c2c !important;
|
27 |
+
}
|
28 |
+
|
29 |
+
.dark [data-testid='user'] {
|
30 |
+
background-color: #26b561 !important;
|
31 |
+
}
|
32 |
+
|
33 |
+
#mGPT {
|
34 |
+
height: 100%;
|
35 |
+
min-height: 500px;
|
36 |
+
}
|
37 |
+
|
38 |
+
[class*='message-buttons'] {
|
39 |
+
visibility: hidden;
|
40 |
+
}
|
41 |
+
|
42 |
+
[class*='message'] {
|
43 |
+
border: none;
|
44 |
+
font-size: var(--text-xl) !important;
|
45 |
+
line-height: var(--line-xl) !important;
|
46 |
+
}
|
47 |
+
/* [data-testid='bot'] {
|
48 |
+
max-width: 85%;
|
49 |
+
width: auto !important;
|
50 |
+
border-bottom-left-radius: 0 !important;
|
51 |
+
}
|
52 |
+
[data-testid='user'] {
|
53 |
+
max-width: 85%;
|
54 |
+
width: auto !important;
|
55 |
+
border-bottom-right-radius: 0 !important;
|
56 |
+
} */
|
57 |
+
|
58 |
+
/* Text & Video */
|
59 |
+
#method {
|
60 |
+
line-height: 1.95 !important;
|
61 |
+
}
|
62 |
+
|
63 |
+
.side-content {
|
64 |
+
max-width: 340px;
|
65 |
+
}
|
66 |
+
|
67 |
+
/* @media only screen and (min-width: 768px) {
|
68 |
+
.side-content {
|
69 |
+
float: left;
|
70 |
+
overflow-wrap: break-word;
|
71 |
+
padding-right: 2rem;
|
72 |
+
}
|
73 |
+
|
74 |
+
.side-video {
|
75 |
+
float: right;
|
76 |
+
}
|
77 |
+
} */
|
78 |
+
|
79 |
+
/* Buttom */
|
80 |
+
#upload {
|
81 |
+
color: #000000;
|
82 |
+
}
|
83 |
+
|
84 |
+
.videodl-button {
|
85 |
+
position: absolute;
|
86 |
+
left: 80%;
|
87 |
+
top: 5px;
|
88 |
+
width: 24px;
|
89 |
+
height: 24px;
|
90 |
+
}
|
91 |
+
|
92 |
+
.videodl-button svg {
|
93 |
+
width: 24px;
|
94 |
+
height: 24px;
|
95 |
+
}
|
96 |
+
|
97 |
+
.npydl-button {
|
98 |
+
position: absolute;
|
99 |
+
left: 90%;
|
100 |
+
top: 5px;
|
101 |
+
width: 24px;
|
102 |
+
height: 24px;
|
103 |
+
}
|
104 |
+
|
105 |
+
.npydl-button svg {
|
106 |
+
width: 24px;
|
107 |
+
height: 24px;
|
108 |
+
}
|
109 |
+
|
110 |
+
/* Table */
|
111 |
+
table {
|
112 |
+
margin: 1em 0;
|
113 |
+
border-collapse: collapse;
|
114 |
+
empty-cells: show;
|
115 |
+
}
|
116 |
+
td,
|
117 |
+
th {
|
118 |
+
border: 1.2px solid var(--border-color-primary) !important;
|
119 |
+
padding: 0.2em;
|
120 |
+
}
|
121 |
+
thead {
|
122 |
+
background-color: rgba(175, 184, 193, 0.2);
|
123 |
+
}
|
124 |
+
thead th {
|
125 |
+
padding: 0.5em 0.2em;
|
126 |
+
}
|
127 |
+
/* Inline code */
|
128 |
+
#mGPT code {
|
129 |
+
display: inline;
|
130 |
+
white-space: break-spaces;
|
131 |
+
border-radius: 6px;
|
132 |
+
margin: 0 2px 0 2px;
|
133 |
+
padding: 0.2em 0.4em 0.1em 0.4em;
|
134 |
+
background-color: rgba(175, 184, 193, 0.2);
|
135 |
+
}
|
136 |
+
/* Code block */
|
137 |
+
#mGPT pre code {
|
138 |
+
display: block;
|
139 |
+
overflow: auto;
|
140 |
+
white-space: pre;
|
141 |
+
background-color: hsla(0, 0%, 0%, 80%) !important;
|
142 |
+
border-radius: 10px;
|
143 |
+
padding: 1.4em 1.2em 0em 1.4em;
|
144 |
+
margin: 1.2em 2em 1.2em 0.5em;
|
145 |
+
color: #fff;
|
146 |
+
box-shadow: 6px 6px 16px hsla(0, 0%, 0%, 0.2);
|
147 |
+
}
|
148 |
+
/* Hightlight */
|
149 |
+
#mGPT .highlight {
|
150 |
+
background-color: transparent;
|
151 |
+
}
|
152 |
+
#mGPT .highlight .hll {
|
153 |
+
background-color: #49483e;
|
154 |
+
}
|
155 |
+
#mGPT .highlight .c {
|
156 |
+
color: #75715e;
|
157 |
+
} /* Comment */
|
158 |
+
#mGPT .highlight .err {
|
159 |
+
color: #960050;
|
160 |
+
background-color: #1e0010;
|
161 |
+
} /* Error */
|
162 |
+
#mGPT .highlight .k {
|
163 |
+
color: #66d9ef;
|
164 |
+
} /* Keyword */
|
165 |
+
#mGPT .highlight .l {
|
166 |
+
color: #ae81ff;
|
167 |
+
} /* Literal */
|
168 |
+
#mGPT .highlight .n {
|
169 |
+
color: #f8f8f2;
|
170 |
+
} /* Name */
|
171 |
+
#mGPT .highlight .o {
|
172 |
+
color: #f92672;
|
173 |
+
} /* Operator */
|
174 |
+
#mGPT .highlight .p {
|
175 |
+
color: #f8f8f2;
|
176 |
+
} /* Punctuation */
|
177 |
+
#mGPT .highlight .ch {
|
178 |
+
color: #75715e;
|
179 |
+
} /* Comment.Hashbang */
|
180 |
+
#mGPT .highlight .cm {
|
181 |
+
color: #75715e;
|
182 |
+
} /* Comment.Multiline */
|
183 |
+
#mGPT .highlight .cp {
|
184 |
+
color: #75715e;
|
185 |
+
} /* Comment.Preproc */
|
186 |
+
#mGPT .highlight .cpf {
|
187 |
+
color: #75715e;
|
188 |
+
} /* Comment.PreprocFile */
|
189 |
+
#mGPT .highlight .c1 {
|
190 |
+
color: #75715e;
|
191 |
+
} /* Comment.Single */
|
192 |
+
#mGPT .highlight .cs {
|
193 |
+
color: #75715e;
|
194 |
+
} /* Comment.Special */
|
195 |
+
#mGPT .highlight .gd {
|
196 |
+
color: #f92672;
|
197 |
+
} /* Generic.Deleted */
|
198 |
+
#mGPT .highlight .ge {
|
199 |
+
font-style: italic;
|
200 |
+
} /* Generic.Emph */
|
201 |
+
#mGPT .highlight .gi {
|
202 |
+
color: #a6e22e;
|
203 |
+
} /* Generic.Inserted */
|
204 |
+
#mGPT .highlight .gs {
|
205 |
+
font-weight: bold;
|
206 |
+
} /* Generic.Strong */
|
207 |
+
#mGPT .highlight .gu {
|
208 |
+
color: #75715e;
|
209 |
+
} /* Generic.Subheading */
|
210 |
+
#mGPT .highlight .kc {
|
211 |
+
color: #66d9ef;
|
212 |
+
} /* Keyword.Constant */
|
213 |
+
#mGPT .highlight .kd {
|
214 |
+
color: #66d9ef;
|
215 |
+
} /* Keyword.Declaration */
|
216 |
+
#mGPT .highlight .kn {
|
217 |
+
color: #f92672;
|
218 |
+
} /* Keyword.Namespace */
|
219 |
+
#mGPT .highlight .kp {
|
220 |
+
color: #66d9ef;
|
221 |
+
} /* Keyword.Pseudo */
|
222 |
+
#mGPT .highlight .kr {
|
223 |
+
color: #66d9ef;
|
224 |
+
} /* Keyword.Reserved */
|
225 |
+
#mGPT .highlight .kt {
|
226 |
+
color: #66d9ef;
|
227 |
+
} /* Keyword.Type */
|
228 |
+
#mGPT .highlight .ld {
|
229 |
+
color: #e6db74;
|
230 |
+
} /* Literal.Date */
|
231 |
+
#mGPT .highlight .m {
|
232 |
+
color: #ae81ff;
|
233 |
+
} /* Literal.Number */
|
234 |
+
#mGPT .highlight .s {
|
235 |
+
color: #e6db74;
|
236 |
+
} /* Literal.String */
|
237 |
+
#mGPT .highlight .na {
|
238 |
+
color: #a6e22e;
|
239 |
+
} /* Name.Attribute */
|
240 |
+
#mGPT .highlight .nb {
|
241 |
+
color: #f8f8f2;
|
242 |
+
} /* Name.Builtin */
|
243 |
+
#mGPT .highlight .nc {
|
244 |
+
color: #a6e22e;
|
245 |
+
} /* Name.Class */
|
246 |
+
#mGPT .highlight .no {
|
247 |
+
color: #66d9ef;
|
248 |
+
} /* Name.Constant */
|
249 |
+
#mGPT .highlight .nd {
|
250 |
+
color: #a6e22e;
|
251 |
+
} /* Name.Decorator */
|
252 |
+
#mGPT .highlight .ni {
|
253 |
+
color: #f8f8f2;
|
254 |
+
} /* Name.Entity */
|
255 |
+
#mGPT .highlight .ne {
|
256 |
+
color: #a6e22e;
|
257 |
+
} /* Name.Exception */
|
258 |
+
#mGPT .highlight .nf {
|
259 |
+
color: #a6e22e;
|
260 |
+
} /* Name.Function */
|
261 |
+
#mGPT .highlight .nl {
|
262 |
+
color: #f8f8f2;
|
263 |
+
} /* Name.Label */
|
264 |
+
#mGPT .highlight .nn {
|
265 |
+
color: #f8f8f2;
|
266 |
+
} /* Name.Namespace */
|
267 |
+
#mGPT .highlight .nx {
|
268 |
+
color: #a6e22e;
|
269 |
+
} /* Name.Other */
|
270 |
+
#mGPT .highlight .py {
|
271 |
+
color: #f8f8f2;
|
272 |
+
} /* Name.Property */
|
273 |
+
#mGPT .highlight .nt {
|
274 |
+
color: #f92672;
|
275 |
+
} /* Name.Tag */
|
276 |
+
#mGPT .highlight .nv {
|
277 |
+
color: #f8f8f2;
|
278 |
+
} /* Name.Variable */
|
279 |
+
#mGPT .highlight .ow {
|
280 |
+
color: #f92672;
|
281 |
+
} /* Operator.Word */
|
282 |
+
#mGPT .highlight .w {
|
283 |
+
color: #f8f8f2;
|
284 |
+
} /* Text.Whitespace */
|
285 |
+
#mGPT .highlight .mb {
|
286 |
+
color: #ae81ff;
|
287 |
+
} /* Literal.Number.Bin */
|
288 |
+
#mGPT .highlight .mf {
|
289 |
+
color: #ae81ff;
|
290 |
+
} /* Literal.Number.Float */
|
291 |
+
#mGPT .highlight .mh {
|
292 |
+
color: #ae81ff;
|
293 |
+
} /* Literal.Number.Hex */
|
294 |
+
#mGPT .highlight .mi {
|
295 |
+
color: #ae81ff;
|
296 |
+
} /* Literal.Number.Integer */
|
297 |
+
#mGPT .highlight .mo {
|
298 |
+
color: #ae81ff;
|
299 |
+
} /* Literal.Number.Oct */
|
300 |
+
#mGPT .highlight .sa {
|
301 |
+
color: #e6db74;
|
302 |
+
} /* Literal.String.Affix */
|
303 |
+
#mGPT .highlight .sb {
|
304 |
+
color: #e6db74;
|
305 |
+
} /* Literal.String.Backtick */
|
306 |
+
#mGPT .highlight .sc {
|
307 |
+
color: #e6db74;
|
308 |
+
} /* Literal.String.Char */
|
309 |
+
#mGPT .highlight .dl {
|
310 |
+
color: #e6db74;
|
311 |
+
} /* Literal.String.Delimiter */
|
312 |
+
#mGPT .highlight .sd {
|
313 |
+
color: #e6db74;
|
314 |
+
} /* Literal.String.Doc */
|
315 |
+
#mGPT .highlight .s2 {
|
316 |
+
color: #e6db74;
|
317 |
+
} /* Literal.String.Double */
|
318 |
+
#mGPT .highlight .se {
|
319 |
+
color: #ae81ff;
|
320 |
+
} /* Literal.String.Escape */
|
321 |
+
#mGPT .highlight .sh {
|
322 |
+
color: #e6db74;
|
323 |
+
} /* Literal.String.Heredoc */
|
324 |
+
#mGPT .highlight .si {
|
325 |
+
color: #e6db74;
|
326 |
+
} /* Literal.String.Interpol */
|
327 |
+
#mGPT .highlight .sx {
|
328 |
+
color: #e6db74;
|
329 |
+
} /* Literal.String.Other */
|
330 |
+
#mGPT .highlight .sr {
|
331 |
+
color: #e6db74;
|
332 |
+
} /* Literal.String.Regex */
|
333 |
+
#mGPT .highlight .s1 {
|
334 |
+
color: #e6db74;
|
335 |
+
} /* Literal.String.Single */
|
336 |
+
#mGPT .highlight .ss {
|
337 |
+
color: #e6db74;
|
338 |
+
} /* Literal.String.Symbol */
|
339 |
+
#mGPT .highlight .bp {
|
340 |
+
color: #f8f8f2;
|
341 |
+
} /* Name.Builtin.Pseudo */
|
342 |
+
#mGPT .highlight .fm {
|
343 |
+
color: #a6e22e;
|
344 |
+
} /* Name.Function.Magic */
|
345 |
+
#mGPT .highlight .vc {
|
346 |
+
color: #f8f8f2;
|
347 |
+
} /* Name.Variable.Class */
|
348 |
+
#mGPT .highlight .vg {
|
349 |
+
color: #f8f8f2;
|
350 |
+
} /* Name.Variable.Global */
|
351 |
+
#mGPT .highlight .vi {
|
352 |
+
color: #f8f8f2;
|
353 |
+
} /* Name.Variable.Instance */
|
354 |
+
#mGPT .highlight .vm {
|
355 |
+
color: #f8f8f2;
|
356 |
+
} /* Name.Variable.Magic */
|
357 |
+
#mGPT .highlight .il {
|
358 |
+
color: #ae81ff;
|
359 |
+
} /* Literal.Number.Integer.Long */
|
assets/images/avatar_bot.jpg
ADDED
assets/meta/mean.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0bdb5ba69a3a9e34d71990db15bc535ebc024c8d95ddb5574196f96058faa7d3
|
3 |
+
size 2232
|
assets/meta/mean_eval.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0bdb5ba69a3a9e34d71990db15bc535ebc024c8d95ddb5574196f96058faa7d3
|
3 |
+
size 2232
|
assets/meta/std.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a5f7d60301c9465972fc225f8ad0ee8f957e7720431189123eb6d15873a9557
|
3 |
+
size 2232
|
assets/meta/std_eval.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a5f7d60301c9465972fc225f8ad0ee8f957e7720431189123eb6d15873a9557
|
3 |
+
size 2232
|
assets/videos/m2t_0.mp4
ADDED
Binary file (500 kB). View file
|
|
assets/videos/t2m_0.mp4
ADDED
Binary file (811 kB). View file
|
|
configs/assets.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONFIG_FOLDER: configs # Config files path
|
2 |
+
FOLDER: experiments # Experiment files saving path
|
3 |
+
|
4 |
+
TEST:
|
5 |
+
FOLDER: results # Testing files saving path
|
6 |
+
|
7 |
+
DATASET:
|
8 |
+
TASK_ROOT: deps/mGPT_instructions
|
9 |
+
SMPL_PATH: deps/smpl
|
10 |
+
TRANSFORM_PATH: deps/transforms/
|
11 |
+
WORD_VERTILIZER_PATH: deps/glove/
|
12 |
+
KIT:
|
13 |
+
ROOT: datasets/kit-ml # KIT directory
|
14 |
+
SPLIT_ROOT: datasets/kit-ml # KIT splits directory
|
15 |
+
MEAN_STD_PATH: deps/t2m/
|
16 |
+
HUMANML3D:
|
17 |
+
ROOT: datasets/humanml3d # HumanML3D directory
|
18 |
+
SPLIT_ROOT: datasets/humanml3d # HumanML3D splits directory
|
19 |
+
MEAN_STD_PATH: deps/t2m/
|
20 |
+
|
21 |
+
METRIC:
|
22 |
+
TM2T:
|
23 |
+
t2m_path: deps/t2m/ # path for tm2t evaluator
|
24 |
+
|
25 |
+
model:
|
26 |
+
whisper_path: openai/whisper-large-v2 # path for whisper model, webui only
|
27 |
+
|
28 |
+
RENDER:
|
29 |
+
BLENDER_PATH: libs/blender-2.93.2-linux-x64/blender
|
30 |
+
SMPL_MODEL_PATH: deps/smpl/smpl_models/smpl
|
31 |
+
MODEL_PATH: deps/smpl/smpl_models/
|
32 |
+
FACES_PATH: deps/smplh/smplh.faces
|
configs/default.yaml
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SEED_VALUE: 1234 # Seed value
|
2 |
+
DEBUG: True # Debug mode
|
3 |
+
FULL_CONFIG: false
|
4 |
+
|
5 |
+
TRAIN:
|
6 |
+
SPLIT: 'train' # Training split name
|
7 |
+
NUM_WORKERS: 8 # Number of workers
|
8 |
+
BATCH_SIZE: 8 # Size of batches
|
9 |
+
END_EPOCH: 2000 # End epoch
|
10 |
+
|
11 |
+
RESUME: '' # Experiment path to be resumed training
|
12 |
+
PRETRAINED_VAE: '' # Pretrained vae/vqvae model path
|
13 |
+
PRETRAINED: '' # Pretrained model path
|
14 |
+
|
15 |
+
OPTIM:
|
16 |
+
target: AdamW
|
17 |
+
params:
|
18 |
+
lr: 2e-4
|
19 |
+
betas: [0.9, 0.99]
|
20 |
+
weight_decay: 0.0
|
21 |
+
|
22 |
+
LR_SCHEDULER:
|
23 |
+
target: CosineAnnealingLR
|
24 |
+
params:
|
25 |
+
T_max: ${eval:${LOGGER.VAL_EVERY_STEPS} * 100}
|
26 |
+
eta_min: 1e-6
|
27 |
+
|
28 |
+
EVAL:
|
29 |
+
SPLIT: 'val' # Validation split name
|
30 |
+
BATCH_SIZE: 16 # Validation Batch size
|
31 |
+
NUM_WORKERS: 8 # Validation Batch size
|
32 |
+
|
33 |
+
TEST:
|
34 |
+
CHECKPOINTS: '' # Pretrained model path
|
35 |
+
SPLIT: 'test' # Testing split name
|
36 |
+
BATCH_SIZE: 16 # Testing Batch size
|
37 |
+
NUM_WORKERS: 8 # Testing Batch size
|
38 |
+
|
39 |
+
SAVE_PREDICTIONS: False # Weather to save predictions
|
40 |
+
COUNT_TIME: False # Weather to count time during test
|
41 |
+
REPLICATION_TIMES: 20 # Number of times to replicate the test
|
42 |
+
REP_I: 0 # For counting replication times
|
43 |
+
|
44 |
+
model:
|
45 |
+
target: mGPT.models.mgpt.MotionGPT
|
46 |
+
params:
|
47 |
+
condition: 'text'
|
48 |
+
task: 't2m'
|
49 |
+
lm: ${lm.default}
|
50 |
+
motion_vae: ${vq.default}
|
51 |
+
|
52 |
+
# Related parameters
|
53 |
+
stage: ${TRAIN.STAGE}
|
54 |
+
debug: ${DEBUG}
|
55 |
+
codebook_size: ${model.params.motion_vae.params.code_num}
|
56 |
+
metrics_dict: ${METRIC.TYPE}
|
57 |
+
|
58 |
+
LOSS:
|
59 |
+
LAMBDA_REC: 1.0 # Lambda for reconstruction losses
|
60 |
+
LAMBDA_JOINT: 1.0 # Lambda for joint losses
|
61 |
+
|
62 |
+
LAMBDA_LATENT: 1e-5 # Lambda for latent losses
|
63 |
+
LAMBDA_KL: 1e-5 # Lambda for kl losses
|
64 |
+
LAMBDA_GEN: 1.0 # Lambda for text-motion generation losses
|
65 |
+
LAMBDA_CROSS: 1.0 # Lambda for cross-reconstruction losses
|
66 |
+
LAMBDA_CYCLE: 1.0 # Lambda for cycle losses
|
67 |
+
LAMBDA_PRIOR: 0.0 # Lambda for diffusion prior losses
|
68 |
+
|
69 |
+
LAMBDA_VELOCITY: 0.5 # Lambda for velocity losses
|
70 |
+
LAMBDA_COMMIT: 0.02 # Lambda for commitment losses
|
71 |
+
|
72 |
+
ABLATION:
|
73 |
+
RECONS_LOSS: 'l1_smooth'
|
74 |
+
|
75 |
+
METRIC:
|
76 |
+
TASK: 't2m'
|
77 |
+
FORCE_IN_METER: True
|
78 |
+
DIST_SYNC_ON_STEP: True
|
79 |
+
MM_NUM_SAMPLES: 100 # Number of samples for multimodal test
|
80 |
+
MM_NUM_REPEATS: 30 # Number of repeats for multimodal test
|
81 |
+
MM_NUM_TIMES: 10 # Number of times to repeat the multimodal test
|
82 |
+
DIVERSITY_TIMES: 300 # Number of times to repeat the diversity test
|
83 |
+
TM2T: ${evaluator.tm2t}
|
84 |
+
|
85 |
+
DATASET:
|
86 |
+
target: mGPT.data.HumanML3D.HumanML3DDataModule
|
87 |
+
CODE_PATH: 'VQVAE'
|
88 |
+
TASK_ROOT: ''
|
89 |
+
TASK_PATH: ''
|
90 |
+
NFEATS: 263
|
91 |
+
KIT:
|
92 |
+
MAX_MOTION_LEN: 196
|
93 |
+
MIN_MOTION_LEN: 24
|
94 |
+
MAX_TEXT_LEN: 20
|
95 |
+
PICK_ONE_TEXT: true
|
96 |
+
FRAME_RATE: 12.5
|
97 |
+
UNIT_LEN: 4
|
98 |
+
HUMANML3D:
|
99 |
+
MAX_MOTION_LEN: 196
|
100 |
+
MIN_MOTION_LEN: 40
|
101 |
+
MAX_TEXT_LEN: 20
|
102 |
+
PICK_ONE_TEXT: true
|
103 |
+
FRAME_RATE: 20.0
|
104 |
+
UNIT_LEN: 4
|
105 |
+
STD_TEXT: False
|
106 |
+
|
107 |
+
ABLATION:
|
108 |
+
# For MotionGPT
|
109 |
+
use_length: False
|
110 |
+
predict_ratio: 0.2
|
111 |
+
inbetween_ratio: 0.25
|
112 |
+
image_size: 256
|
113 |
+
|
114 |
+
# For Motion-latent-diffusion
|
115 |
+
VAE_TYPE: 'actor' # vae ablation: actor or mcross
|
116 |
+
VAE_ARCH: 'encoder_decoder' # mdiffusion vae architecture
|
117 |
+
PE_TYPE: 'actor' # mdiffusion mld or actor
|
118 |
+
DIFF_PE_TYPE: 'actor' # mdiffusion mld or actor
|
119 |
+
SKIP_CONNECT: False # skip connection for denoiser va
|
120 |
+
MLP_DIST: False # use linear to expand mean and std rather expand token nums
|
121 |
+
IS_DIST: False # Mcross distribution kl
|
122 |
+
PREDICT_EPSILON: True # noise or motion
|
123 |
+
|
124 |
+
LOGGER:
|
125 |
+
VAL_EVERY_STEPS: 10
|
126 |
+
LOGGERS: ['tensorboard', 'wandb']
|
127 |
+
TENSORBOARD:
|
128 |
+
target: pytorch_lightning.loggers.TensorBoardLogger
|
129 |
+
params:
|
130 |
+
save_dir: ${FOLDER_EXP}
|
131 |
+
name: 'tensorboard'
|
132 |
+
version: ''
|
133 |
+
WANDB:
|
134 |
+
target: pytorch_lightning.loggers.WandbLogger
|
135 |
+
params:
|
136 |
+
project: null
|
137 |
+
offline: False
|
138 |
+
id: null
|
139 |
+
version: ''
|
140 |
+
name: ${NAME}
|
141 |
+
save_dir: ${FOLDER_EXP}
|
configs/evaluator/tm2t.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
t2m_textencoder:
|
2 |
+
target: mGPT.archs.tm2t_evaluator.TextEncoderBiGRUCo
|
3 |
+
params:
|
4 |
+
word_size: 300
|
5 |
+
pos_size: 15
|
6 |
+
hidden_size: 512
|
7 |
+
output_size: 512
|
8 |
+
t2m_moveencoder:
|
9 |
+
target: mGPT.archs.tm2t_evaluator.MovementConvEncoder
|
10 |
+
params:
|
11 |
+
input_size: ${eval:${DATASET.NFEATS} - 4}
|
12 |
+
hidden_size: 512
|
13 |
+
output_size: 512
|
14 |
+
t2m_motionencoder:
|
15 |
+
target: mGPT.archs.tm2t_evaluator.MotionEncoderBiGRUCo
|
16 |
+
params:
|
17 |
+
input_size: ${evaluator.tm2t.t2m_moveencoder.params.output_size}
|
18 |
+
hidden_size: 1024
|
19 |
+
output_size: 512
|
configs/lm/default.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
target: mGPT.archs.mgpt_lm.MLM
|
2 |
+
params:
|
3 |
+
model_type: t5
|
4 |
+
model_path: google/flan-t5-base
|
5 |
+
stage: ${TRAIN.STAGE}
|
6 |
+
motion_codebook_size: ${model.params.codebook_size}
|
7 |
+
ablation: ${ABLATION}
|
configs/render.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NAME: '___render_do_not_need_name__' # Experiment name
|
2 |
+
ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
|
3 |
+
DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3]
|
4 |
+
|
5 |
+
RENDER:
|
6 |
+
FOLDER: '___no_need__'
|
7 |
+
INPUT_MODE: 'npy'
|
8 |
+
DIR: ''
|
9 |
+
NPY: '___no_need__'
|
10 |
+
DENOISING: True
|
11 |
+
OLDRENDER: True
|
12 |
+
# ["ultra", "high", "med", "low"]
|
13 |
+
# RES: 'high'
|
14 |
+
RES: 'med'
|
15 |
+
DOWNSAMPLE: False
|
16 |
+
FPS: 20.0
|
17 |
+
CANONICALIZE: True
|
18 |
+
EXACT_FRAME: 0.5
|
19 |
+
NUM: 8
|
20 |
+
MODE: '___no_need__' #sequence frame video
|
21 |
+
VID_EXT: mp4
|
22 |
+
ALWAYS_ON_FLOOR: false
|
23 |
+
GT: false
|
configs/vq/default.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
target: mGPT.archs.mgpt_vq.VQVae
|
2 |
+
params:
|
3 |
+
quantizer: 'ema_reset'
|
4 |
+
code_num: 512
|
5 |
+
code_dim: 512
|
6 |
+
output_emb_width: 512
|
7 |
+
down_t: 2
|
8 |
+
stride_t: 2
|
9 |
+
width: 512
|
10 |
+
depth: 3
|
11 |
+
dilation_growth_rate: 3
|
12 |
+
norm: None
|
13 |
+
activation: 'relu'
|
14 |
+
nfeats: ${DATASET.NFEATS}
|
15 |
+
ablation: ${ABLATION}
|
configs/webui.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NAME: Webui # Experiment name
|
2 |
+
DEBUG: False # Debug mode
|
3 |
+
ACCELERATOR: 'cpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
|
4 |
+
DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3]
|
5 |
+
|
6 |
+
# Training configuration
|
7 |
+
TRAIN:
|
8 |
+
#---------------------------------
|
9 |
+
STAGE: lm_instruct
|
10 |
+
DATASETS: ['humanml3d'] # Training datasets
|
11 |
+
NUM_WORKERS: 32 # Number of workers
|
12 |
+
BATCH_SIZE: 16 # Size of batches
|
13 |
+
START_EPOCH: 0 # Start epochMMOTIONENCODER
|
14 |
+
END_EPOCH: 99999 # End epoch
|
15 |
+
ABLATION:
|
16 |
+
pkeep: 0.5
|
17 |
+
OPTIM:
|
18 |
+
TYPE: AdamW # Optimizer type
|
19 |
+
LR: 2e-4 # Learning rate
|
20 |
+
WEIGHT_DECAY: 0.0
|
21 |
+
LR_SCHEDULER: [100, 200, 300, 400]
|
22 |
+
GAMMA: 0.8
|
23 |
+
|
24 |
+
# Evaluating Configuration
|
25 |
+
EVAL:
|
26 |
+
DATASETS: ['humanml3d'] # Evaluating datasets
|
27 |
+
BATCH_SIZE: 32 # Evaluating Batch size
|
28 |
+
SPLIT: test
|
29 |
+
|
30 |
+
# Test Configuration
|
31 |
+
TEST:
|
32 |
+
CHECKPOINTS: checkpoints/MotionGPT-base/motiongpt_s3_h3d.ckpt
|
33 |
+
DATASETS: ['humanml3d'] # training datasets
|
34 |
+
SPLIT: test
|
35 |
+
BATCH_SIZE: 32 # training Batch size
|
36 |
+
MEAN: False
|
37 |
+
NUM_SAMPLES: 1
|
38 |
+
FACT: 1
|
39 |
+
|
40 |
+
# Datasets Configuration
|
41 |
+
DATASET:
|
42 |
+
JOINT_TYPE: 'humanml3d' # join type
|
43 |
+
CODE_PATH: 'VQBEST'
|
44 |
+
METRIC:
|
45 |
+
TYPE: ['TM2TMetrics']
|
46 |
+
# Losses Configuration
|
47 |
+
LOSS:
|
48 |
+
TYPE: t2mgpt # Losses type
|
49 |
+
LAMBDA_FEATURE: 1.0
|
50 |
+
LAMBDA_VELOCITY: 0.5
|
51 |
+
LAMBDA_COMMIT: 0.02
|
52 |
+
LAMBDA_CLS: 1.0
|
53 |
+
LAMBDA_M2T2M: 1.0
|
54 |
+
LAMBDA_T2M2T: 10.0
|
55 |
+
ABLATION:
|
56 |
+
RECONS_LOSS: 'l1_smooth'
|
57 |
+
|
58 |
+
# Model Configuration
|
59 |
+
model:
|
60 |
+
target: mGPT.models.mgpt.MotionGPT
|
61 |
+
params:
|
62 |
+
condition: 'text'
|
63 |
+
task: 't2m'
|
64 |
+
lm: ${lm.default}
|
65 |
+
motion_vae: ${vq.default}
|
66 |
+
|
67 |
+
# Logger configuration
|
68 |
+
LOGGER:
|
69 |
+
LOG_EVERY_STEPS: 5
|
70 |
+
VAL_EVERY_STEPS: 10
|
71 |
+
TENSORBOARD: True
|
72 |
+
wandb:
|
73 |
+
params:
|
74 |
+
project: null
|
deps/smpl/smpl_models/SMPL_downsample_index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e5b783c1677079397ee4bc26df5c72d73b8bb393bea41fa295b951187443daec
|
3 |
+
size 3556
|
deps/smpl/smpl_models/gmm_08.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1374908aae055a2afa01a2cd9a169bc6cfec1ceb7aa590e201a47b383060491
|
3 |
+
size 839127
|
deps/smpl/smpl_models/neutral_smpl_mean_params.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac9b474c74daec0253ed084720f662059336e976850f08a4a9a3f76d06613776
|
3 |
+
size 4848
|
deps/smpl/smpl_models/smpl.faces
ADDED
Binary file (331 kB). View file
|
|
deps/smpl/smpl_models/smpl.tar.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bf4793af6b29677b0841c58db392642cb70b477890dc91de01128c7f34738d8d
|
3 |
+
size 45
|
deps/smpl/smpl_models/smpl/SMPL_FEMALE.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a583c1b98e4afc19042641f1bae5cd8a1f712a6724886291a7627ec07acd408d
|
3 |
+
size 39056454
|
deps/smpl/smpl_models/smpl/SMPL_MALE.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0e8c0bbbbc635dcb166ed29c303fb4bef16ea5f623e5a89263495a9e403575bd
|
3 |
+
size 39056404
|
deps/smpl/smpl_models/smpl/SMPL_NEUTRAL.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:98e65c74ad9b998783132f00880d1025a8d64b158e040e6ef13a557e5098bc42
|
3 |
+
size 39001280
|
deps/smpl/smpl_models/smpl/readme.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
This directory leaves for SMPL models
|
deps/smpl/smpl_models/smplh/SMPLH_FEMALE.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f0fba73ef2494b26de243c1d88a1dbe1047e5566128cf7222c942089543f4560
|
3 |
+
size 39708434
|
deps/smpl/smpl_models/smplh/SMPLH_MALE.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10b617fdd329557937d6fe38e8a542afab236a8887522d9da0bd42e7f2b76eaa
|
3 |
+
size 39686902
|
deps/smpl/smpl_models/smplh/SMPLH_NEUTRAL.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:42969b34d8cd383e172515a7bca6ff3b2c37aa2c5c78088c69d20e517fa96026
|
3 |
+
size 39708959
|
deps/smpl/smpl_models/smplh/mano_v1_2.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50976831790ea9657d8110e0c94e50e90eaf35cd76169f0b27e5d32f3fcd951f
|
3 |
+
size 175200815
|
deps/smpl/smpl_models/smplh/smplh.faces
ADDED
Binary file (165 kB). View file
|
|
deps/smpl/smpl_models/smplh/smplh.tar.xz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:46d5b8687be48c91181fa88271feff3a5e83aa62a481fad8a0bcb9254b2a74f1
|
3 |
+
size 113231292
|
deps/smpl/smpl_models/smplx_parts_segm.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb69c10801205c9cfb5353fdeb1b9cc5ade53d14c265c3339421cdde8b9c91e7
|
3 |
+
size 1323168
|
mGPT/__init__.py
ADDED
File without changes
|
mGPT/archs/__init__.py
ADDED
File without changes
|
mGPT/archs/mgpt_lm.py
ADDED
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Union
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
import time
|
6 |
+
import heapq
|
7 |
+
import torch
|
8 |
+
from torch import Tensor, nn
|
9 |
+
from torch.distributions.distribution import Distribution
|
10 |
+
from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer
|
11 |
+
import random
|
12 |
+
from typing import Optional
|
13 |
+
from .tools.token_emb import NewTokenEmb
|
14 |
+
|
15 |
+
|
16 |
+
class MLM(nn.Module):
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
model_path: str,
|
21 |
+
model_type: str = "t5",
|
22 |
+
stage: str = "lm_pretrain",
|
23 |
+
new_token_type: str = "insert",
|
24 |
+
motion_codebook_size: int = 512,
|
25 |
+
framerate: float = 20.0,
|
26 |
+
down_t: int = 4,
|
27 |
+
predict_ratio: float = 0.2,
|
28 |
+
inbetween_ratio: float = 0.25,
|
29 |
+
max_length: int = 256,
|
30 |
+
lora: bool = False,
|
31 |
+
quota_ratio: float = 0.5,
|
32 |
+
noise_density: float = 0.15,
|
33 |
+
mean_noise_span_length: int = 3,
|
34 |
+
**kwargs,
|
35 |
+
) -> None:
|
36 |
+
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
# Parameters
|
40 |
+
self.m_codebook_size = motion_codebook_size
|
41 |
+
self.max_length = max_length
|
42 |
+
self.framerate = framerate
|
43 |
+
self.down_t = down_t
|
44 |
+
self.predict_ratio = predict_ratio
|
45 |
+
self.inbetween_ratio = inbetween_ratio
|
46 |
+
self.noise_density = noise_density
|
47 |
+
self.mean_noise_span_length = mean_noise_span_length
|
48 |
+
self.quota_ratio = quota_ratio
|
49 |
+
self.stage = stage
|
50 |
+
|
51 |
+
# Instantiate language model
|
52 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=True)
|
53 |
+
if model_type == "t5":
|
54 |
+
self.language_model = T5ForConditionalGeneration.from_pretrained(
|
55 |
+
model_path)
|
56 |
+
self.lm_type = 'encdec'
|
57 |
+
elif model_type == "gpt2":
|
58 |
+
self.language_model = GPT2LMHeadModel.from_pretrained(model_path)
|
59 |
+
self.lm_type = 'dec'
|
60 |
+
else:
|
61 |
+
raise ValueError("type must be either seq2seq or conditional")
|
62 |
+
|
63 |
+
if self.lm_type == 'dec':
|
64 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
65 |
+
|
66 |
+
# Add motion tokens
|
67 |
+
self.tokenizer.add_tokens(
|
68 |
+
[f'<motion_id_{i}>' for i in range(self.m_codebook_size + 3)])
|
69 |
+
|
70 |
+
if new_token_type == "insert":
|
71 |
+
self.language_model.resize_token_embeddings(len(self.tokenizer))
|
72 |
+
elif new_token_type == "mlp":
|
73 |
+
shared = NewTokenEmb(self.language_model.shared,
|
74 |
+
self.m_codebook_size + 3)
|
75 |
+
# lm_head = NewTokenEmb(self.language_model.lm_head,
|
76 |
+
# self.m_codebook_size + 3)
|
77 |
+
self.language_model.resize_token_embeddings(len(self.tokenizer))
|
78 |
+
self.language_model.shared = shared
|
79 |
+
# self.language_model.lm_head = lm_head
|
80 |
+
|
81 |
+
# Lora
|
82 |
+
if lora:
|
83 |
+
from peft import LoraConfig, TaskType, get_peft_model, get_peft_model_state_dict
|
84 |
+
from peft.utils.other import fsdp_auto_wrap_policy
|
85 |
+
peft_config = LoraConfig(
|
86 |
+
bias="none",
|
87 |
+
task_type="CAUSAL_LM",
|
88 |
+
# inference_mode=False,
|
89 |
+
r=8,
|
90 |
+
lora_alpha=16,
|
91 |
+
lora_dropout=0.05)
|
92 |
+
self.language_model = get_peft_model(self.language_model,
|
93 |
+
peft_config)
|
94 |
+
|
95 |
+
def forward(self, texts: List[str], motion_tokens: Tensor,
|
96 |
+
lengths: List[int], tasks: dict):
|
97 |
+
if self.lm_type == 'encdec':
|
98 |
+
return self.forward_encdec(texts, motion_tokens, lengths, tasks)
|
99 |
+
elif self.lm_type == 'dec':
|
100 |
+
return self.forward_dec(texts, motion_tokens, lengths, tasks)
|
101 |
+
else:
|
102 |
+
raise NotImplementedError("Only conditional_multitask supported")
|
103 |
+
|
104 |
+
def forward_encdec(
|
105 |
+
self,
|
106 |
+
texts: List[str],
|
107 |
+
motion_tokens: Tensor,
|
108 |
+
lengths: List[int],
|
109 |
+
tasks: dict,
|
110 |
+
):
|
111 |
+
|
112 |
+
# Tensor to string
|
113 |
+
motion_strings = self.motion_token_to_string(motion_tokens, lengths)
|
114 |
+
|
115 |
+
# Supervised or unsupervised
|
116 |
+
# condition = random.choice(
|
117 |
+
# ['text', 'motion', 'supervised', 'supervised', 'supervised'])
|
118 |
+
condition = random.choice(['supervised', 'supervised', 'supervised'])
|
119 |
+
|
120 |
+
if condition == 'text':
|
121 |
+
inputs = texts
|
122 |
+
outputs = texts
|
123 |
+
elif condition == 'motion':
|
124 |
+
inputs = motion_strings
|
125 |
+
outputs = motion_strings
|
126 |
+
else:
|
127 |
+
inputs, outputs = self.template_fulfill(tasks, lengths,
|
128 |
+
motion_strings, texts)
|
129 |
+
|
130 |
+
# Tokenize
|
131 |
+
source_encoding = self.tokenizer(inputs,
|
132 |
+
padding='max_length',
|
133 |
+
max_length=self.max_length,
|
134 |
+
truncation=True,
|
135 |
+
return_attention_mask=True,
|
136 |
+
add_special_tokens=True,
|
137 |
+
return_tensors="pt")
|
138 |
+
|
139 |
+
source_attention_mask = source_encoding.attention_mask.to(
|
140 |
+
motion_tokens.device)
|
141 |
+
source_input_ids = source_encoding.input_ids.to(motion_tokens.device)
|
142 |
+
|
143 |
+
if condition in ['text', 'motion']:
|
144 |
+
batch_size, expandend_input_length = source_input_ids.shape
|
145 |
+
mask_indices = np.asarray([
|
146 |
+
self.random_spans_noise_mask(expandend_input_length)
|
147 |
+
for i in range(batch_size)
|
148 |
+
])
|
149 |
+
target_mask = ~mask_indices
|
150 |
+
input_ids_sentinel = self.create_sentinel_ids(
|
151 |
+
mask_indices.astype(np.int8))
|
152 |
+
target_sentinel = self.create_sentinel_ids(
|
153 |
+
target_mask.astype(np.int8))
|
154 |
+
|
155 |
+
labels_input_ids = self.filter_input_ids(source_input_ids,
|
156 |
+
target_sentinel)
|
157 |
+
source_input_ids = self.filter_input_ids(source_input_ids,
|
158 |
+
input_ids_sentinel)
|
159 |
+
|
160 |
+
else:
|
161 |
+
target_inputs = self.tokenizer(outputs,
|
162 |
+
padding='max_length',
|
163 |
+
max_length=self.max_length,
|
164 |
+
truncation=True,
|
165 |
+
return_attention_mask=True,
|
166 |
+
add_special_tokens=True,
|
167 |
+
return_tensors="pt")
|
168 |
+
|
169 |
+
labels_input_ids = target_inputs.input_ids.to(motion_tokens.device)
|
170 |
+
lables_attention_mask = target_inputs.attention_mask.to(
|
171 |
+
motion_tokens.device)
|
172 |
+
|
173 |
+
labels_input_ids[labels_input_ids == 0] = -100
|
174 |
+
outputs = self.language_model(
|
175 |
+
input_ids=source_input_ids,
|
176 |
+
attention_mask=source_attention_mask
|
177 |
+
if condition == 'supervised' else None,
|
178 |
+
labels=labels_input_ids,
|
179 |
+
decoder_attention_mask=lables_attention_mask
|
180 |
+
if condition == 'supervised' else None,
|
181 |
+
)
|
182 |
+
|
183 |
+
return outputs
|
184 |
+
|
185 |
+
def forward_dec(
|
186 |
+
self,
|
187 |
+
texts: List[str],
|
188 |
+
motion_tokens: Tensor,
|
189 |
+
lengths: List[int],
|
190 |
+
tasks: dict,
|
191 |
+
):
|
192 |
+
self.tokenizer.padding_side = "right"
|
193 |
+
|
194 |
+
# Tensor to string
|
195 |
+
motion_strings = self.motion_token_to_string(motion_tokens, lengths)
|
196 |
+
|
197 |
+
# Supervised or unsupervised
|
198 |
+
condition = random.choice(
|
199 |
+
['text', 'motion', 'supervised', 'supervised', 'supervised'])
|
200 |
+
|
201 |
+
if condition == 'text':
|
202 |
+
labels = texts
|
203 |
+
elif condition == 'motion':
|
204 |
+
labels = motion_strings
|
205 |
+
else:
|
206 |
+
inputs, outputs = self.template_fulfill(tasks, lengths,
|
207 |
+
motion_strings, texts)
|
208 |
+
labels = []
|
209 |
+
for i in range(len(inputs)):
|
210 |
+
labels.append(inputs[i] + ' \n ' + outputs[i] +
|
211 |
+
self.tokenizer.eos_token)
|
212 |
+
|
213 |
+
# Tokenize
|
214 |
+
inputs = self.tokenizer(labels,
|
215 |
+
padding='max_length',
|
216 |
+
max_length=self.max_length,
|
217 |
+
truncation=True,
|
218 |
+
return_attention_mask=True,
|
219 |
+
return_tensors="pt")
|
220 |
+
|
221 |
+
labels_input_ids = inputs.input_ids.to(motion_tokens.device)
|
222 |
+
lables_attention_mask = inputs.attention_mask.to(motion_tokens.device)
|
223 |
+
|
224 |
+
# print(labels_input_ids[0:5])
|
225 |
+
|
226 |
+
outputs = self.language_model(input_ids=labels_input_ids,
|
227 |
+
attention_mask=lables_attention_mask,
|
228 |
+
labels=inputs["input_ids"])
|
229 |
+
|
230 |
+
return outputs
|
231 |
+
|
232 |
+
def generate_direct(self,
|
233 |
+
texts: List[str],
|
234 |
+
max_length: int = 256,
|
235 |
+
num_beams: int = 1,
|
236 |
+
do_sample: bool = True,
|
237 |
+
bad_words_ids: List[int] = None):
|
238 |
+
|
239 |
+
# Device
|
240 |
+
self.device = self.language_model.device
|
241 |
+
|
242 |
+
# Tokenize
|
243 |
+
if self.lm_type == 'dec':
|
244 |
+
texts = [text + " \n " for text in texts]
|
245 |
+
|
246 |
+
source_encoding = self.tokenizer(texts,
|
247 |
+
padding='max_length',
|
248 |
+
max_length=self.max_length,
|
249 |
+
truncation=True,
|
250 |
+
return_attention_mask=True,
|
251 |
+
add_special_tokens=True,
|
252 |
+
return_tensors="pt")
|
253 |
+
|
254 |
+
source_input_ids = source_encoding.input_ids.to(self.device)
|
255 |
+
source_attention_mask = source_encoding.attention_mask.to(self.device)
|
256 |
+
|
257 |
+
if self.lm_type == 'encdec':
|
258 |
+
outputs = self.language_model.generate(
|
259 |
+
source_input_ids,
|
260 |
+
max_length=max_length,
|
261 |
+
num_beams=num_beams,
|
262 |
+
do_sample=do_sample,
|
263 |
+
bad_words_ids=bad_words_ids,
|
264 |
+
)
|
265 |
+
elif self.lm_type == 'dec':
|
266 |
+
outputs = self.language_model.generate(
|
267 |
+
input_ids=source_input_ids,
|
268 |
+
attention_mask=source_attention_mask,
|
269 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
270 |
+
do_sample=do_sample,
|
271 |
+
max_new_tokens=max_length)
|
272 |
+
self.tokenizer.padding_side = 'left'
|
273 |
+
|
274 |
+
outputs_string = self.tokenizer.batch_decode(outputs,
|
275 |
+
skip_special_tokens=True)
|
276 |
+
|
277 |
+
print(texts[:2])
|
278 |
+
print(outputs_string[:2])
|
279 |
+
|
280 |
+
outputs_tokens, cleaned_text = self.motion_string_to_token(
|
281 |
+
outputs_string)
|
282 |
+
|
283 |
+
return outputs_tokens, cleaned_text
|
284 |
+
|
285 |
+
def generate_conditional(self,
|
286 |
+
texts: Optional[List[str]] = None,
|
287 |
+
motion_tokens: Optional[Tensor] = None,
|
288 |
+
lengths: Optional[List[int]] = None,
|
289 |
+
task: str = "t2m",
|
290 |
+
with_len: bool = False,
|
291 |
+
stage: str = 'train',
|
292 |
+
tasks: dict = None):
|
293 |
+
|
294 |
+
self.device = self.language_model.device
|
295 |
+
|
296 |
+
if task in ["t2m", "m2m", "pred", "inbetween"]:
|
297 |
+
|
298 |
+
if task == "t2m":
|
299 |
+
assert texts is not None
|
300 |
+
motion_strings = [''] * len(texts)
|
301 |
+
if not with_len:
|
302 |
+
if tasks is None:
|
303 |
+
tasks = [{
|
304 |
+
'input':
|
305 |
+
['Generate motion: <Caption_Placeholder>'],
|
306 |
+
'output': ['']
|
307 |
+
}] * len(texts)
|
308 |
+
|
309 |
+
lengths = [0] * len(texts)
|
310 |
+
else:
|
311 |
+
tasks = [{
|
312 |
+
'input': [
|
313 |
+
'Generate motion with <Frame_Placeholder> frames: <Caption_Placeholder>'
|
314 |
+
],
|
315 |
+
'output': ['']
|
316 |
+
}] * len(texts)
|
317 |
+
|
318 |
+
elif task == "pred":
|
319 |
+
assert motion_tokens is not None and lengths is not None
|
320 |
+
texts = [''] * len(lengths)
|
321 |
+
tasks = [{
|
322 |
+
'input': ['Predict motion: <Motion_Placeholder_s1>'],
|
323 |
+
'output': ['']
|
324 |
+
}] * len(lengths)
|
325 |
+
|
326 |
+
motion_strings_old = self.motion_token_to_string(
|
327 |
+
motion_tokens, lengths)
|
328 |
+
motion_strings = []
|
329 |
+
for i, length in enumerate(lengths):
|
330 |
+
split = length // 5
|
331 |
+
motion_strings.append(
|
332 |
+
'>'.join(motion_strings_old[i].split('>')[:split]) +
|
333 |
+
'>')
|
334 |
+
|
335 |
+
elif task == "inbetween":
|
336 |
+
assert motion_tokens is not None and lengths is not None
|
337 |
+
texts = [''] * len(lengths)
|
338 |
+
tasks = [{
|
339 |
+
'input': [
|
340 |
+
"Complete the masked motion: <Motion_Placeholder_Masked>"
|
341 |
+
],
|
342 |
+
'output': ['']
|
343 |
+
}] * len(lengths)
|
344 |
+
motion_strings = self.motion_token_to_string(
|
345 |
+
motion_tokens, lengths)
|
346 |
+
|
347 |
+
inputs, outputs = self.template_fulfill(tasks, lengths,
|
348 |
+
motion_strings, texts,
|
349 |
+
stage)
|
350 |
+
|
351 |
+
outputs_tokens, cleaned_text = self.generate_direct(inputs,
|
352 |
+
max_length=128,
|
353 |
+
num_beams=1,
|
354 |
+
do_sample=True)
|
355 |
+
|
356 |
+
return outputs_tokens
|
357 |
+
|
358 |
+
elif task == "m2t":
|
359 |
+
assert motion_tokens is not None and lengths is not None
|
360 |
+
|
361 |
+
motion_strings = self.motion_token_to_string(
|
362 |
+
motion_tokens, lengths)
|
363 |
+
|
364 |
+
if not with_len:
|
365 |
+
tasks = [{
|
366 |
+
'input': ['Generate text: <Motion_Placeholder>'],
|
367 |
+
'output': ['']
|
368 |
+
}] * len(lengths)
|
369 |
+
else:
|
370 |
+
tasks = [{
|
371 |
+
'input': [
|
372 |
+
'Generate text with <Frame_Placeholder> frames: <Motion_Placeholder>'
|
373 |
+
],
|
374 |
+
'output': ['']
|
375 |
+
}] * len(lengths)
|
376 |
+
|
377 |
+
texts = [''] * len(lengths)
|
378 |
+
|
379 |
+
inputs, outputs = self.template_fulfill(tasks, lengths,
|
380 |
+
motion_strings, texts)
|
381 |
+
outputs_tokens, cleaned_text = self.generate_direct(
|
382 |
+
inputs,
|
383 |
+
max_length=40,
|
384 |
+
num_beams=1,
|
385 |
+
do_sample=False,
|
386 |
+
# bad_words_ids=self.bad_words_ids
|
387 |
+
)
|
388 |
+
return cleaned_text
|
389 |
+
|
390 |
+
def motion_token_to_string(self, motion_token: Tensor, lengths: List[int]):
|
391 |
+
motion_string = []
|
392 |
+
for i in range(len(motion_token)):
|
393 |
+
motion_i = motion_token[i].cpu(
|
394 |
+
) if motion_token[i].device.type == 'cuda' else motion_token[i]
|
395 |
+
motion_list = motion_i.tolist()[:lengths[i]]
|
396 |
+
motion_string.append(
|
397 |
+
(f'<motion_id_{self.m_codebook_size}>' +
|
398 |
+
''.join([f'<motion_id_{int(i)}>' for i in motion_list]) +
|
399 |
+
f'<motion_id_{self.m_codebook_size + 1}>'))
|
400 |
+
return motion_string
|
401 |
+
|
402 |
+
def motion_token_list_to_string(self, motion_token: Tensor):
|
403 |
+
motion_string = []
|
404 |
+
for i in range(len(motion_token)):
|
405 |
+
motion_i = motion_token[i].cpu(
|
406 |
+
) if motion_token[i].device.type == 'cuda' else motion_token[i]
|
407 |
+
motion_list = motion_i.tolist()
|
408 |
+
motion_string.append(
|
409 |
+
(f'<motion_id_{self.m_codebook_size}>' +
|
410 |
+
''.join([f'<motion_id_{int(i)}>' for i in motion_list]) +
|
411 |
+
f'<motion_id_{self.m_codebook_size + 1}>'))
|
412 |
+
return motion_string
|
413 |
+
|
414 |
+
def motion_string_to_token(self, motion_string: List[str]):
|
415 |
+
motion_tokens = []
|
416 |
+
output_string = []
|
417 |
+
for i in range(len(motion_string)):
|
418 |
+
string = self.get_middle_str(
|
419 |
+
motion_string[i], f'<motion_id_{self.m_codebook_size}>',
|
420 |
+
f'<motion_id_{self.m_codebook_size + 1}>')
|
421 |
+
string_list = string.split('><')
|
422 |
+
token_list = [
|
423 |
+
int(i.split('_')[-1].replace('>', ''))
|
424 |
+
for i in string_list[1:-1]
|
425 |
+
]
|
426 |
+
if len(token_list) == 0:
|
427 |
+
token_list = [0]
|
428 |
+
token_list_padded = torch.tensor(token_list,
|
429 |
+
dtype=int).to(self.device)
|
430 |
+
motion_tokens.append(token_list_padded)
|
431 |
+
output_string.append(motion_string[i].replace(
|
432 |
+
string, '<Motion_Placeholder>'))
|
433 |
+
|
434 |
+
return motion_tokens, output_string
|
435 |
+
|
436 |
+
def placeholder_fulfill(self, prompt: str, length: int, motion_string: str,
|
437 |
+
text: str):
|
438 |
+
|
439 |
+
seconds = math.floor(length / self.framerate)
|
440 |
+
motion_splited = motion_string.split('>')
|
441 |
+
token_length = length / self.down_t
|
442 |
+
predict_head = int(token_length * self.predict_ratio + 1)
|
443 |
+
masked_head = int(token_length * self.inbetween_ratio + 1)
|
444 |
+
masked_tail = int(token_length * (1 - self.inbetween_ratio) + 1)
|
445 |
+
|
446 |
+
motion_predict_head = '>'.join(
|
447 |
+
motion_splited[:predict_head]
|
448 |
+
) + f'><motion_id_{self.m_codebook_size+1}>'
|
449 |
+
motion_predict_last = f'<motion_id_{self.m_codebook_size}>' + '>'.join(
|
450 |
+
motion_splited[predict_head:])
|
451 |
+
|
452 |
+
motion_masked = '>'.join(
|
453 |
+
motion_splited[:masked_head]
|
454 |
+
) + '>' + f'<motion_id_{self.m_codebook_size+2}>' * (
|
455 |
+
masked_tail - masked_head) + '>'.join(motion_splited[masked_tail:])
|
456 |
+
|
457 |
+
if random.random() < self.quota_ratio:
|
458 |
+
text = f'\"{text}\"'
|
459 |
+
|
460 |
+
prompt = prompt.replace('<Caption_Placeholder>', text).replace(
|
461 |
+
'<Motion_Placeholder>',
|
462 |
+
motion_string).replace('<Frame_Placeholder>', f'{length}').replace(
|
463 |
+
'<Second_Placeholder>', '%.1f' % seconds).replace(
|
464 |
+
'<Motion_Placeholder_s1>', motion_predict_head).replace(
|
465 |
+
'<Motion_Placeholder_s2>',
|
466 |
+
motion_predict_last).replace(
|
467 |
+
'<Motion_Placeholder_Masked>', motion_masked)
|
468 |
+
|
469 |
+
return prompt
|
470 |
+
|
471 |
+
def template_fulfill(self,
|
472 |
+
tasks,
|
473 |
+
lengths,
|
474 |
+
motion_strings,
|
475 |
+
texts,
|
476 |
+
stage='test'):
|
477 |
+
inputs = []
|
478 |
+
outputs = []
|
479 |
+
for i in range(len(lengths)):
|
480 |
+
input_template = random.choice(tasks[i]['input'])
|
481 |
+
output_template = random.choice(tasks[i]['output'])
|
482 |
+
length = lengths[i]
|
483 |
+
inputs.append(
|
484 |
+
self.placeholder_fulfill(input_template, length,
|
485 |
+
motion_strings[i], texts[i]))
|
486 |
+
outputs.append(
|
487 |
+
self.placeholder_fulfill(output_template, length,
|
488 |
+
motion_strings[i], texts[i]))
|
489 |
+
|
490 |
+
return inputs, outputs
|
491 |
+
|
492 |
+
def get_middle_str(self, content, startStr, endStr):
|
493 |
+
try:
|
494 |
+
startIndex = content.index(startStr)
|
495 |
+
if startIndex >= 0:
|
496 |
+
startIndex += len(startStr)
|
497 |
+
endIndex = content.index(endStr)
|
498 |
+
except:
|
499 |
+
return f'<motion_id_{self.m_codebook_size}><motion_id_0><motion_id_{self.m_codebook_size+1}>'
|
500 |
+
|
501 |
+
return f'<motion_id_{self.m_codebook_size}>' + content[
|
502 |
+
startIndex:endIndex] + f'<motion_id_{self.m_codebook_size+1}>'
|
503 |
+
|
504 |
+
def random_spans_noise_mask(self, length):
|
505 |
+
# From https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py
|
506 |
+
|
507 |
+
orig_length = length
|
508 |
+
|
509 |
+
num_noise_tokens = int(np.round(length * self.noise_density))
|
510 |
+
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
|
511 |
+
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
|
512 |
+
num_noise_spans = int(
|
513 |
+
np.round(num_noise_tokens / self.mean_noise_span_length))
|
514 |
+
|
515 |
+
# avoid degeneracy by ensuring positive number of noise spans
|
516 |
+
num_noise_spans = max(num_noise_spans, 1)
|
517 |
+
num_nonnoise_tokens = length - num_noise_tokens
|
518 |
+
|
519 |
+
# pick the lengths of the noise spans and the non-noise spans
|
520 |
+
def _random_segmentation(num_items, num_segments):
|
521 |
+
"""Partition a sequence of items randomly into non-empty segments.
|
522 |
+
Args:
|
523 |
+
num_items: an integer scalar > 0
|
524 |
+
num_segments: an integer scalar in [1, num_items]
|
525 |
+
Returns:
|
526 |
+
a Tensor with shape [num_segments] containing positive integers that add
|
527 |
+
up to num_items
|
528 |
+
"""
|
529 |
+
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
|
530 |
+
np.random.shuffle(mask_indices)
|
531 |
+
first_in_segment = np.pad(mask_indices, [[1, 0]])
|
532 |
+
segment_id = np.cumsum(first_in_segment)
|
533 |
+
# count length of sub segments assuming that list is sorted
|
534 |
+
_, segment_length = np.unique(segment_id, return_counts=True)
|
535 |
+
return segment_length
|
536 |
+
|
537 |
+
noise_span_lengths = _random_segmentation(num_noise_tokens,
|
538 |
+
num_noise_spans)
|
539 |
+
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens,
|
540 |
+
num_noise_spans)
|
541 |
+
|
542 |
+
interleaved_span_lengths = np.reshape(
|
543 |
+
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
|
544 |
+
[num_noise_spans * 2],
|
545 |
+
)
|
546 |
+
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
|
547 |
+
span_start_indicator = np.zeros((length, ), dtype=np.int8)
|
548 |
+
span_start_indicator[span_starts] = True
|
549 |
+
span_num = np.cumsum(span_start_indicator)
|
550 |
+
is_noise = np.equal(span_num % 2, 1)
|
551 |
+
|
552 |
+
return is_noise[:orig_length]
|
553 |
+
|
554 |
+
def create_sentinel_ids(self, mask_indices):
|
555 |
+
# From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
|
556 |
+
start_indices = mask_indices - np.roll(mask_indices, 1,
|
557 |
+
axis=-1) * mask_indices
|
558 |
+
start_indices[:, 0] = mask_indices[:, 0]
|
559 |
+
|
560 |
+
sentinel_ids = np.where(start_indices != 0,
|
561 |
+
np.cumsum(start_indices, axis=-1),
|
562 |
+
start_indices)
|
563 |
+
sentinel_ids = np.where(sentinel_ids != 0,
|
564 |
+
(len(self.tokenizer) - sentinel_ids), 0)
|
565 |
+
sentinel_ids -= mask_indices - start_indices
|
566 |
+
|
567 |
+
return sentinel_ids
|
568 |
+
|
569 |
+
def filter_input_ids(self, input_ids, sentinel_ids):
|
570 |
+
# From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
|
571 |
+
batch_size = input_ids.shape[0]
|
572 |
+
|
573 |
+
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids,
|
574 |
+
input_ids.to('cpu'))
|
575 |
+
|
576 |
+
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
|
577 |
+
# masked tokens coming after sentinel tokens and should be removed
|
578 |
+
input_ids = input_ids_full[input_ids_full >= 0].reshape(
|
579 |
+
(batch_size, -1))
|
580 |
+
input_ids = np.concatenate(
|
581 |
+
[
|
582 |
+
input_ids,
|
583 |
+
np.full((batch_size, 1),
|
584 |
+
self.tokenizer.eos_token_id,
|
585 |
+
dtype=np.int32),
|
586 |
+
],
|
587 |
+
axis=-1,
|
588 |
+
)
|
589 |
+
|
590 |
+
input_ids = torch.tensor(input_ids, device=self.device)
|
591 |
+
|
592 |
+
return input_ids
|
mGPT/archs/mgpt_vq.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Partially from https://github.com/Mael-zys/T2M-GPT
|
2 |
+
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch import Tensor, nn
|
7 |
+
from torch.distributions.distribution import Distribution
|
8 |
+
from .tools.resnet import Resnet1D
|
9 |
+
from .tools.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset
|
10 |
+
from collections import OrderedDict
|
11 |
+
|
12 |
+
|
13 |
+
class VQVae(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
nfeats: int,
|
17 |
+
quantizer: str = "ema_reset",
|
18 |
+
code_num=512,
|
19 |
+
code_dim=512,
|
20 |
+
output_emb_width=512,
|
21 |
+
down_t=3,
|
22 |
+
stride_t=2,
|
23 |
+
width=512,
|
24 |
+
depth=3,
|
25 |
+
dilation_growth_rate=3,
|
26 |
+
norm=None,
|
27 |
+
activation: str = "relu",
|
28 |
+
**kwargs) -> None:
|
29 |
+
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.code_dim = code_dim
|
33 |
+
|
34 |
+
self.encoder = Encoder(nfeats,
|
35 |
+
output_emb_width,
|
36 |
+
down_t,
|
37 |
+
stride_t,
|
38 |
+
width,
|
39 |
+
depth,
|
40 |
+
dilation_growth_rate,
|
41 |
+
activation=activation,
|
42 |
+
norm=norm)
|
43 |
+
|
44 |
+
self.decoder = Decoder(nfeats,
|
45 |
+
output_emb_width,
|
46 |
+
down_t,
|
47 |
+
stride_t,
|
48 |
+
width,
|
49 |
+
depth,
|
50 |
+
dilation_growth_rate,
|
51 |
+
activation=activation,
|
52 |
+
norm=norm)
|
53 |
+
|
54 |
+
if quantizer == "ema_reset":
|
55 |
+
self.quantizer = QuantizeEMAReset(code_num, code_dim, mu=0.99)
|
56 |
+
elif quantizer == "orig":
|
57 |
+
self.quantizer = Quantizer(code_num, code_dim, beta=1.0)
|
58 |
+
elif quantizer == "ema":
|
59 |
+
self.quantizer = QuantizeEMA(code_num, code_dim, mu=0.99)
|
60 |
+
elif quantizer == "reset":
|
61 |
+
self.quantizer = QuantizeReset(code_num, code_dim)
|
62 |
+
|
63 |
+
def preprocess(self, x):
|
64 |
+
# (bs, T, Jx3) -> (bs, Jx3, T)
|
65 |
+
x = x.permute(0, 2, 1)
|
66 |
+
return x
|
67 |
+
|
68 |
+
def postprocess(self, x):
|
69 |
+
# (bs, Jx3, T) -> (bs, T, Jx3)
|
70 |
+
x = x.permute(0, 2, 1)
|
71 |
+
return x
|
72 |
+
|
73 |
+
def forward(self, features: Tensor):
|
74 |
+
# Preprocess
|
75 |
+
x_in = self.preprocess(features)
|
76 |
+
|
77 |
+
# Encode
|
78 |
+
x_encoder = self.encoder(x_in)
|
79 |
+
|
80 |
+
# quantization
|
81 |
+
x_quantized, loss, perplexity = self.quantizer(x_encoder)
|
82 |
+
|
83 |
+
# decoder
|
84 |
+
x_decoder = self.decoder(x_quantized)
|
85 |
+
x_out = self.postprocess(x_decoder)
|
86 |
+
|
87 |
+
return x_out, loss, perplexity
|
88 |
+
|
89 |
+
def encode(
|
90 |
+
self,
|
91 |
+
features: Tensor,
|
92 |
+
) -> Union[Tensor, Distribution]:
|
93 |
+
|
94 |
+
N, T, _ = features.shape
|
95 |
+
x_in = self.preprocess(features)
|
96 |
+
x_encoder = self.encoder(x_in)
|
97 |
+
x_encoder = self.postprocess(x_encoder)
|
98 |
+
x_encoder = x_encoder.contiguous().view(-1,
|
99 |
+
x_encoder.shape[-1]) # (NT, C)
|
100 |
+
code_idx = self.quantizer.quantize(x_encoder)
|
101 |
+
code_idx = code_idx.view(N, -1)
|
102 |
+
|
103 |
+
# latent, dist
|
104 |
+
return code_idx, None
|
105 |
+
|
106 |
+
def decode(self, z: Tensor):
|
107 |
+
|
108 |
+
x_d = self.quantizer.dequantize(z)
|
109 |
+
x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
|
110 |
+
|
111 |
+
# decoder
|
112 |
+
x_decoder = self.decoder(x_d)
|
113 |
+
x_out = self.postprocess(x_decoder)
|
114 |
+
return x_out
|
115 |
+
|
116 |
+
|
117 |
+
class Encoder(nn.Module):
|
118 |
+
|
119 |
+
def __init__(self,
|
120 |
+
input_emb_width=3,
|
121 |
+
output_emb_width=512,
|
122 |
+
down_t=3,
|
123 |
+
stride_t=2,
|
124 |
+
width=512,
|
125 |
+
depth=3,
|
126 |
+
dilation_growth_rate=3,
|
127 |
+
activation='relu',
|
128 |
+
norm=None):
|
129 |
+
super().__init__()
|
130 |
+
|
131 |
+
blocks = []
|
132 |
+
filter_t, pad_t = stride_t * 2, stride_t // 2
|
133 |
+
blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1))
|
134 |
+
blocks.append(nn.ReLU())
|
135 |
+
|
136 |
+
for i in range(down_t):
|
137 |
+
input_dim = width
|
138 |
+
block = nn.Sequential(
|
139 |
+
nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t),
|
140 |
+
Resnet1D(width,
|
141 |
+
depth,
|
142 |
+
dilation_growth_rate,
|
143 |
+
activation=activation,
|
144 |
+
norm=norm),
|
145 |
+
)
|
146 |
+
blocks.append(block)
|
147 |
+
blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1))
|
148 |
+
self.model = nn.Sequential(*blocks)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
return self.model(x)
|
152 |
+
|
153 |
+
|
154 |
+
class Decoder(nn.Module):
|
155 |
+
|
156 |
+
def __init__(self,
|
157 |
+
input_emb_width=3,
|
158 |
+
output_emb_width=512,
|
159 |
+
down_t=3,
|
160 |
+
stride_t=2,
|
161 |
+
width=512,
|
162 |
+
depth=3,
|
163 |
+
dilation_growth_rate=3,
|
164 |
+
activation='relu',
|
165 |
+
norm=None):
|
166 |
+
super().__init__()
|
167 |
+
blocks = []
|
168 |
+
|
169 |
+
filter_t, pad_t = stride_t * 2, stride_t // 2
|
170 |
+
blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1))
|
171 |
+
blocks.append(nn.ReLU())
|
172 |
+
for i in range(down_t):
|
173 |
+
out_dim = width
|
174 |
+
block = nn.Sequential(
|
175 |
+
Resnet1D(width,
|
176 |
+
depth,
|
177 |
+
dilation_growth_rate,
|
178 |
+
reverse_dilation=True,
|
179 |
+
activation=activation,
|
180 |
+
norm=norm), nn.Upsample(scale_factor=2,
|
181 |
+
mode='nearest'),
|
182 |
+
nn.Conv1d(width, out_dim, 3, 1, 1))
|
183 |
+
blocks.append(block)
|
184 |
+
blocks.append(nn.Conv1d(width, width, 3, 1, 1))
|
185 |
+
blocks.append(nn.ReLU())
|
186 |
+
blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1))
|
187 |
+
self.model = nn.Sequential(*blocks)
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
return self.model(x)
|
mGPT/archs/tm2t_evaluator.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
4 |
+
|
5 |
+
|
6 |
+
class MovementConvEncoder(nn.Module):
|
7 |
+
def __init__(self, input_size, hidden_size, output_size):
|
8 |
+
super(MovementConvEncoder, self).__init__()
|
9 |
+
self.main = nn.Sequential(
|
10 |
+
nn.Conv1d(input_size, hidden_size, 4, 2, 1),
|
11 |
+
nn.Dropout(0.2, inplace=True),
|
12 |
+
nn.LeakyReLU(0.2, inplace=True),
|
13 |
+
nn.Conv1d(hidden_size, output_size, 4, 2, 1),
|
14 |
+
nn.Dropout(0.2, inplace=True),
|
15 |
+
nn.LeakyReLU(0.2, inplace=True),
|
16 |
+
)
|
17 |
+
self.out_net = nn.Linear(output_size, output_size)
|
18 |
+
# self.main.apply(init_weight)
|
19 |
+
# self.out_net.apply(init_weight)
|
20 |
+
|
21 |
+
def forward(self, inputs):
|
22 |
+
inputs = inputs.permute(0, 2, 1)
|
23 |
+
outputs = self.main(inputs).permute(0, 2, 1)
|
24 |
+
# print(outputs.shape)
|
25 |
+
return self.out_net(outputs)
|
26 |
+
|
27 |
+
|
28 |
+
class MotionEncoderBiGRUCo(nn.Module):
|
29 |
+
def __init__(self, input_size, hidden_size, output_size):
|
30 |
+
super(MotionEncoderBiGRUCo, self).__init__()
|
31 |
+
|
32 |
+
self.input_emb = nn.Linear(input_size, hidden_size)
|
33 |
+
self.gru = nn.GRU(
|
34 |
+
hidden_size, hidden_size, batch_first=True, bidirectional=True
|
35 |
+
)
|
36 |
+
self.output_net = nn.Sequential(
|
37 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
38 |
+
nn.LayerNorm(hidden_size),
|
39 |
+
nn.LeakyReLU(0.2, inplace=True),
|
40 |
+
nn.Linear(hidden_size, output_size),
|
41 |
+
)
|
42 |
+
|
43 |
+
# self.input_emb.apply(init_weight)
|
44 |
+
# self.output_net.apply(init_weight)
|
45 |
+
self.hidden_size = hidden_size
|
46 |
+
self.hidden = nn.Parameter(
|
47 |
+
torch.randn((2, 1, self.hidden_size), requires_grad=True)
|
48 |
+
)
|
49 |
+
|
50 |
+
# input(batch_size, seq_len, dim)
|
51 |
+
def forward(self, inputs, m_lens):
|
52 |
+
num_samples = inputs.shape[0]
|
53 |
+
|
54 |
+
input_embs = self.input_emb(inputs)
|
55 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
56 |
+
|
57 |
+
cap_lens = m_lens.data.tolist()
|
58 |
+
|
59 |
+
# emb = pack_padded_sequence(input=input_embs, lengths=cap_lens, batch_first=True)
|
60 |
+
emb = input_embs
|
61 |
+
|
62 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
63 |
+
|
64 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
65 |
+
|
66 |
+
return self.output_net(gru_last)
|
67 |
+
|
68 |
+
|
69 |
+
class TextEncoderBiGRUCo(nn.Module):
|
70 |
+
def __init__(self, word_size, pos_size, hidden_size, output_size):
|
71 |
+
super(TextEncoderBiGRUCo, self).__init__()
|
72 |
+
|
73 |
+
self.pos_emb = nn.Linear(pos_size, word_size)
|
74 |
+
self.input_emb = nn.Linear(word_size, hidden_size)
|
75 |
+
self.gru = nn.GRU(
|
76 |
+
hidden_size, hidden_size, batch_first=True, bidirectional=True
|
77 |
+
)
|
78 |
+
self.output_net = nn.Sequential(
|
79 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
80 |
+
nn.LayerNorm(hidden_size),
|
81 |
+
nn.LeakyReLU(0.2, inplace=True),
|
82 |
+
nn.Linear(hidden_size, output_size),
|
83 |
+
)
|
84 |
+
|
85 |
+
# self.input_emb.apply(init_weight)
|
86 |
+
# self.pos_emb.apply(init_weight)
|
87 |
+
# self.output_net.apply(init_weight)
|
88 |
+
# self.linear2.apply(init_weight)
|
89 |
+
# self.batch_size = batch_size
|
90 |
+
self.hidden_size = hidden_size
|
91 |
+
self.hidden = nn.Parameter(
|
92 |
+
torch.randn((2, 1, self.hidden_size), requires_grad=True)
|
93 |
+
)
|
94 |
+
|
95 |
+
# input(batch_size, seq_len, dim)
|
96 |
+
def forward(self, word_embs, pos_onehot, cap_lens):
|
97 |
+
num_samples = word_embs.shape[0]
|
98 |
+
|
99 |
+
pos_embs = self.pos_emb(pos_onehot)
|
100 |
+
inputs = word_embs + pos_embs
|
101 |
+
input_embs = self.input_emb(inputs)
|
102 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
103 |
+
|
104 |
+
cap_lens = cap_lens.data.tolist()
|
105 |
+
emb = pack_padded_sequence(input=input_embs, lengths=cap_lens, batch_first=True)
|
106 |
+
|
107 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
108 |
+
|
109 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
110 |
+
|
111 |
+
return self.output_net(gru_last)
|
mGPT/archs/tools/embeddings.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is taken from signjoey repository
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor, nn
|
6 |
+
|
7 |
+
|
8 |
+
def get_activation(activation_type):
|
9 |
+
if activation_type == "relu":
|
10 |
+
return nn.ReLU()
|
11 |
+
elif activation_type == "relu6":
|
12 |
+
return nn.ReLU6()
|
13 |
+
elif activation_type == "prelu":
|
14 |
+
return nn.PReLU()
|
15 |
+
elif activation_type == "selu":
|
16 |
+
return nn.SELU()
|
17 |
+
elif activation_type == "celu":
|
18 |
+
return nn.CELU()
|
19 |
+
elif activation_type == "gelu":
|
20 |
+
return nn.GELU()
|
21 |
+
elif activation_type == "sigmoid":
|
22 |
+
return nn.Sigmoid()
|
23 |
+
elif activation_type == "softplus":
|
24 |
+
return nn.Softplus()
|
25 |
+
elif activation_type == "softshrink":
|
26 |
+
return nn.Softshrink()
|
27 |
+
elif activation_type == "softsign":
|
28 |
+
return nn.Softsign()
|
29 |
+
elif activation_type == "tanh":
|
30 |
+
return nn.Tanh()
|
31 |
+
elif activation_type == "tanhshrink":
|
32 |
+
return nn.Tanhshrink()
|
33 |
+
else:
|
34 |
+
raise ValueError("Unknown activation type {}".format(activation_type))
|
35 |
+
|
36 |
+
|
37 |
+
class MaskedNorm(nn.Module):
|
38 |
+
"""
|
39 |
+
Original Code from:
|
40 |
+
https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, norm_type, num_groups, num_features):
|
44 |
+
super().__init__()
|
45 |
+
self.norm_type = norm_type
|
46 |
+
if self.norm_type == "batch":
|
47 |
+
self.norm = nn.BatchNorm1d(num_features=num_features)
|
48 |
+
elif self.norm_type == "group":
|
49 |
+
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features)
|
50 |
+
elif self.norm_type == "layer":
|
51 |
+
self.norm = nn.LayerNorm(normalized_shape=num_features)
|
52 |
+
else:
|
53 |
+
raise ValueError("Unsupported Normalization Layer")
|
54 |
+
|
55 |
+
self.num_features = num_features
|
56 |
+
|
57 |
+
def forward(self, x: Tensor, mask: Tensor):
|
58 |
+
if self.training:
|
59 |
+
reshaped = x.reshape([-1, self.num_features])
|
60 |
+
reshaped_mask = mask.reshape([-1, 1]) > 0
|
61 |
+
selected = torch.masked_select(reshaped, reshaped_mask).reshape(
|
62 |
+
[-1, self.num_features]
|
63 |
+
)
|
64 |
+
batch_normed = self.norm(selected)
|
65 |
+
scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
|
66 |
+
return scattered.reshape([x.shape[0], -1, self.num_features])
|
67 |
+
else:
|
68 |
+
reshaped = x.reshape([-1, self.num_features])
|
69 |
+
batched_normed = self.norm(reshaped)
|
70 |
+
return batched_normed.reshape([x.shape[0], -1, self.num_features])
|
71 |
+
|
72 |
+
|
73 |
+
# TODO (Cihan): Spatial and Word Embeddings are pretty much the same
|
74 |
+
# We might as well convert them into a single module class.
|
75 |
+
# Only difference is the lut vs linear layers.
|
76 |
+
class Embeddings(nn.Module):
|
77 |
+
|
78 |
+
"""
|
79 |
+
Simple embeddings class
|
80 |
+
"""
|
81 |
+
|
82 |
+
# pylint: disable=unused-argument
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
embedding_dim: int = 64,
|
86 |
+
num_heads: int = 8,
|
87 |
+
scale: bool = False,
|
88 |
+
scale_factor: float = None,
|
89 |
+
norm_type: str = None,
|
90 |
+
activation_type: str = None,
|
91 |
+
vocab_size: int = 0,
|
92 |
+
padding_idx: int = 1,
|
93 |
+
freeze: bool = False,
|
94 |
+
**kwargs
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
Create new embeddings for the vocabulary.
|
98 |
+
Use scaling for the Transformer.
|
99 |
+
|
100 |
+
:param embedding_dim:
|
101 |
+
:param scale:
|
102 |
+
:param vocab_size:
|
103 |
+
:param padding_idx:
|
104 |
+
:param freeze: freeze the embeddings during training
|
105 |
+
"""
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
self.embedding_dim = embedding_dim
|
109 |
+
self.vocab_size = vocab_size
|
110 |
+
self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx)
|
111 |
+
|
112 |
+
self.norm_type = norm_type
|
113 |
+
if self.norm_type:
|
114 |
+
self.norm = MaskedNorm(
|
115 |
+
norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim
|
116 |
+
)
|
117 |
+
|
118 |
+
self.activation_type = activation_type
|
119 |
+
if self.activation_type:
|
120 |
+
self.activation = get_activation(activation_type)
|
121 |
+
|
122 |
+
self.scale = scale
|
123 |
+
if self.scale:
|
124 |
+
if scale_factor:
|
125 |
+
self.scale_factor = scale_factor
|
126 |
+
else:
|
127 |
+
self.scale_factor = math.sqrt(self.embedding_dim)
|
128 |
+
|
129 |
+
if freeze:
|
130 |
+
freeze_params(self)
|
131 |
+
|
132 |
+
# pylint: disable=arguments-differ
|
133 |
+
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
|
134 |
+
"""
|
135 |
+
Perform lookup for input `x` in the embedding table.
|
136 |
+
|
137 |
+
:param mask: token masks
|
138 |
+
:param x: index in the vocabulary
|
139 |
+
:return: embedded representation for `x`
|
140 |
+
"""
|
141 |
+
|
142 |
+
x = self.lut(x)
|
143 |
+
|
144 |
+
if self.norm_type:
|
145 |
+
x = self.norm(x, mask)
|
146 |
+
|
147 |
+
if self.activation_type:
|
148 |
+
x = self.activation(x)
|
149 |
+
|
150 |
+
if self.scale:
|
151 |
+
return x * self.scale_factor
|
152 |
+
else:
|
153 |
+
return x
|
154 |
+
|
155 |
+
def __repr__(self):
|
156 |
+
return "%s(embedding_dim=%d, vocab_size=%d)" % (
|
157 |
+
self.__class__.__name__,
|
158 |
+
self.embedding_dim,
|
159 |
+
self.vocab_size,
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
class SpatialEmbeddings(nn.Module):
|
164 |
+
|
165 |
+
"""
|
166 |
+
Simple Linear Projection Layer
|
167 |
+
(For encoder outputs to predict glosses)
|
168 |
+
"""
|
169 |
+
|
170 |
+
# pylint: disable=unused-argument
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
embedding_dim: int,
|
174 |
+
input_size: int,
|
175 |
+
num_heads: int,
|
176 |
+
freeze: bool = False,
|
177 |
+
norm_type: str = "batch",
|
178 |
+
activation_type: str = "softsign",
|
179 |
+
scale: bool = False,
|
180 |
+
scale_factor: float = None,
|
181 |
+
**kwargs
|
182 |
+
):
|
183 |
+
"""
|
184 |
+
Create new embeddings for the vocabulary.
|
185 |
+
Use scaling for the Transformer.
|
186 |
+
|
187 |
+
:param embedding_dim:
|
188 |
+
:param input_size:
|
189 |
+
:param freeze: freeze the embeddings during training
|
190 |
+
"""
|
191 |
+
super().__init__()
|
192 |
+
|
193 |
+
self.embedding_dim = embedding_dim
|
194 |
+
self.input_size = input_size
|
195 |
+
self.ln = nn.Linear(self.input_size, self.embedding_dim)
|
196 |
+
|
197 |
+
self.norm_type = norm_type
|
198 |
+
if self.norm_type:
|
199 |
+
self.norm = MaskedNorm(
|
200 |
+
norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim
|
201 |
+
)
|
202 |
+
|
203 |
+
self.activation_type = activation_type
|
204 |
+
if self.activation_type:
|
205 |
+
self.activation = get_activation(activation_type)
|
206 |
+
|
207 |
+
self.scale = scale
|
208 |
+
if self.scale:
|
209 |
+
if scale_factor:
|
210 |
+
self.scale_factor = scale_factor
|
211 |
+
else:
|
212 |
+
self.scale_factor = math.sqrt(self.embedding_dim)
|
213 |
+
|
214 |
+
if freeze:
|
215 |
+
freeze_params(self)
|
216 |
+
|
217 |
+
# pylint: disable=arguments-differ
|
218 |
+
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
|
219 |
+
"""
|
220 |
+
:param mask: frame masks
|
221 |
+
:param x: input frame features
|
222 |
+
:return: embedded representation for `x`
|
223 |
+
"""
|
224 |
+
|
225 |
+
x = self.ln(x)
|
226 |
+
|
227 |
+
if self.norm_type:
|
228 |
+
x = self.norm(x, mask)
|
229 |
+
|
230 |
+
if self.activation_type:
|
231 |
+
x = self.activation(x)
|
232 |
+
|
233 |
+
if self.scale:
|
234 |
+
return x * self.scale_factor
|
235 |
+
else:
|
236 |
+
return x
|
237 |
+
|
238 |
+
def __repr__(self):
|
239 |
+
return "%s(embedding_dim=%d, input_size=%d)" % (
|
240 |
+
self.__class__.__name__,
|
241 |
+
self.embedding_dim,
|
242 |
+
self.input_size,
|
243 |
+
)
|
244 |
+
|
245 |
+
def get_timestep_embedding(
|
246 |
+
timesteps: torch.Tensor,
|
247 |
+
embedding_dim: int,
|
248 |
+
flip_sin_to_cos: bool = False,
|
249 |
+
downscale_freq_shift: float = 1,
|
250 |
+
scale: float = 1,
|
251 |
+
max_period: int = 10000,
|
252 |
+
):
|
253 |
+
"""
|
254 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
255 |
+
|
256 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
257 |
+
These may be fractional.
|
258 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
259 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
260 |
+
"""
|
261 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
262 |
+
|
263 |
+
half_dim = embedding_dim // 2
|
264 |
+
exponent = -math.log(max_period) * torch.arange(
|
265 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
266 |
+
)
|
267 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
268 |
+
|
269 |
+
emb = torch.exp(exponent)
|
270 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
271 |
+
|
272 |
+
# scale embeddings
|
273 |
+
emb = scale * emb
|
274 |
+
|
275 |
+
# concat sine and cosine embeddings
|
276 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
277 |
+
|
278 |
+
# flip sine and cosine embeddings
|
279 |
+
if flip_sin_to_cos:
|
280 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
281 |
+
|
282 |
+
# zero pad
|
283 |
+
if embedding_dim % 2 == 1:
|
284 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
285 |
+
return emb
|
286 |
+
|
287 |
+
|
288 |
+
class TimestepEmbedding(nn.Module):
|
289 |
+
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
|
290 |
+
super().__init__()
|
291 |
+
|
292 |
+
self.linear_1 = nn.Linear(channel, time_embed_dim)
|
293 |
+
self.act = None
|
294 |
+
if act_fn == "silu":
|
295 |
+
self.act = nn.SiLU()
|
296 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
297 |
+
|
298 |
+
def forward(self, sample):
|
299 |
+
sample = self.linear_1(sample)
|
300 |
+
|
301 |
+
if self.act is not None:
|
302 |
+
sample = self.act(sample)
|
303 |
+
|
304 |
+
sample = self.linear_2(sample)
|
305 |
+
return sample
|
306 |
+
|
307 |
+
|
308 |
+
class Timesteps(nn.Module):
|
309 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
310 |
+
super().__init__()
|
311 |
+
self.num_channels = num_channels
|
312 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
313 |
+
self.downscale_freq_shift = downscale_freq_shift
|
314 |
+
|
315 |
+
def forward(self, timesteps):
|
316 |
+
t_emb = get_timestep_embedding(
|
317 |
+
timesteps,
|
318 |
+
self.num_channels,
|
319 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
320 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
321 |
+
)
|
322 |
+
return t_emb
|
mGPT/archs/tools/quantize_cnn.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
class QuantizeEMAReset(nn.Module):
|
7 |
+
def __init__(self, nb_code, code_dim, mu):
|
8 |
+
super().__init__()
|
9 |
+
self.nb_code = nb_code
|
10 |
+
self.code_dim = code_dim
|
11 |
+
self.mu = mu
|
12 |
+
self.reset_codebook()
|
13 |
+
|
14 |
+
def reset_codebook(self):
|
15 |
+
self.init = False
|
16 |
+
self.code_sum = None
|
17 |
+
self.code_count = None
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).to(device))
|
20 |
+
|
21 |
+
def _tile(self, x):
|
22 |
+
nb_code_x, code_dim = x.shape
|
23 |
+
if nb_code_x < self.nb_code:
|
24 |
+
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
|
25 |
+
std = 0.01 / np.sqrt(code_dim)
|
26 |
+
out = x.repeat(n_repeats, 1)
|
27 |
+
out = out + torch.randn_like(out) * std
|
28 |
+
else :
|
29 |
+
out = x
|
30 |
+
return out
|
31 |
+
|
32 |
+
def init_codebook(self, x):
|
33 |
+
out = self._tile(x)
|
34 |
+
self.codebook = out[:self.nb_code]
|
35 |
+
self.code_sum = self.codebook.clone()
|
36 |
+
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
|
37 |
+
self.init = True
|
38 |
+
|
39 |
+
@torch.no_grad()
|
40 |
+
def compute_perplexity(self, code_idx) :
|
41 |
+
# Calculate new centres
|
42 |
+
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
|
43 |
+
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
|
44 |
+
|
45 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
46 |
+
prob = code_count / torch.sum(code_count)
|
47 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
48 |
+
return perplexity
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def update_codebook(self, x, code_idx):
|
52 |
+
|
53 |
+
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
|
54 |
+
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
|
55 |
+
|
56 |
+
code_sum = torch.matmul(code_onehot, x) # nb_code, w
|
57 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
58 |
+
|
59 |
+
out = self._tile(x)
|
60 |
+
code_rand = out[:self.nb_code]
|
61 |
+
|
62 |
+
# Update centres
|
63 |
+
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
|
64 |
+
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
|
65 |
+
|
66 |
+
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
|
67 |
+
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
|
68 |
+
|
69 |
+
self.codebook = usage * code_update + (1 - usage) * code_rand
|
70 |
+
prob = code_count / torch.sum(code_count)
|
71 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
72 |
+
|
73 |
+
|
74 |
+
return perplexity
|
75 |
+
|
76 |
+
def preprocess(self, x):
|
77 |
+
# NCT -> NTC -> [NT, C]
|
78 |
+
x = x.permute(0, 2, 1).contiguous()
|
79 |
+
x = x.view(-1, x.shape[-1])
|
80 |
+
return x
|
81 |
+
|
82 |
+
def quantize(self, x):
|
83 |
+
# Calculate latent code x_l
|
84 |
+
k_w = self.codebook.t()
|
85 |
+
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
|
86 |
+
keepdim=True) # (N * L, b)
|
87 |
+
_, code_idx = torch.min(distance, dim=-1)
|
88 |
+
return code_idx
|
89 |
+
|
90 |
+
def dequantize(self, code_idx):
|
91 |
+
x = F.embedding(code_idx, self.codebook)
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
N, width, T = x.shape
|
97 |
+
|
98 |
+
# Preprocess
|
99 |
+
x = self.preprocess(x)
|
100 |
+
|
101 |
+
# Init codebook if not inited
|
102 |
+
if self.training and not self.init:
|
103 |
+
self.init_codebook(x)
|
104 |
+
|
105 |
+
# quantize and dequantize through bottleneck
|
106 |
+
code_idx = self.quantize(x)
|
107 |
+
x_d = self.dequantize(code_idx)
|
108 |
+
|
109 |
+
# Update embeddings
|
110 |
+
if self.training:
|
111 |
+
perplexity = self.update_codebook(x, code_idx)
|
112 |
+
else :
|
113 |
+
perplexity = self.compute_perplexity(code_idx)
|
114 |
+
|
115 |
+
# Loss
|
116 |
+
commit_loss = F.mse_loss(x, x_d.detach())
|
117 |
+
|
118 |
+
# Passthrough
|
119 |
+
x_d = x + (x_d - x).detach()
|
120 |
+
|
121 |
+
# Postprocess
|
122 |
+
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
|
123 |
+
|
124 |
+
return x_d, commit_loss, perplexity
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
class Quantizer(nn.Module):
|
129 |
+
def __init__(self, n_e, e_dim, beta):
|
130 |
+
super(Quantizer, self).__init__()
|
131 |
+
|
132 |
+
self.e_dim = e_dim
|
133 |
+
self.n_e = n_e
|
134 |
+
self.beta = beta
|
135 |
+
|
136 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
137 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
138 |
+
|
139 |
+
def forward(self, z):
|
140 |
+
|
141 |
+
N, width, T = z.shape
|
142 |
+
z = self.preprocess(z)
|
143 |
+
assert z.shape[-1] == self.e_dim
|
144 |
+
z_flattened = z.contiguous().view(-1, self.e_dim)
|
145 |
+
|
146 |
+
# B x V
|
147 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
148 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
149 |
+
torch.matmul(z_flattened, self.embedding.weight.t())
|
150 |
+
# B x 1
|
151 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
152 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
153 |
+
|
154 |
+
# compute loss for embedding
|
155 |
+
loss = torch.mean((z_q - z.detach())**2) + self.beta * \
|
156 |
+
torch.mean((z_q.detach() - z)**2)
|
157 |
+
|
158 |
+
# preserve gradients
|
159 |
+
z_q = z + (z_q - z).detach()
|
160 |
+
z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
|
161 |
+
|
162 |
+
min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype)
|
163 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
164 |
+
perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10)))
|
165 |
+
return z_q, loss, perplexity
|
166 |
+
|
167 |
+
def quantize(self, z):
|
168 |
+
|
169 |
+
assert z.shape[-1] == self.e_dim
|
170 |
+
|
171 |
+
# B x V
|
172 |
+
d = torch.sum(z ** 2, dim=1, keepdim=True) + \
|
173 |
+
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
|
174 |
+
torch.matmul(z, self.embedding.weight.t())
|
175 |
+
# B x 1
|
176 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
177 |
+
return min_encoding_indices
|
178 |
+
|
179 |
+
def dequantize(self, indices):
|
180 |
+
|
181 |
+
index_flattened = indices.view(-1)
|
182 |
+
z_q = self.embedding(index_flattened)
|
183 |
+
z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous()
|
184 |
+
return z_q
|
185 |
+
|
186 |
+
def preprocess(self, x):
|
187 |
+
# NCT -> NTC -> [NT, C]
|
188 |
+
x = x.permute(0, 2, 1).contiguous()
|
189 |
+
x = x.view(-1, x.shape[-1])
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
class QuantizeReset(nn.Module):
|
195 |
+
def __init__(self, nb_code, code_dim):
|
196 |
+
super().__init__()
|
197 |
+
self.nb_code = nb_code
|
198 |
+
self.code_dim = code_dim
|
199 |
+
self.reset_codebook()
|
200 |
+
self.codebook = nn.Parameter(torch.randn(nb_code, code_dim))
|
201 |
+
|
202 |
+
def reset_codebook(self):
|
203 |
+
self.init = False
|
204 |
+
self.code_count = None
|
205 |
+
|
206 |
+
def _tile(self, x):
|
207 |
+
nb_code_x, code_dim = x.shape
|
208 |
+
if nb_code_x < self.nb_code:
|
209 |
+
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
|
210 |
+
std = 0.01 / np.sqrt(code_dim)
|
211 |
+
out = x.repeat(n_repeats, 1)
|
212 |
+
out = out + torch.randn_like(out) * std
|
213 |
+
else :
|
214 |
+
out = x
|
215 |
+
return out
|
216 |
+
|
217 |
+
def init_codebook(self, x):
|
218 |
+
out = self._tile(x)
|
219 |
+
self.codebook = nn.Parameter(out[:self.nb_code])
|
220 |
+
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
|
221 |
+
self.init = True
|
222 |
+
|
223 |
+
@torch.no_grad()
|
224 |
+
def compute_perplexity(self, code_idx) :
|
225 |
+
# Calculate new centres
|
226 |
+
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
|
227 |
+
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
|
228 |
+
|
229 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
230 |
+
prob = code_count / torch.sum(code_count)
|
231 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
232 |
+
return perplexity
|
233 |
+
|
234 |
+
def update_codebook(self, x, code_idx):
|
235 |
+
|
236 |
+
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
|
237 |
+
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
|
238 |
+
|
239 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
240 |
+
|
241 |
+
out = self._tile(x)
|
242 |
+
code_rand = out[:self.nb_code]
|
243 |
+
|
244 |
+
# Update centres
|
245 |
+
self.code_count = code_count # nb_code
|
246 |
+
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
|
247 |
+
|
248 |
+
self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand
|
249 |
+
prob = code_count / torch.sum(code_count)
|
250 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
251 |
+
|
252 |
+
|
253 |
+
return perplexity
|
254 |
+
|
255 |
+
def preprocess(self, x):
|
256 |
+
# NCT -> NTC -> [NT, C]
|
257 |
+
x = x.permute(0, 2, 1).contiguous()
|
258 |
+
x = x.view(-1, x.shape[-1])
|
259 |
+
return x
|
260 |
+
|
261 |
+
def quantize(self, x):
|
262 |
+
# Calculate latent code x_l
|
263 |
+
k_w = self.codebook.t()
|
264 |
+
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
|
265 |
+
keepdim=True) # (N * L, b)
|
266 |
+
_, code_idx = torch.min(distance, dim=-1)
|
267 |
+
return code_idx
|
268 |
+
|
269 |
+
def dequantize(self, code_idx):
|
270 |
+
x = F.embedding(code_idx, self.codebook)
|
271 |
+
return x
|
272 |
+
|
273 |
+
|
274 |
+
def forward(self, x):
|
275 |
+
N, width, T = x.shape
|
276 |
+
# Preprocess
|
277 |
+
x = self.preprocess(x)
|
278 |
+
# Init codebook if not inited
|
279 |
+
if self.training and not self.init:
|
280 |
+
self.init_codebook(x)
|
281 |
+
# quantize and dequantize through bottleneck
|
282 |
+
code_idx = self.quantize(x)
|
283 |
+
x_d = self.dequantize(code_idx)
|
284 |
+
# Update embeddings
|
285 |
+
if self.training:
|
286 |
+
perplexity = self.update_codebook(x, code_idx)
|
287 |
+
else :
|
288 |
+
perplexity = self.compute_perplexity(code_idx)
|
289 |
+
|
290 |
+
# Loss
|
291 |
+
commit_loss = F.mse_loss(x, x_d.detach())
|
292 |
+
|
293 |
+
# Passthrough
|
294 |
+
x_d = x + (x_d - x).detach()
|
295 |
+
|
296 |
+
# Postprocess
|
297 |
+
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
|
298 |
+
|
299 |
+
return x_d, commit_loss, perplexity
|
300 |
+
|
301 |
+
|
302 |
+
class QuantizeEMA(nn.Module):
|
303 |
+
def __init__(self, nb_code, code_dim, mu):
|
304 |
+
super().__init__()
|
305 |
+
self.nb_code = nb_code
|
306 |
+
self.code_dim = code_dim
|
307 |
+
self.mu = mu
|
308 |
+
self.reset_codebook()
|
309 |
+
|
310 |
+
def reset_codebook(self):
|
311 |
+
self.init = False
|
312 |
+
self.code_sum = None
|
313 |
+
self.code_count = None
|
314 |
+
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda())
|
315 |
+
|
316 |
+
def _tile(self, x):
|
317 |
+
nb_code_x, code_dim = x.shape
|
318 |
+
if nb_code_x < self.nb_code:
|
319 |
+
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
|
320 |
+
std = 0.01 / np.sqrt(code_dim)
|
321 |
+
out = x.repeat(n_repeats, 1)
|
322 |
+
out = out + torch.randn_like(out) * std
|
323 |
+
else :
|
324 |
+
out = x
|
325 |
+
return out
|
326 |
+
|
327 |
+
def init_codebook(self, x):
|
328 |
+
out = self._tile(x)
|
329 |
+
self.codebook = out[:self.nb_code]
|
330 |
+
self.code_sum = self.codebook.clone()
|
331 |
+
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
|
332 |
+
self.init = True
|
333 |
+
|
334 |
+
@torch.no_grad()
|
335 |
+
def compute_perplexity(self, code_idx) :
|
336 |
+
# Calculate new centres
|
337 |
+
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
|
338 |
+
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
|
339 |
+
|
340 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
341 |
+
prob = code_count / torch.sum(code_count)
|
342 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
343 |
+
return perplexity
|
344 |
+
|
345 |
+
@torch.no_grad()
|
346 |
+
def update_codebook(self, x, code_idx):
|
347 |
+
|
348 |
+
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
|
349 |
+
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
|
350 |
+
|
351 |
+
code_sum = torch.matmul(code_onehot, x) # nb_code, w
|
352 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
353 |
+
|
354 |
+
# Update centres
|
355 |
+
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
|
356 |
+
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
|
357 |
+
|
358 |
+
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
|
359 |
+
|
360 |
+
self.codebook = code_update
|
361 |
+
prob = code_count / torch.sum(code_count)
|
362 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
363 |
+
|
364 |
+
return perplexity
|
365 |
+
|
366 |
+
def preprocess(self, x):
|
367 |
+
# NCT -> NTC -> [NT, C]
|
368 |
+
x = x.permute(0, 2, 1).contiguous()
|
369 |
+
x = x.view(-1, x.shape[-1])
|
370 |
+
return x
|
371 |
+
|
372 |
+
def quantize(self, x):
|
373 |
+
# Calculate latent code x_l
|
374 |
+
k_w = self.codebook.t()
|
375 |
+
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
|
376 |
+
keepdim=True) # (N * L, b)
|
377 |
+
_, code_idx = torch.min(distance, dim=-1)
|
378 |
+
return code_idx
|
379 |
+
|
380 |
+
def dequantize(self, code_idx):
|
381 |
+
x = F.embedding(code_idx, self.codebook)
|
382 |
+
return x
|
383 |
+
|
384 |
+
|
385 |
+
def forward(self, x):
|
386 |
+
N, width, T = x.shape
|
387 |
+
|
388 |
+
# Preprocess
|
389 |
+
x = self.preprocess(x)
|
390 |
+
|
391 |
+
# Init codebook if not inited
|
392 |
+
if self.training and not self.init:
|
393 |
+
self.init_codebook(x)
|
394 |
+
|
395 |
+
# quantize and dequantize through bottleneck
|
396 |
+
code_idx = self.quantize(x)
|
397 |
+
x_d = self.dequantize(code_idx)
|
398 |
+
|
399 |
+
# Update embeddings
|
400 |
+
if self.training:
|
401 |
+
perplexity = self.update_codebook(x, code_idx)
|
402 |
+
else :
|
403 |
+
perplexity = self.compute_perplexity(code_idx)
|
404 |
+
|
405 |
+
# Loss
|
406 |
+
commit_loss = F.mse_loss(x, x_d.detach())
|
407 |
+
|
408 |
+
# Passthrough
|
409 |
+
x_d = x + (x_d - x).detach()
|
410 |
+
|
411 |
+
# Postprocess
|
412 |
+
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
|
413 |
+
|
414 |
+
return x_d, commit_loss, perplexity
|
mGPT/archs/tools/resnet.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class nonlinearity(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
def forward(self, x):
|
9 |
+
# swish
|
10 |
+
return x * torch.sigmoid(x)
|
11 |
+
|
12 |
+
class ResConv1DBlock(nn.Module):
|
13 |
+
def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None):
|
14 |
+
super().__init__()
|
15 |
+
padding = dilation
|
16 |
+
self.norm = norm
|
17 |
+
if norm == "LN":
|
18 |
+
self.norm1 = nn.LayerNorm(n_in)
|
19 |
+
self.norm2 = nn.LayerNorm(n_in)
|
20 |
+
elif norm == "GN":
|
21 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
|
22 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
|
23 |
+
elif norm == "BN":
|
24 |
+
self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
|
25 |
+
self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
|
26 |
+
|
27 |
+
else:
|
28 |
+
self.norm1 = nn.Identity()
|
29 |
+
self.norm2 = nn.Identity()
|
30 |
+
|
31 |
+
if activation == "relu":
|
32 |
+
self.activation1 = nn.ReLU()
|
33 |
+
self.activation2 = nn.ReLU()
|
34 |
+
|
35 |
+
elif activation == "silu":
|
36 |
+
self.activation1 = nonlinearity()
|
37 |
+
self.activation2 = nonlinearity()
|
38 |
+
|
39 |
+
elif activation == "gelu":
|
40 |
+
self.activation1 = nn.GELU()
|
41 |
+
self.activation2 = nn.GELU()
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation)
|
46 |
+
self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,)
|
47 |
+
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x_orig = x
|
51 |
+
if self.norm == "LN":
|
52 |
+
x = self.norm1(x.transpose(-2, -1))
|
53 |
+
x = self.activation1(x.transpose(-2, -1))
|
54 |
+
else:
|
55 |
+
x = self.norm1(x)
|
56 |
+
x = self.activation1(x)
|
57 |
+
|
58 |
+
x = self.conv1(x)
|
59 |
+
|
60 |
+
if self.norm == "LN":
|
61 |
+
x = self.norm2(x.transpose(-2, -1))
|
62 |
+
x = self.activation2(x.transpose(-2, -1))
|
63 |
+
else:
|
64 |
+
x = self.norm2(x)
|
65 |
+
x = self.activation2(x)
|
66 |
+
|
67 |
+
x = self.conv2(x)
|
68 |
+
x = x + x_orig
|
69 |
+
return x
|
70 |
+
|
71 |
+
class Resnet1D(nn.Module):
|
72 |
+
def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None):
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)]
|
76 |
+
if reverse_dilation:
|
77 |
+
blocks = blocks[::-1]
|
78 |
+
|
79 |
+
self.model = nn.Sequential(*blocks)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
return self.model(x)
|
mGPT/archs/tools/token_emb.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from torch import Tensor, nn
|
3 |
+
|
4 |
+
class NewTokenEmb(nn.Module):
|
5 |
+
"""
|
6 |
+
For adding new tokens to a pretrained model
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(self,
|
10 |
+
old_embeddings: nn.Embedding,
|
11 |
+
new_num_tokens: int = None) -> None:
|
12 |
+
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.num_tokens = old_embeddings.num_embeddings + new_num_tokens
|
16 |
+
self.old_num_tokens = old_embeddings.num_embeddings
|
17 |
+
self.new_num_tokens = new_num_tokens
|
18 |
+
self.embedding_dim = old_embeddings.embedding_dim
|
19 |
+
|
20 |
+
# For text embeddings
|
21 |
+
self.text_embeddings = nn.Embedding(
|
22 |
+
self.num_tokens,
|
23 |
+
self.embedding_dim,
|
24 |
+
device=old_embeddings.weight.device,
|
25 |
+
dtype=old_embeddings.weight.dtype)
|
26 |
+
with torch.no_grad():
|
27 |
+
self.text_embeddings.weight.data[:old_embeddings.
|
28 |
+
num_embeddings] = old_embeddings.weight.data
|
29 |
+
self.text_embeddings.weight.data[
|
30 |
+
self.old_num_tokens:] = torch.zeros(
|
31 |
+
self.new_num_tokens,
|
32 |
+
self.embedding_dim,
|
33 |
+
dtype=old_embeddings.weight.dtype,
|
34 |
+
device=old_embeddings.weight.device)
|
35 |
+
self.text_embeddings.weight.requires_grad_(False)
|
36 |
+
|
37 |
+
# For motion embeddings
|
38 |
+
self.motion_embeddings = nn.Embedding(
|
39 |
+
new_num_tokens,
|
40 |
+
self.embedding_dim,
|
41 |
+
device=old_embeddings.weight.device,
|
42 |
+
dtype=old_embeddings.weight.dtype)
|
43 |
+
with torch.no_grad():
|
44 |
+
self.motion_embeddings.weight.data[:self.
|
45 |
+
old_num_tokens] = torch.zeros(
|
46 |
+
new_num_tokens,
|
47 |
+
self.embedding_dim,
|
48 |
+
dtype=old_embeddings.weight.
|
49 |
+
dtype,
|
50 |
+
device=old_embeddings.
|
51 |
+
weight.device)
|
52 |
+
self.word2motionProj = nn.Linear(self.old_num_tokens, new_num_tokens)
|
53 |
+
|
54 |
+
def forward(self, input: Tensor) -> Tensor:
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
self.motion_embeddings.weight.data[:self.
|
58 |
+
old_num_tokens] = torch.zeros(
|
59 |
+
self.new_num_tokens,
|
60 |
+
self.embedding_dim,
|
61 |
+
dtype=self.motion_embeddings
|
62 |
+
.weight.dtype,
|
63 |
+
device=self.
|
64 |
+
motion_embeddings.weight.
|
65 |
+
device)
|
66 |
+
|
67 |
+
self.motion_embeddings.weight.data[
|
68 |
+
self.old_num_tokens:] = self.word2motionProj(
|
69 |
+
self.text_embeddings.weight.data[:self.old_num_tokens].permute(
|
70 |
+
1, 0)).permute(1, 0)
|
71 |
+
|
72 |
+
return self.text_embeddings(input) + self.motion_embeddings(input)
|
73 |
+
|
mGPT/archs/tools/transformer_layers.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
# Took from https://github.com/joeynmt/joeynmt/blob/fb66afcbe1beef9acd59283bcc084c4d4c1e6343/joeynmt/transformer_layers.py
|
8 |
+
|
9 |
+
|
10 |
+
# pylint: disable=arguments-differ
|
11 |
+
class MultiHeadedAttention(nn.Module):
|
12 |
+
"""
|
13 |
+
Multi-Head Attention module from "Attention is All You Need"
|
14 |
+
|
15 |
+
Implementation modified from OpenNMT-py.
|
16 |
+
https://github.com/OpenNMT/OpenNMT-py
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_heads: int, size: int, dropout: float = 0.1):
|
20 |
+
"""
|
21 |
+
Create a multi-headed attention layer.
|
22 |
+
:param num_heads: the number of heads
|
23 |
+
:param size: model size (must be divisible by num_heads)
|
24 |
+
:param dropout: probability of dropping a unit
|
25 |
+
"""
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
assert size % num_heads == 0
|
29 |
+
|
30 |
+
self.head_size = head_size = size // num_heads
|
31 |
+
self.model_size = size
|
32 |
+
self.num_heads = num_heads
|
33 |
+
|
34 |
+
self.k_layer = nn.Linear(size, num_heads * head_size)
|
35 |
+
self.v_layer = nn.Linear(size, num_heads * head_size)
|
36 |
+
self.q_layer = nn.Linear(size, num_heads * head_size)
|
37 |
+
|
38 |
+
self.output_layer = nn.Linear(size, size)
|
39 |
+
self.softmax = nn.Softmax(dim=-1)
|
40 |
+
self.dropout = nn.Dropout(dropout)
|
41 |
+
|
42 |
+
def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None):
|
43 |
+
"""
|
44 |
+
Computes multi-headed attention.
|
45 |
+
|
46 |
+
:param k: keys [B, M, D] with M being the sentence length.
|
47 |
+
:param v: values [B, M, D]
|
48 |
+
:param q: query [B, M, D]
|
49 |
+
:param mask: optional mask [B, 1, M] or [B, M, M]
|
50 |
+
:return:
|
51 |
+
"""
|
52 |
+
batch_size = k.size(0)
|
53 |
+
num_heads = self.num_heads
|
54 |
+
|
55 |
+
# project the queries (q), keys (k), and values (v)
|
56 |
+
k = self.k_layer(k)
|
57 |
+
v = self.v_layer(v)
|
58 |
+
q = self.q_layer(q)
|
59 |
+
|
60 |
+
# reshape q, k, v for our computation to [batch_size, num_heads, ..]
|
61 |
+
k = k.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)
|
62 |
+
v = v.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)
|
63 |
+
q = q.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)
|
64 |
+
|
65 |
+
# compute scores
|
66 |
+
q = q / math.sqrt(self.head_size)
|
67 |
+
|
68 |
+
# batch x num_heads x query_len x key_len
|
69 |
+
scores = torch.matmul(q, k.transpose(2, 3))
|
70 |
+
# torch.Size([48, 8, 183, 183])
|
71 |
+
|
72 |
+
# apply the mask (if we have one)
|
73 |
+
# we add a dimension for the heads to it below: [B, 1, 1, M]
|
74 |
+
if mask is not None:
|
75 |
+
scores = scores.masked_fill(~mask.unsqueeze(1), float('-inf'))
|
76 |
+
|
77 |
+
# apply attention dropout and compute context vectors.
|
78 |
+
attention = self.softmax(scores)
|
79 |
+
attention = self.dropout(attention)
|
80 |
+
# torch.Size([48, 8, 183, 183]) [bs, nheads, time, time] (for decoding)
|
81 |
+
|
82 |
+
# v: torch.Size([48, 8, 183, 32]) (32 is 256/8)
|
83 |
+
# get context vector (select values with attention) and reshape
|
84 |
+
# back to [B, M, D]
|
85 |
+
context = torch.matmul(attention, v) # torch.Size([48, 8, 183, 32])
|
86 |
+
context = context.transpose(1, 2).contiguous().view(
|
87 |
+
batch_size, -1, num_heads * self.head_size)
|
88 |
+
# torch.Size([48, 183, 256]) put back to 256 (combine the heads)
|
89 |
+
|
90 |
+
output = self.output_layer(context)
|
91 |
+
# torch.Size([48, 183, 256]): 1 output per time step
|
92 |
+
|
93 |
+
return output
|
94 |
+
|
95 |
+
|
96 |
+
# pylint: disable=arguments-differ
|
97 |
+
class PositionwiseFeedForward(nn.Module):
|
98 |
+
"""
|
99 |
+
Position-wise Feed-forward layer
|
100 |
+
Projects to ff_size and then back down to input_size.
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, input_size, ff_size, dropout=0.1):
|
104 |
+
"""
|
105 |
+
Initializes position-wise feed-forward layer.
|
106 |
+
:param input_size: dimensionality of the input.
|
107 |
+
:param ff_size: dimensionality of intermediate representation
|
108 |
+
:param dropout:
|
109 |
+
"""
|
110 |
+
super().__init__()
|
111 |
+
self.layer_norm = nn.LayerNorm(input_size, eps=1e-6)
|
112 |
+
self.pwff_layer = nn.Sequential(
|
113 |
+
nn.Linear(input_size, ff_size),
|
114 |
+
nn.ReLU(),
|
115 |
+
nn.Dropout(dropout),
|
116 |
+
nn.Linear(ff_size, input_size),
|
117 |
+
nn.Dropout(dropout),
|
118 |
+
)
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
x_norm = self.layer_norm(x)
|
122 |
+
return self.pwff_layer(x_norm) + x
|
123 |
+
|
124 |
+
|
125 |
+
# pylint: disable=arguments-differ
|
126 |
+
class PositionalEncoding(nn.Module):
|
127 |
+
"""
|
128 |
+
Pre-compute position encodings (PE).
|
129 |
+
In forward pass, this adds the position-encodings to the
|
130 |
+
input for as many time steps as necessary.
|
131 |
+
|
132 |
+
Implementation based on OpenNMT-py.
|
133 |
+
https://github.com/OpenNMT/OpenNMT-py
|
134 |
+
"""
|
135 |
+
|
136 |
+
def __init__(self, size: int = 0, max_len: int = 5000):
|
137 |
+
"""
|
138 |
+
Positional Encoding with maximum length max_len
|
139 |
+
:param size:
|
140 |
+
:param max_len:
|
141 |
+
:param dropout:
|
142 |
+
"""
|
143 |
+
if size % 2 != 0:
|
144 |
+
raise ValueError("Cannot use sin/cos positional encoding with "
|
145 |
+
"odd dim (got dim={:d})".format(size))
|
146 |
+
pe = torch.zeros(max_len, size)
|
147 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
148 |
+
div_term = torch.exp((torch.arange(0, size, 2, dtype=torch.float) *
|
149 |
+
-(math.log(10000.0) / size)))
|
150 |
+
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
151 |
+
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
152 |
+
pe = pe.unsqueeze(0) # shape: [1, size, max_len]
|
153 |
+
super().__init__()
|
154 |
+
self.register_buffer('pe', pe)
|
155 |
+
self.dim = size
|
156 |
+
|
157 |
+
def forward(self, emb):
|
158 |
+
"""Embed inputs.
|
159 |
+
Args:
|
160 |
+
emb (FloatTensor): Sequence of word vectors
|
161 |
+
``(seq_len, batch_size, self.dim)``
|
162 |
+
"""
|
163 |
+
# Add position encodings
|
164 |
+
return emb + self.pe[:, :emb.size(1)]
|
165 |
+
|
166 |
+
|
167 |
+
class TransformerEncoderLayer(nn.Module):
|
168 |
+
"""
|
169 |
+
One Transformer encoder layer has a Multi-head attention layer plus
|
170 |
+
a position-wise feed-forward layer.
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self,
|
174 |
+
size: int = 0,
|
175 |
+
ff_size: int = 0,
|
176 |
+
num_heads: int = 0,
|
177 |
+
dropout: float = 0.1):
|
178 |
+
"""
|
179 |
+
A single Transformer layer.
|
180 |
+
:param size:
|
181 |
+
:param ff_size:
|
182 |
+
:param num_heads:
|
183 |
+
:param dropout:
|
184 |
+
"""
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
self.layer_norm = nn.LayerNorm(size, eps=1e-6)
|
188 |
+
self.src_src_att = MultiHeadedAttention(num_heads,
|
189 |
+
size,
|
190 |
+
dropout=dropout)
|
191 |
+
self.feed_forward = PositionwiseFeedForward(size,
|
192 |
+
ff_size=ff_size,
|
193 |
+
dropout=dropout)
|
194 |
+
self.dropout = nn.Dropout(dropout)
|
195 |
+
self.size = size
|
196 |
+
|
197 |
+
# pylint: disable=arguments-differ
|
198 |
+
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
|
199 |
+
"""
|
200 |
+
Forward pass for a single transformer encoder layer.
|
201 |
+
First applies layer norm, then self attention,
|
202 |
+
then dropout with residual connection (adding the input to the result),
|
203 |
+
and then a position-wise feed-forward layer.
|
204 |
+
|
205 |
+
:param x: layer input
|
206 |
+
:param mask: input mask
|
207 |
+
:return: output tensor
|
208 |
+
"""
|
209 |
+
x_norm = self.layer_norm(x)
|
210 |
+
h = self.src_src_att(x_norm, x_norm, x_norm, mask)
|
211 |
+
h = self.dropout(h) + x
|
212 |
+
o = self.feed_forward(h)
|
213 |
+
return o
|
214 |
+
|
215 |
+
|
216 |
+
class TransformerDecoderLayer(nn.Module):
|
217 |
+
"""
|
218 |
+
Transformer decoder layer.
|
219 |
+
|
220 |
+
Consists of self-attention, source-attention, and feed-forward.
|
221 |
+
"""
|
222 |
+
|
223 |
+
def __init__(self,
|
224 |
+
size: int = 0,
|
225 |
+
ff_size: int = 0,
|
226 |
+
num_heads: int = 0,
|
227 |
+
dropout: float = 0.1):
|
228 |
+
"""
|
229 |
+
Represents a single Transformer decoder layer.
|
230 |
+
|
231 |
+
It attends to the source representation and the previous decoder states.
|
232 |
+
|
233 |
+
:param size: model dimensionality
|
234 |
+
:param ff_size: size of the feed-forward intermediate layer
|
235 |
+
:param num_heads: number of heads
|
236 |
+
:param dropout: dropout to apply to input
|
237 |
+
"""
|
238 |
+
super().__init__()
|
239 |
+
self.size = size
|
240 |
+
|
241 |
+
self.trg_trg_att = MultiHeadedAttention(num_heads,
|
242 |
+
size,
|
243 |
+
dropout=dropout)
|
244 |
+
self.src_trg_att = MultiHeadedAttention(num_heads,
|
245 |
+
size,
|
246 |
+
dropout=dropout)
|
247 |
+
|
248 |
+
self.feed_forward = PositionwiseFeedForward(size,
|
249 |
+
ff_size=ff_size,
|
250 |
+
dropout=dropout)
|
251 |
+
|
252 |
+
self.x_layer_norm = nn.LayerNorm(size, eps=1e-6)
|
253 |
+
self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6)
|
254 |
+
|
255 |
+
self.dropout = nn.Dropout(dropout)
|
256 |
+
|
257 |
+
# pylint: disable=arguments-differ
|
258 |
+
def forward(self,
|
259 |
+
x: Tensor = None,
|
260 |
+
memory: Tensor = None,
|
261 |
+
src_mask: Tensor = None,
|
262 |
+
trg_mask: Tensor = None) -> Tensor:
|
263 |
+
"""
|
264 |
+
Forward pass of a single Transformer decoder layer.
|
265 |
+
|
266 |
+
:param x: inputs
|
267 |
+
:param memory: source representations
|
268 |
+
:param src_mask: source mask
|
269 |
+
:param trg_mask: target mask (so as to not condition on future steps)
|
270 |
+
:return: output tensor
|
271 |
+
"""
|
272 |
+
# decoder/target self-attention
|
273 |
+
x_norm = self.x_layer_norm(x) # torch.Size([48, 183, 256])
|
274 |
+
h1 = self.trg_trg_att(x_norm, x_norm, x_norm, mask=trg_mask)
|
275 |
+
h1 = self.dropout(h1) + x
|
276 |
+
|
277 |
+
# source-target attention
|
278 |
+
h1_norm = self.dec_layer_norm(
|
279 |
+
h1) # torch.Size([48, 183, 256]) (same for memory)
|
280 |
+
h2 = self.src_trg_att(memory, memory, h1_norm, mask=src_mask)
|
281 |
+
|
282 |
+
# final position-wise feed-forward layer
|
283 |
+
o = self.feed_forward(self.dropout(h2) + h1)
|
284 |
+
|
285 |
+
return o
|
mGPT/callback.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pytorch_lightning import LightningModule, Trainer
|
3 |
+
from pytorch_lightning.callbacks import Callback, RichProgressBar, ModelCheckpoint
|
4 |
+
|
5 |
+
|
6 |
+
def build_callbacks(cfg, logger=None, phase='test', **kwargs):
|
7 |
+
callbacks = []
|
8 |
+
logger = logger
|
9 |
+
|
10 |
+
# Rich Progress Bar
|
11 |
+
callbacks.append(progressBar())
|
12 |
+
|
13 |
+
# Checkpoint Callback
|
14 |
+
if phase == 'train':
|
15 |
+
callbacks.extend(getCheckpointCallback(cfg, logger=logger, **kwargs))
|
16 |
+
|
17 |
+
return callbacks
|
18 |
+
|
19 |
+
def getCheckpointCallback(cfg, logger=None, **kwargs):
|
20 |
+
callbacks = []
|
21 |
+
# Logging
|
22 |
+
metric_monitor = {
|
23 |
+
"loss_total": "total/train",
|
24 |
+
"Train_jf": "recons/text2jfeats/train",
|
25 |
+
"Val_jf": "recons/text2jfeats/val",
|
26 |
+
"Train_rf": "recons/text2rfeats/train",
|
27 |
+
"Val_rf": "recons/text2rfeats/val",
|
28 |
+
"APE root": "Metrics/APE_root",
|
29 |
+
"APE mean pose": "Metrics/APE_mean_pose",
|
30 |
+
"AVE root": "Metrics/AVE_root",
|
31 |
+
"AVE mean pose": "Metrics/AVE_mean_pose",
|
32 |
+
"R_TOP_1": "Metrics/R_precision_top_1",
|
33 |
+
"R_TOP_2": "Metrics/R_precision_top_2",
|
34 |
+
"R_TOP_3": "Metrics/R_precision_top_3",
|
35 |
+
"gt_R_TOP_3": "Metrics/gt_R_precision_top_3",
|
36 |
+
"FID": "Metrics/FID",
|
37 |
+
"gt_FID": "Metrics/gt_FID",
|
38 |
+
"Diversity": "Metrics/Diversity",
|
39 |
+
"MM dist": "Metrics/Matching_score",
|
40 |
+
"Accuracy": "Metrics/accuracy",
|
41 |
+
}
|
42 |
+
callbacks.append(
|
43 |
+
progressLogger(logger,metric_monitor=metric_monitor,log_every_n_steps=1))
|
44 |
+
|
45 |
+
# Save 10 latest checkpoints
|
46 |
+
checkpointParams = {
|
47 |
+
'dirpath': os.path.join(cfg.FOLDER_EXP, "checkpoints"),
|
48 |
+
'filename': "{epoch}",
|
49 |
+
'monitor': "step",
|
50 |
+
'mode': "max",
|
51 |
+
'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS,
|
52 |
+
'save_top_k': 8,
|
53 |
+
'save_last': True,
|
54 |
+
'save_on_train_epoch_end': True
|
55 |
+
}
|
56 |
+
callbacks.append(ModelCheckpoint(**checkpointParams))
|
57 |
+
|
58 |
+
# Save checkpoint every n*10 epochs
|
59 |
+
checkpointParams.update({
|
60 |
+
'every_n_epochs':
|
61 |
+
cfg.LOGGER.VAL_EVERY_STEPS * 10,
|
62 |
+
'save_top_k':
|
63 |
+
-1,
|
64 |
+
'save_last':
|
65 |
+
False
|
66 |
+
})
|
67 |
+
callbacks.append(ModelCheckpoint(**checkpointParams))
|
68 |
+
|
69 |
+
metrics = cfg.METRIC.TYPE
|
70 |
+
metric_monitor_map = {
|
71 |
+
'TemosMetric': {
|
72 |
+
'Metrics/APE_root': {
|
73 |
+
'abbr': 'APEroot',
|
74 |
+
'mode': 'min'
|
75 |
+
},
|
76 |
+
},
|
77 |
+
'TM2TMetrics': {
|
78 |
+
'Metrics/FID': {
|
79 |
+
'abbr': 'FID',
|
80 |
+
'mode': 'min'
|
81 |
+
},
|
82 |
+
'Metrics/R_precision_top_3': {
|
83 |
+
'abbr': 'R3',
|
84 |
+
'mode': 'max'
|
85 |
+
}
|
86 |
+
},
|
87 |
+
'MRMetrics': {
|
88 |
+
'Metrics/MPJPE': {
|
89 |
+
'abbr': 'MPJPE',
|
90 |
+
'mode': 'min'
|
91 |
+
}
|
92 |
+
},
|
93 |
+
'HUMANACTMetrics': {
|
94 |
+
'Metrics/Accuracy': {
|
95 |
+
'abbr': 'Accuracy',
|
96 |
+
'mode': 'max'
|
97 |
+
}
|
98 |
+
},
|
99 |
+
'UESTCMetrics': {
|
100 |
+
'Metrics/Accuracy': {
|
101 |
+
'abbr': 'Accuracy',
|
102 |
+
'mode': 'max'
|
103 |
+
}
|
104 |
+
},
|
105 |
+
'UncondMetrics': {
|
106 |
+
'Metrics/FID': {
|
107 |
+
'abbr': 'FID',
|
108 |
+
'mode': 'min'
|
109 |
+
}
|
110 |
+
}
|
111 |
+
}
|
112 |
+
|
113 |
+
checkpointParams.update({
|
114 |
+
'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS,
|
115 |
+
'save_top_k': 1,
|
116 |
+
})
|
117 |
+
|
118 |
+
for metric in metrics:
|
119 |
+
if metric in metric_monitor_map.keys():
|
120 |
+
metric_monitors = metric_monitor_map[metric]
|
121 |
+
|
122 |
+
# Delete R3 if training VAE
|
123 |
+
if cfg.TRAIN.STAGE == 'vae' and metric == 'TM2TMetrics':
|
124 |
+
del metric_monitors['Metrics/R_precision_top_3']
|
125 |
+
|
126 |
+
for metric_monitor in metric_monitors:
|
127 |
+
checkpointParams.update({
|
128 |
+
'filename':
|
129 |
+
metric_monitor_map[metric][metric_monitor]['mode']
|
130 |
+
+ "-" +
|
131 |
+
metric_monitor_map[metric][metric_monitor]['abbr']
|
132 |
+
+ "{ep}",
|
133 |
+
'monitor':
|
134 |
+
metric_monitor,
|
135 |
+
'mode':
|
136 |
+
metric_monitor_map[metric][metric_monitor]['mode'],
|
137 |
+
})
|
138 |
+
callbacks.append(
|
139 |
+
ModelCheckpoint(**checkpointParams))
|
140 |
+
return callbacks
|
141 |
+
|
142 |
+
class progressBar(RichProgressBar):
|
143 |
+
def __init__(self, ):
|
144 |
+
super().__init__()
|
145 |
+
|
146 |
+
def get_metrics(self, trainer, model):
|
147 |
+
# Don't show the version number
|
148 |
+
items = super().get_metrics(trainer, model)
|
149 |
+
items.pop("v_num", None)
|
150 |
+
return items
|
151 |
+
|
152 |
+
class progressLogger(Callback):
|
153 |
+
def __init__(self,
|
154 |
+
logger,
|
155 |
+
metric_monitor: dict,
|
156 |
+
precision: int = 3,
|
157 |
+
log_every_n_steps: int = 1):
|
158 |
+
# Metric to monitor
|
159 |
+
self.logger = logger
|
160 |
+
self.metric_monitor = metric_monitor
|
161 |
+
self.precision = precision
|
162 |
+
self.log_every_n_steps = log_every_n_steps
|
163 |
+
|
164 |
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule,
|
165 |
+
**kwargs) -> None:
|
166 |
+
self.logger.info("Training started")
|
167 |
+
|
168 |
+
def on_train_end(self, trainer: Trainer, pl_module: LightningModule,
|
169 |
+
**kwargs) -> None:
|
170 |
+
self.logger.info("Training done")
|
171 |
+
|
172 |
+
def on_validation_epoch_end(self, trainer: Trainer,
|
173 |
+
pl_module: LightningModule, **kwargs) -> None:
|
174 |
+
if trainer.sanity_checking:
|
175 |
+
self.logger.info("Sanity checking ok.")
|
176 |
+
|
177 |
+
def on_train_epoch_end(self,
|
178 |
+
trainer: Trainer,
|
179 |
+
pl_module: LightningModule,
|
180 |
+
padding=False,
|
181 |
+
**kwargs) -> None:
|
182 |
+
metric_format = f"{{:.{self.precision}e}}"
|
183 |
+
line = f"Epoch {trainer.current_epoch}"
|
184 |
+
if padding:
|
185 |
+
line = f"{line:>{len('Epoch xxxx')}}" # Right padding
|
186 |
+
|
187 |
+
if trainer.current_epoch % self.log_every_n_steps == 0:
|
188 |
+
metrics_str = []
|
189 |
+
|
190 |
+
losses_dict = trainer.callback_metrics
|
191 |
+
for metric_name, dico_name in self.metric_monitor.items():
|
192 |
+
if dico_name in losses_dict:
|
193 |
+
metric = losses_dict[dico_name].item()
|
194 |
+
metric = metric_format.format(metric)
|
195 |
+
metric = f"{metric_name} {metric}"
|
196 |
+
metrics_str.append(metric)
|
197 |
+
|
198 |
+
line = line + ": " + " ".join(metrics_str)
|
199 |
+
|
200 |
+
self.logger.info(line)
|
mGPT/config.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
from os.path import join as pjoin
|
5 |
+
import os
|
6 |
+
import glob
|
7 |
+
|
8 |
+
|
9 |
+
def get_module_config(cfg, filepath="./configs"):
|
10 |
+
"""
|
11 |
+
Load yaml config files from subfolders
|
12 |
+
"""
|
13 |
+
|
14 |
+
yamls = glob.glob(pjoin(filepath, '*', '*.yaml'))
|
15 |
+
yamls = [y.replace(filepath, '') for y in yamls]
|
16 |
+
for yaml in yamls:
|
17 |
+
nodes = yaml.replace('.yaml', '').replace('/', '.')
|
18 |
+
nodes = nodes[1:] if nodes[0] == '.' else nodes
|
19 |
+
OmegaConf.update(cfg, nodes, OmegaConf.load('./configs' + yaml))
|
20 |
+
|
21 |
+
return cfg
|
22 |
+
|
23 |
+
|
24 |
+
def get_obj_from_str(string, reload=False):
|
25 |
+
"""
|
26 |
+
Get object from string
|
27 |
+
"""
|
28 |
+
|
29 |
+
module, cls = string.rsplit(".", 1)
|
30 |
+
if reload:
|
31 |
+
module_imp = importlib.import_module(module)
|
32 |
+
importlib.reload(module_imp)
|
33 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
34 |
+
|
35 |
+
|
36 |
+
def instantiate_from_config(config):
|
37 |
+
"""
|
38 |
+
Instantiate object from config
|
39 |
+
"""
|
40 |
+
if not "target" in config:
|
41 |
+
raise KeyError("Expected key `target` to instantiate.")
|
42 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
43 |
+
|
44 |
+
|
45 |
+
def resume_config(cfg: OmegaConf):
|
46 |
+
"""
|
47 |
+
Resume model and wandb
|
48 |
+
"""
|
49 |
+
|
50 |
+
if cfg.TRAIN.RESUME:
|
51 |
+
resume = cfg.TRAIN.RESUME
|
52 |
+
if os.path.exists(resume):
|
53 |
+
# Checkpoints
|
54 |
+
cfg.TRAIN.PRETRAINED = pjoin(resume, "checkpoints", "last.ckpt")
|
55 |
+
# Wandb
|
56 |
+
wandb_files = os.listdir(pjoin(resume, "wandb", "latest-run"))
|
57 |
+
wandb_run = [item for item in wandb_files if "run-" in item][0]
|
58 |
+
cfg.LOGGER.WANDB.params.id = wandb_run.replace("run-","").replace(".wandb", "")
|
59 |
+
else:
|
60 |
+
raise ValueError("Resume path is not right.")
|
61 |
+
|
62 |
+
return cfg
|
63 |
+
|
64 |
+
def parse_args(phase="train"):
|
65 |
+
"""
|
66 |
+
Parse arguments and load config files
|
67 |
+
"""
|
68 |
+
|
69 |
+
parser = ArgumentParser()
|
70 |
+
group = parser.add_argument_group("Training options")
|
71 |
+
|
72 |
+
# Assets
|
73 |
+
group.add_argument(
|
74 |
+
"--cfg_assets",
|
75 |
+
type=str,
|
76 |
+
required=False,
|
77 |
+
default="./configs/assets.yaml",
|
78 |
+
help="config file for asset paths",
|
79 |
+
)
|
80 |
+
|
81 |
+
# Default config
|
82 |
+
if phase in ["train", "test"]:
|
83 |
+
cfg_defualt = "./configs/default.yaml"
|
84 |
+
elif phase == "render":
|
85 |
+
cfg_defualt = "./configs/render.yaml"
|
86 |
+
elif phase == "webui":
|
87 |
+
cfg_defualt = "./configs/webui.yaml"
|
88 |
+
|
89 |
+
group.add_argument(
|
90 |
+
"--cfg",
|
91 |
+
type=str,
|
92 |
+
required=False,
|
93 |
+
default=cfg_defualt,
|
94 |
+
help="config file",
|
95 |
+
)
|
96 |
+
|
97 |
+
# Parse for each phase
|
98 |
+
if phase in ["train", "test"]:
|
99 |
+
group.add_argument("--batch_size",
|
100 |
+
type=int,
|
101 |
+
required=False,
|
102 |
+
help="training batch size")
|
103 |
+
group.add_argument("--num_nodes",
|
104 |
+
type=int,
|
105 |
+
required=False,
|
106 |
+
help="number of nodes")
|
107 |
+
group.add_argument("--device",
|
108 |
+
type=int,
|
109 |
+
nargs="+",
|
110 |
+
required=False,
|
111 |
+
help="training device")
|
112 |
+
group.add_argument("--task",
|
113 |
+
type=str,
|
114 |
+
required=False,
|
115 |
+
help="evaluation task type")
|
116 |
+
group.add_argument("--nodebug",
|
117 |
+
action="store_true",
|
118 |
+
required=False,
|
119 |
+
help="debug or not")
|
120 |
+
|
121 |
+
|
122 |
+
if phase == "demo":
|
123 |
+
group.add_argument(
|
124 |
+
"--example",
|
125 |
+
type=str,
|
126 |
+
required=False,
|
127 |
+
help="input text and lengths with txt format",
|
128 |
+
)
|
129 |
+
group.add_argument(
|
130 |
+
"--out_dir",
|
131 |
+
type=str,
|
132 |
+
required=False,
|
133 |
+
help="output dir",
|
134 |
+
)
|
135 |
+
group.add_argument("--task",
|
136 |
+
type=str,
|
137 |
+
required=False,
|
138 |
+
help="evaluation task type")
|
139 |
+
|
140 |
+
if phase == "render":
|
141 |
+
group.add_argument("--npy",
|
142 |
+
type=str,
|
143 |
+
required=False,
|
144 |
+
default=None,
|
145 |
+
help="npy motion files")
|
146 |
+
group.add_argument("--dir",
|
147 |
+
type=str,
|
148 |
+
required=False,
|
149 |
+
default=None,
|
150 |
+
help="npy motion folder")
|
151 |
+
group.add_argument("--fps",
|
152 |
+
type=int,
|
153 |
+
required=False,
|
154 |
+
default=30,
|
155 |
+
help="render fps")
|
156 |
+
group.add_argument(
|
157 |
+
"--mode",
|
158 |
+
type=str,
|
159 |
+
required=False,
|
160 |
+
default="sequence",
|
161 |
+
help="render target: video, sequence, frame",
|
162 |
+
)
|
163 |
+
|
164 |
+
params = parser.parse_args()
|
165 |
+
|
166 |
+
# Load yaml config files
|
167 |
+
OmegaConf.register_new_resolver("eval", eval)
|
168 |
+
cfg_assets = OmegaConf.load(params.cfg_assets)
|
169 |
+
cfg_base = OmegaConf.load(pjoin(cfg_assets.CONFIG_FOLDER, 'default.yaml'))
|
170 |
+
cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load(params.cfg))
|
171 |
+
if not cfg_exp.FULL_CONFIG:
|
172 |
+
cfg_exp = get_module_config(cfg_exp, cfg_assets.CONFIG_FOLDER)
|
173 |
+
cfg = OmegaConf.merge(cfg_exp, cfg_assets)
|
174 |
+
|
175 |
+
# Update config with arguments
|
176 |
+
if phase in ["train", "test"]:
|
177 |
+
cfg.TRAIN.BATCH_SIZE = params.batch_size if params.batch_size else cfg.TRAIN.BATCH_SIZE
|
178 |
+
cfg.DEVICE = params.device if params.device else cfg.DEVICE
|
179 |
+
cfg.NUM_NODES = params.num_nodes if params.num_nodes else cfg.NUM_NODES
|
180 |
+
cfg.model.params.task = params.task if params.task else cfg.model.params.task
|
181 |
+
cfg.DEBUG = not params.nodebug if params.nodebug is not None else cfg.DEBUG
|
182 |
+
|
183 |
+
# Force no debug in test
|
184 |
+
if phase == "test":
|
185 |
+
cfg.DEBUG = False
|
186 |
+
cfg.DEVICE = [0]
|
187 |
+
print("Force no debugging and one gpu when testing")
|
188 |
+
|
189 |
+
if phase == "demo":
|
190 |
+
cfg.DEMO.RENDER = params.render
|
191 |
+
cfg.DEMO.FRAME_RATE = params.frame_rate
|
192 |
+
cfg.DEMO.EXAMPLE = params.example
|
193 |
+
cfg.DEMO.TASK = params.task
|
194 |
+
cfg.TEST.FOLDER = params.out_dir if params.out_dir else cfg.TEST.FOLDER
|
195 |
+
os.makedirs(cfg.TEST.FOLDER, exist_ok=True)
|
196 |
+
|
197 |
+
if phase == "render":
|
198 |
+
if params.npy:
|
199 |
+
cfg.RENDER.NPY = params.npy
|
200 |
+
cfg.RENDER.INPUT_MODE = "npy"
|
201 |
+
if params.dir:
|
202 |
+
cfg.RENDER.DIR = params.dir
|
203 |
+
cfg.RENDER.INPUT_MODE = "dir"
|
204 |
+
if params.fps:
|
205 |
+
cfg.RENDER.FPS = float(params.fps)
|
206 |
+
cfg.RENDER.MODE = params.mode
|
207 |
+
|
208 |
+
# Debug mode
|
209 |
+
if cfg.DEBUG:
|
210 |
+
cfg.NAME = "debug--" + cfg.NAME
|
211 |
+
cfg.LOGGER.WANDB.params.offline = True
|
212 |
+
cfg.LOGGER.VAL_EVERY_STEPS = 1
|
213 |
+
|
214 |
+
# Resume config
|
215 |
+
cfg = resume_config(cfg)
|
216 |
+
|
217 |
+
return cfg
|
mGPT/data/HumanML3D.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from os.path import join as pjoin
|
4 |
+
from .humanml.utils.word_vectorizer import WordVectorizer
|
5 |
+
from .humanml.scripts.motion_process import (process_file, recover_from_ric)
|
6 |
+
from . import BASEDataModule
|
7 |
+
from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken, Text2MotionDatasetM2T
|
8 |
+
from .utils import humanml3d_collate
|
9 |
+
|
10 |
+
|
11 |
+
class HumanML3DDataModule(BASEDataModule):
|
12 |
+
def __init__(self, cfg, **kwargs):
|
13 |
+
|
14 |
+
super().__init__(collate_fn=humanml3d_collate)
|
15 |
+
self.cfg = cfg
|
16 |
+
self.save_hyperparameters(logger=False)
|
17 |
+
|
18 |
+
# Basic info of the dataset
|
19 |
+
cfg.DATASET.JOINT_TYPE = 'humanml3d'
|
20 |
+
self.name = "humanml3d"
|
21 |
+
self.njoints = 22
|
22 |
+
|
23 |
+
# Path to the dataset
|
24 |
+
data_root = cfg.DATASET.HUMANML3D.ROOT
|
25 |
+
self.hparams.data_root = data_root
|
26 |
+
self.hparams.text_dir = pjoin(data_root, "texts")
|
27 |
+
self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs')
|
28 |
+
|
29 |
+
# Mean and std of the dataset
|
30 |
+
self.hparams.mean = np.load(pjoin('assets/meta', "mean.npy"))
|
31 |
+
self.hparams.std = np.load(pjoin('assets/meta', "std.npy"))
|
32 |
+
|
33 |
+
# Mean and std for fair evaluation
|
34 |
+
self.hparams.mean_eval = np.load(pjoin('assets/meta', "mean_eval.npy"))
|
35 |
+
self.hparams.std_eval = np.load(pjoin('assets/meta', "std_eval.npy"))
|
36 |
+
|
37 |
+
# Length of the dataset
|
38 |
+
self.hparams.max_motion_length = cfg.DATASET.HUMANML3D.MAX_MOTION_LEN
|
39 |
+
self.hparams.min_motion_length = cfg.DATASET.HUMANML3D.MIN_MOTION_LEN
|
40 |
+
self.hparams.max_text_len = cfg.DATASET.HUMANML3D.MAX_TEXT_LEN
|
41 |
+
self.hparams.unit_length = cfg.DATASET.HUMANML3D.UNIT_LEN
|
42 |
+
|
43 |
+
# Additional parameters
|
44 |
+
self.hparams.debug = cfg.DEBUG
|
45 |
+
self.hparams.stage = cfg.TRAIN.STAGE
|
46 |
+
|
47 |
+
# Dataset switch
|
48 |
+
self.DatasetEval = Text2MotionDatasetEval
|
49 |
+
|
50 |
+
if cfg.TRAIN.STAGE == "vae":
|
51 |
+
if cfg.model.params.motion_vae.target.split('.')[-1].lower() == "vqvae":
|
52 |
+
self.hparams.win_size = 64
|
53 |
+
self.Dataset = MotionDatasetVQ
|
54 |
+
else:
|
55 |
+
self.Dataset = MotionDataset
|
56 |
+
elif 'lm' in cfg.TRAIN.STAGE:
|
57 |
+
self.hparams.code_path = cfg.DATASET.CODE_PATH
|
58 |
+
self.hparams.task_path = cfg.DATASET.TASK_PATH
|
59 |
+
self.hparams.std_text = cfg.DATASET.HUMANML3D.STD_TEXT
|
60 |
+
self.Dataset = Text2MotionDatasetCB
|
61 |
+
elif cfg.TRAIN.STAGE == "token":
|
62 |
+
self.Dataset = Text2MotionDatasetToken
|
63 |
+
self.DatasetEval = Text2MotionDatasetToken
|
64 |
+
elif cfg.TRAIN.STAGE == "m2t":
|
65 |
+
self.Dataset = Text2MotionDatasetM2T
|
66 |
+
self.DatasetEval = Text2MotionDatasetM2T
|
67 |
+
else:
|
68 |
+
self.Dataset = Text2MotionDataset
|
69 |
+
|
70 |
+
# Get additional info of the dataset
|
71 |
+
self.nfeats = 263
|
72 |
+
cfg.DATASET.NFEATS = self.nfeats
|
73 |
+
|
74 |
+
|
75 |
+
def feats2joints(self, features):
|
76 |
+
mean = torch.tensor(self.hparams.mean).to(features)
|
77 |
+
std = torch.tensor(self.hparams.std).to(features)
|
78 |
+
features = features * std + mean
|
79 |
+
return recover_from_ric(features, self.njoints)
|
80 |
+
|
81 |
+
def joints2feats(self, features):
|
82 |
+
features = process_file(features, self.njoints)[0]
|
83 |
+
return features
|
84 |
+
|
85 |
+
def normalize(self, features):
|
86 |
+
mean = torch.tensor(self.hparams.mean).to(features)
|
87 |
+
std = torch.tensor(self.hparams.std).to(features)
|
88 |
+
features = (features - mean) / std
|
89 |
+
return features
|
90 |
+
|
91 |
+
def denormalize(self, features):
|
92 |
+
mean = torch.tensor(self.hparams.mean).to(features)
|
93 |
+
std = torch.tensor(self.hparams.std).to(features)
|
94 |
+
features = features * std + mean
|
95 |
+
return features
|
96 |
+
|
97 |
+
def renorm4t2m(self, features):
|
98 |
+
# renorm to t2m norms for using t2m evaluators
|
99 |
+
ori_mean = torch.tensor(self.hparams.mean).to(features)
|
100 |
+
ori_std = torch.tensor(self.hparams.std).to(features)
|
101 |
+
eval_mean = torch.tensor(self.hparams.mean_eval).to(features)
|
102 |
+
eval_std = torch.tensor(self.hparams.std_eval).to(features)
|
103 |
+
features = features * ori_std + ori_mean
|
104 |
+
features = (features - eval_mean) / eval_std
|
105 |
+
return features
|
106 |
+
|
107 |
+
def mm_mode(self, mm_on=True):
|
108 |
+
if mm_on:
|
109 |
+
self.is_mm = True
|
110 |
+
self.name_list = self.test_dataset.name_list
|
111 |
+
self.mm_list = np.random.choice(self.name_list,
|
112 |
+
self.cfg.METRIC.MM_NUM_SAMPLES,
|
113 |
+
replace=False)
|
114 |
+
self.test_dataset.name_list = self.mm_list
|
115 |
+
else:
|
116 |
+
self.is_mm = False
|
117 |
+
self.test_dataset.name_list = self.name_list
|
mGPT/data/Kit.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from os.path import join as pjoin
|
4 |
+
from .humanml.utils.word_vectorizer import WordVectorizer
|
5 |
+
from .humanml.scripts.motion_process import (process_file, recover_from_ric)
|
6 |
+
from .HumanML3D import HumanML3DDataModule
|
7 |
+
from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken
|
8 |
+
|
9 |
+
|
10 |
+
class KitDataModule(HumanML3DDataModule):
|
11 |
+
def __init__(self, cfg, **kwargs):
|
12 |
+
|
13 |
+
super().__init__(cfg, **kwargs)
|
14 |
+
|
15 |
+
# Basic info of the dataset
|
16 |
+
self.name = "kit"
|
17 |
+
self.njoints = 21
|
18 |
+
|
19 |
+
# Path to the dataset
|
20 |
+
data_root = cfg.DATASET.KIT.ROOT
|
21 |
+
self.hparams.data_root = data_root
|
22 |
+
self.hparams.text_dir = pjoin(data_root, "texts")
|
23 |
+
self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs')
|
24 |
+
|
25 |
+
# Mean and std of the dataset
|
26 |
+
dis_data_root = pjoin(cfg.DATASET.KIT.MEAN_STD_PATH, 'kit',
|
27 |
+
"VQVAEV3_CB1024_CMT_H1024_NRES3", "meta")
|
28 |
+
self.hparams.mean = np.load(pjoin(dis_data_root, "mean.npy"))
|
29 |
+
self.hparams.std = np.load(pjoin(dis_data_root, "std.npy"))
|
30 |
+
|
31 |
+
# Mean and std for fair evaluation
|
32 |
+
dis_data_root_eval = pjoin(cfg.DATASET.KIT.MEAN_STD_PATH, 't2m',
|
33 |
+
"Comp_v6_KLD005", "meta")
|
34 |
+
self.hparams.mean_eval = np.load(pjoin(dis_data_root_eval, "mean.npy"))
|
35 |
+
self.hparams.std_eval = np.load(pjoin(dis_data_root_eval, "std.npy"))
|
36 |
+
|
37 |
+
# Length of the dataset
|
38 |
+
self.hparams.max_motion_length = cfg.DATASET.KIT.MAX_MOTION_LEN
|
39 |
+
self.hparams.min_motion_length = cfg.DATASET.KIT.MIN_MOTION_LEN
|
40 |
+
self.hparams.max_text_len = cfg.DATASET.KIT.MAX_TEXT_LEN
|
41 |
+
self.hparams.unit_length = cfg.DATASET.KIT.UNIT_LEN
|
42 |
+
|
43 |
+
# Get additional info of the dataset
|
44 |
+
self._sample_set = self.get_sample_set(overrides={"split": "test", "tiny": True})
|
45 |
+
self.nfeats = self._sample_set.nfeats
|
46 |
+
cfg.DATASET.NFEATS = self.nfeats
|
47 |
+
|
48 |
+
def feats2joints(self, features):
|
49 |
+
mean = torch.tensor(self.hparams.mean).to(features)
|
50 |
+
std = torch.tensor(self.hparams.std).to(features)
|
51 |
+
features = features * std + mean
|
52 |
+
return recover_from_ric(features, self.njoints)
|
53 |
+
|
54 |
+
def joints2feats(self, features):
|
55 |
+
features = process_file(features, self.njoints)[0]
|
56 |
+
# mean = torch.tensor(self.hparams.mean).to(features)
|
57 |
+
# std = torch.tensor(self.hparams.std).to(features)
|
58 |
+
# features = (features - mean) / std
|
59 |
+
return features
|
60 |
+
|
61 |
+
def normalize(self, features):
|
62 |
+
mean = torch.tensor(self.hparams.mean).to(features)
|
63 |
+
std = torch.tensor(self.hparams.std).to(features)
|
64 |
+
features = (features - mean) / std
|
65 |
+
return features
|
66 |
+
|
67 |
+
def renorm4t2m(self, features):
|
68 |
+
# renorm to t2m norms for using t2m evaluators
|
69 |
+
ori_mean = torch.tensor(self.hparams.mean).to(features)
|
70 |
+
ori_std = torch.tensor(self.hparams.std).to(features)
|
71 |
+
eval_mean = torch.tensor(self.hparams.mean_eval).to(features)
|
72 |
+
eval_std = torch.tensor(self.hparams.std_eval).to(features)
|
73 |
+
features = features * ori_std + ori_mean
|
74 |
+
features = (features - eval_mean) / eval_std
|
75 |
+
return features
|
76 |
+
|
77 |
+
def mm_mode(self, mm_on=True):
|
78 |
+
# random select samples for mm
|
79 |
+
if mm_on:
|
80 |
+
self.is_mm = True
|
81 |
+
self.name_list = self.test_dataset.name_list
|
82 |
+
self.mm_list = np.random.choice(self.name_list,
|
83 |
+
self.cfg.METRIC.MM_NUM_SAMPLES,
|
84 |
+
replace=False)
|
85 |
+
self.test_dataset.name_list = self.mm_list
|
86 |
+
else:
|
87 |
+
self.is_mm = False
|
88 |
+
self.test_dataset.name_list = self.name_list
|
mGPT/data/__init__.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
|
4 |
+
|
5 |
+
class BASEDataModule(pl.LightningDataModule):
|
6 |
+
def __init__(self, collate_fn):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
self.dataloader_options = {"collate_fn": collate_fn}
|
10 |
+
self.persistent_workers = True
|
11 |
+
self.is_mm = False
|
12 |
+
|
13 |
+
self._train_dataset = None
|
14 |
+
self._val_dataset = None
|
15 |
+
self._test_dataset = None
|
16 |
+
|
17 |
+
def get_sample_set(self, overrides={}):
|
18 |
+
sample_params = self.hparams.copy()
|
19 |
+
sample_params.update(overrides)
|
20 |
+
return self.DatasetEval(**sample_params)
|
21 |
+
|
22 |
+
@property
|
23 |
+
def train_dataset(self):
|
24 |
+
if self._train_dataset is None:
|
25 |
+
self._train_dataset = self.Dataset(split=self.cfg.TRAIN.SPLIT,
|
26 |
+
**self.hparams)
|
27 |
+
return self._train_dataset
|
28 |
+
|
29 |
+
@property
|
30 |
+
def val_dataset(self):
|
31 |
+
if self._val_dataset is None:
|
32 |
+
params = self.hparams.copy()
|
33 |
+
params['code_path'] = None
|
34 |
+
params['split'] = self.cfg.EVAL.SPLIT
|
35 |
+
self._val_dataset = self.DatasetEval(**params)
|
36 |
+
return self._val_dataset
|
37 |
+
|
38 |
+
@property
|
39 |
+
def test_dataset(self):
|
40 |
+
if self._test_dataset is None:
|
41 |
+
# self._test_dataset = self.DatasetEval(split=self.cfg.TEST.SPLIT,
|
42 |
+
# **self.hparams)
|
43 |
+
params = self.hparams.copy()
|
44 |
+
params['code_path'] = None
|
45 |
+
params['split'] = self.cfg.TEST.SPLIT
|
46 |
+
self._test_dataset = self.DatasetEval( **params)
|
47 |
+
return self._test_dataset
|
48 |
+
|
49 |
+
def setup(self, stage=None):
|
50 |
+
# Use the getter the first time to load the data
|
51 |
+
if stage in (None, "fit"):
|
52 |
+
_ = self.train_dataset
|
53 |
+
_ = self.val_dataset
|
54 |
+
if stage in (None, "test"):
|
55 |
+
_ = self.test_dataset
|
56 |
+
|
57 |
+
def train_dataloader(self):
|
58 |
+
dataloader_options = self.dataloader_options.copy()
|
59 |
+
dataloader_options["batch_size"] = self.cfg.TRAIN.BATCH_SIZE
|
60 |
+
dataloader_options["num_workers"] = self.cfg.TRAIN.NUM_WORKERS
|
61 |
+
return DataLoader(
|
62 |
+
self.train_dataset,
|
63 |
+
shuffle=False,
|
64 |
+
persistent_workers=True,
|
65 |
+
**dataloader_options,
|
66 |
+
)
|
67 |
+
|
68 |
+
def predict_dataloader(self):
|
69 |
+
dataloader_options = self.dataloader_options.copy()
|
70 |
+
dataloader_options[
|
71 |
+
"batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
|
72 |
+
dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
|
73 |
+
dataloader_options["shuffle"] = False
|
74 |
+
return DataLoader(
|
75 |
+
self.test_dataset,
|
76 |
+
persistent_workers=True,
|
77 |
+
**dataloader_options,
|
78 |
+
)
|
79 |
+
|
80 |
+
def val_dataloader(self):
|
81 |
+
# overrides batch_size and num_workers
|
82 |
+
dataloader_options = self.dataloader_options.copy()
|
83 |
+
dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE
|
84 |
+
dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS
|
85 |
+
dataloader_options["shuffle"] = False
|
86 |
+
return DataLoader(
|
87 |
+
self.val_dataset,
|
88 |
+
persistent_workers=True,
|
89 |
+
**dataloader_options,
|
90 |
+
)
|
91 |
+
|
92 |
+
def test_dataloader(self):
|
93 |
+
# overrides batch_size and num_workers
|
94 |
+
dataloader_options = self.dataloader_options.copy()
|
95 |
+
dataloader_options[
|
96 |
+
"batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
|
97 |
+
dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
|
98 |
+
dataloader_options["shuffle"] = False
|
99 |
+
return DataLoader(
|
100 |
+
self.test_dataset,
|
101 |
+
persistent_workers=True,
|
102 |
+
**dataloader_options,
|
103 |
+
)
|
mGPT/data/build_data.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from omegaconf import OmegaConf
|
2 |
+
from os.path import join as pjoin
|
3 |
+
from mGPT.config import instantiate_from_config
|
4 |
+
|
5 |
+
|
6 |
+
def build_data(cfg, phase="train"):
|
7 |
+
data_config = OmegaConf.to_container(cfg.DATASET, resolve=True)
|
8 |
+
data_config['params'] = {'cfg': cfg, 'phase': phase}
|
9 |
+
if isinstance(data_config['target'], str):
|
10 |
+
return instantiate_from_config(data_config)
|
11 |
+
elif isinstance(data_config['target'], list):
|
12 |
+
data_config_tmp = data_config.copy()
|
13 |
+
data_config_tmp['params']['dataModules'] = data_config['target']
|
14 |
+
data_config_tmp['target'] = 'mGPT.data.Concat.ConcatDataModule'
|
15 |
+
return instantiate_from_config(data_config)
|