bill-jiang commited on
Commit
4409449
1 Parent(s): a0563b6
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +165 -0
  2. README.md +4 -4
  3. app.py +511 -0
  4. assets/css/custom.css +359 -0
  5. assets/images/avatar_bot.jpg +0 -0
  6. assets/meta/mean.npy +3 -0
  7. assets/meta/mean_eval.npy +3 -0
  8. assets/meta/std.npy +3 -0
  9. assets/meta/std_eval.npy +3 -0
  10. assets/videos/m2t_0.mp4 +0 -0
  11. assets/videos/t2m_0.mp4 +0 -0
  12. configs/assets.yaml +32 -0
  13. configs/default.yaml +141 -0
  14. configs/evaluator/tm2t.yaml +19 -0
  15. configs/lm/default.yaml +7 -0
  16. configs/render.yaml +23 -0
  17. configs/vq/default.yaml +15 -0
  18. configs/webui.yaml +74 -0
  19. deps/smpl/smpl_models/SMPL_downsample_index.pkl +3 -0
  20. deps/smpl/smpl_models/gmm_08.pkl +3 -0
  21. deps/smpl/smpl_models/neutral_smpl_mean_params.h5 +3 -0
  22. deps/smpl/smpl_models/smpl.faces +0 -0
  23. deps/smpl/smpl_models/smpl.tar.gz +3 -0
  24. deps/smpl/smpl_models/smpl/SMPL_FEMALE.pkl +3 -0
  25. deps/smpl/smpl_models/smpl/SMPL_MALE.pkl +3 -0
  26. deps/smpl/smpl_models/smpl/SMPL_NEUTRAL.pkl +3 -0
  27. deps/smpl/smpl_models/smpl/readme.txt +1 -0
  28. deps/smpl/smpl_models/smplh/SMPLH_FEMALE.npz +3 -0
  29. deps/smpl/smpl_models/smplh/SMPLH_MALE.npz +3 -0
  30. deps/smpl/smpl_models/smplh/SMPLH_NEUTRAL.npz +3 -0
  31. deps/smpl/smpl_models/smplh/mano_v1_2.zip +3 -0
  32. deps/smpl/smpl_models/smplh/smplh.faces +0 -0
  33. deps/smpl/smpl_models/smplh/smplh.tar.xz +3 -0
  34. deps/smpl/smpl_models/smplx_parts_segm.pkl +3 -0
  35. mGPT/__init__.py +0 -0
  36. mGPT/archs/__init__.py +0 -0
  37. mGPT/archs/mgpt_lm.py +592 -0
  38. mGPT/archs/mgpt_vq.py +190 -0
  39. mGPT/archs/tm2t_evaluator.py +111 -0
  40. mGPT/archs/tools/embeddings.py +322 -0
  41. mGPT/archs/tools/quantize_cnn.py +414 -0
  42. mGPT/archs/tools/resnet.py +82 -0
  43. mGPT/archs/tools/token_emb.py +73 -0
  44. mGPT/archs/tools/transformer_layers.py +285 -0
  45. mGPT/callback.py +200 -0
  46. mGPT/config.py +217 -0
  47. mGPT/data/HumanML3D.py +117 -0
  48. mGPT/data/Kit.py +88 -0
  49. mGPT/data/__init__.py +103 -0
  50. mGPT/data/build_data.py +15 -0
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+ .DS_Store
9
+ pyglet
10
+ app2.py
11
+ render.py
12
+ cache
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: MotionGPT
3
- emoji: 🌖
4
- colorFrom: blue
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.43.2
8
  app_file: app.py
9
  pinned: false
10
- license: cc
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: MotionGPT
3
+ emoji: 🏃
4
+ colorFrom: yellow
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.43.2
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import torch
4
+ import time
5
+ import cv2
6
+ import os
7
+ import numpy as np
8
+ import OpenGL.GL as gl
9
+ import pytorch_lightning as pl
10
+ import moviepy.editor as mp
11
+ from pathlib import Path
12
+ from mGPT.data.build_data import build_data
13
+ from mGPT.models.build_model import build_model
14
+ from mGPT.config import parse_args
15
+ from scipy.spatial.transform import Rotation as RRR
16
+ import mGPT.render.matplot.plot_3d_global as plot_3d
17
+ from mGPT.render.pyrender.hybrik_loc2rot import HybrIKJointsToRotmat
18
+ from mGPT.render.pyrender.smpl_render import SMPLRender
19
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
20
+ import librosa
21
+ from huggingface_hub import snapshot_download
22
+
23
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
24
+ os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
25
+ os.system('pip install /home/user/app/pyrender')
26
+
27
+ # Load model
28
+ cfg = parse_args(phase="webui") # parse config file
29
+ cfg.FOLDER = 'cache'
30
+ output_dir = Path(cfg.FOLDER)
31
+ output_dir.mkdir(parents=True, exist_ok=True)
32
+ pl.seed_everything(cfg.SEED_VALUE)
33
+ if torch.cuda.is_available():
34
+ device = torch.device("cuda")
35
+ else:
36
+ device = torch.device("cpu")
37
+
38
+ model_path = snapshot_download(repo_id="bill-jiang/MotionGPT-base")
39
+
40
+ datamodule = build_data(cfg, phase="test")
41
+ model = build_model(cfg, datamodule)
42
+ state_dict = torch.load(f'{model_path}/motiongpt_s3_h3d.tar',
43
+ map_location="cpu")["state_dict"]
44
+ model.load_state_dict(state_dict)
45
+ model.to(device)
46
+
47
+ audio_processor = WhisperProcessor.from_pretrained(cfg.model.whisper_path)
48
+ audio_model = WhisperForConditionalGeneration.from_pretrained(
49
+ cfg.model.whisper_path).to(device)
50
+ forced_decoder_ids_zh = audio_processor.get_decoder_prompt_ids(
51
+ language="zh", task="translate")
52
+ forced_decoder_ids_en = audio_processor.get_decoder_prompt_ids(
53
+ language="en", task="translate")
54
+
55
+ # HTML Style
56
+
57
+ Video_Components = """
58
+ <div class="side-video" style="position: relative;">
59
+ <video width="340" autoplay loop>
60
+ <source src="file/{video_path}" type="video/mp4">
61
+ </video>
62
+ <a class="videodl-button" href="file/{video_path}" download="{video_fname}" title="Download Video">
63
+ <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#000000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-video"><path d="m22 8-6 4 6 4V8Z"/><rect width="14" height="12" x="2" y="6" rx="2" ry="2"/></svg>
64
+ </a>
65
+ <a class="npydl-button" href="file/{motion_path}" download="{motion_fname}" title="Download Motion">
66
+ <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#000000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-file-box"><path d="M14.5 22H18a2 2 0 0 0 2-2V7.5L14.5 2H6a2 2 0 0 0-2 2v4"/><polyline points="14 2 14 8 20 8"/><path d="M2.97 13.12c-.6.36-.97 1.02-.97 1.74v3.28c0 .72.37 1.38.97 1.74l3 1.83c.63.39 1.43.39 2.06 0l3-1.83c.6-.36.97-1.02.97-1.74v-3.28c0-.72-.37-1.38-.97-1.74l-3-1.83a1.97 1.97 0 0 0-2.06 0l-3 1.83Z"/><path d="m7 17-4.74-2.85"/><path d="m7 17 4.74-2.85"/><path d="M7 17v5"/></svg>
67
+ </a>
68
+ </div>
69
+ """
70
+
71
+ Video_Components_example = """
72
+ <div class="side-video" style="position: relative;">
73
+ <video width="340" autoplay loop controls>
74
+ <source src="file/{video_path}" type="video/mp4">
75
+ </video>
76
+ <a class="npydl-button" href="file/{video_path}" download="{video_fname}" title="Download Video">
77
+ <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-video"><path d="m22 8-6 4 6 4V8Z"/><rect width="14" height="12" x="2" y="6" rx="2" ry="2"/></svg>
78
+ </a>
79
+ </div>
80
+ """
81
+
82
+ Text_Components = """
83
+ <h3 class="side-content" >{msg}</h3>
84
+ """
85
+
86
+
87
+ def motion_token_to_string(motion_token, lengths, codebook_size=512):
88
+ motion_string = []
89
+ for i in range(motion_token.shape[0]):
90
+ motion_i = motion_token[i].cpu(
91
+ ) if motion_token.device.type == 'cuda' else motion_token[i]
92
+ motion_list = motion_i.tolist()[:lengths[i]]
93
+ motion_string.append(
94
+ (f'<motion_id_{codebook_size}>' +
95
+ ''.join([f'<motion_id_{int(i)}>' for i in motion_list]) +
96
+ f'<motion_id_{codebook_size + 1}>'))
97
+ return motion_string
98
+
99
+
100
+ def render_motion(data, feats, method='fast'):
101
+ fname = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(
102
+ time.time())) + str(np.random.randint(10000, 99999))
103
+ video_fname = fname + '.mp4'
104
+ feats_fname = fname + '.npy'
105
+ output_npy_path = os.path.join(output_dir, feats_fname)
106
+ output_mp4_path = os.path.join(output_dir, video_fname)
107
+ np.save(output_npy_path, feats)
108
+
109
+ if method == 'slow':
110
+ if len(data.shape) == 4:
111
+ data = data[0]
112
+ data = data - data[0, 0]
113
+ pose_generator = HybrIKJointsToRotmat()
114
+ pose = pose_generator(data)
115
+ pose = np.concatenate([
116
+ pose,
117
+ np.stack([np.stack([np.eye(3)] * pose.shape[0], 0)] * 2, 1)
118
+ ], 1)
119
+ shape = [768, 768]
120
+ render = SMPLRender(cfg.RENDER.SMPL_MODEL_PATH)
121
+
122
+ if not os.environ.get("PYOPENGL_PLATFORM"):
123
+ os.environ["DISPLAY"] = ":0.0"
124
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
125
+
126
+ size = (shape[1], shape[0])
127
+ fps = 20.0
128
+ fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
129
+ videoWriter = cv2.VideoWriter(output_mp4_path, fourcc, fps, size)
130
+ r = RRR.from_rotvec(np.array([np.pi, 0.0, 0.0]))
131
+ pose[:, 0] = np.matmul(r.as_matrix().reshape(1, 3, 3), pose[:, 0])
132
+ for i in range(data.shape[0]):
133
+ img = np.zeros([shape[0], shape[1], 3])
134
+ aroot = data[[i], 0] + np.array([[0.0, 0.0, 30.0]])
135
+ aroot[:, 1] = -aroot[:, 1]
136
+ params = dict(pred_shape=np.zeros([1, 10]),
137
+ pred_root=aroot,
138
+ pred_pose=pose[[i]])
139
+ renderImg = render.render(img.copy(), params)
140
+ renderImg = (renderImg * 255).astype(np.uint8)
141
+ videoWriter.write(renderImg)
142
+ videoWriter.release()
143
+ output_video_h264_name = output_mp4_path[:-4] + '_h264.mp4'
144
+ command = 'ffmpeg -y -i {} -vcodec h264 {}'.format(
145
+ output_mp4_path, output_video_h264_name)
146
+ os.system(command)
147
+ output_mp4_path = output_video_h264_name
148
+ video_fname = video_fname[:-4] + '_h264.mp4'
149
+ elif method == 'fast':
150
+ output_gif_path = output_mp4_path[:-4] + '.gif'
151
+ if len(data.shape) == 3:
152
+ data = data[None]
153
+ if isinstance(data, torch.Tensor):
154
+ data = data.cpu().numpy()
155
+ pose_vis = plot_3d.draw_to_batch(data, [''], [output_gif_path])
156
+ out_video = mp.VideoFileClip(output_gif_path)
157
+ out_video.write_videofile(output_mp4_path)
158
+
159
+ return output_mp4_path, video_fname, output_npy_path, feats_fname
160
+
161
+
162
+ def load_motion(motion_uploaded, method):
163
+ file = motion_uploaded['file']
164
+
165
+ feats = torch.tensor(np.load(file), device=model.device)
166
+ if len(feats.shape) == 2:
167
+ feats = feats[None]
168
+ # feats = model.datamodule.normalize(feats)
169
+
170
+ # Motion tokens
171
+ motion_lengths = feats.shape[0]
172
+ motion_token, _ = model.vae.encode(feats)
173
+
174
+ motion_token_string = model.lm.motion_token_to_string(
175
+ motion_token, [motion_token.shape[1]])[0]
176
+ motion_token_length = motion_token.shape[1]
177
+
178
+ # Motion rendered
179
+ joints = model.datamodule.feats2joints(feats.cpu()).cpu().numpy()
180
+ output_mp4_path, video_fname, output_npy_path, joints_fname = render_motion(
181
+ joints,
182
+ feats.to('cpu').numpy(), method)
183
+
184
+ motion_uploaded.update({
185
+ "feats": feats,
186
+ "joints": joints,
187
+ "motion_video": output_mp4_path,
188
+ "motion_video_fname": video_fname,
189
+ "motion_joints": output_npy_path,
190
+ "motion_joints_fname": joints_fname,
191
+ "motion_lengths": motion_lengths,
192
+ "motion_token": motion_token,
193
+ "motion_token_string": motion_token_string,
194
+ "motion_token_length": motion_token_length,
195
+ })
196
+
197
+ return motion_uploaded
198
+
199
+
200
+ def add_text(history, text, motion_uploaded, data_stored, method):
201
+ data_stored = data_stored + [{'user_input': text}]
202
+
203
+ text = f"""<h3>{text}</h3>"""
204
+ history = history + [(text, None)]
205
+ if 'file' in motion_uploaded.keys():
206
+ motion_uploaded = load_motion(motion_uploaded, method)
207
+ output_mp4_path = motion_uploaded['motion_video']
208
+ video_fname = motion_uploaded['motion_video_fname']
209
+ output_npy_path = motion_uploaded['motion_joints']
210
+ joints_fname = motion_uploaded['motion_joints_fname']
211
+ history = history + [(Video_Components.format(
212
+ video_path=output_mp4_path,
213
+ video_fname=video_fname,
214
+ motion_path=output_npy_path,
215
+ motion_fname=joints_fname), None)]
216
+
217
+ return history, gr.update(value="",
218
+ interactive=False), motion_uploaded, data_stored
219
+
220
+
221
+ def add_audio(history, audio_path, data_stored, language='en'):
222
+ audio, sampling_rate = librosa.load(audio_path, sr=16000)
223
+ input_features = audio_processor(
224
+ audio, sampling_rate, return_tensors="pt"
225
+ ).input_features # whisper training sampling rate, do not modify
226
+ input_features = torch.Tensor(input_features).to(device)
227
+
228
+ if language == 'English':
229
+ forced_decoder_ids = forced_decoder_ids_en
230
+ else:
231
+ forced_decoder_ids = forced_decoder_ids_zh
232
+ predicted_ids = audio_model.generate(input_features,
233
+ forced_decoder_ids=forced_decoder_ids)
234
+ text_input = audio_processor.batch_decode(predicted_ids,
235
+ skip_special_tokens=True)
236
+ text_input = str(text_input).strip('[]"')
237
+ data_stored = data_stored + [{'user_input': text_input}]
238
+ gr.update(value=data_stored, interactive=False)
239
+ history = history + [(text_input, None)]
240
+
241
+ return history, data_stored
242
+
243
+
244
+ def add_file(history, file, txt, motion_uploaded):
245
+ motion_uploaded['file'] = file.name
246
+ txt = txt.replace(" <Motion_Placeholder>", "") + " <Motion_Placeholder>"
247
+ return history, gr.update(value=txt, interactive=True), motion_uploaded
248
+
249
+
250
+ def bot(history, motion_uploaded, data_stored, method):
251
+
252
+ motion_length, motion_token_string = motion_uploaded[
253
+ "motion_lengths"], motion_uploaded["motion_token_string"]
254
+
255
+ input = data_stored[-1]['user_input']
256
+ prompt = model.lm.placeholder_fulfill(input, motion_length,
257
+ motion_token_string, "")
258
+ data_stored[-1]['model_input'] = prompt
259
+ batch = {
260
+ "length": [motion_length],
261
+ "text": [prompt],
262
+ }
263
+
264
+ outputs = model(batch, task="t2m")
265
+ out_feats = outputs["feats"][0]
266
+ out_lengths = outputs["length"][0]
267
+ out_joints = outputs["joints"][:out_lengths].detach().cpu().numpy()
268
+ out_texts = outputs["texts"][0]
269
+ output_mp4_path, video_fname, output_npy_path, joints_fname = render_motion(
270
+ out_joints,
271
+ out_feats.to('cpu').numpy(), method)
272
+
273
+ motion_uploaded = {
274
+ "feats": None,
275
+ "joints": None,
276
+ "motion_video": None,
277
+ "motion_lengths": 0,
278
+ "motion_token": None,
279
+ "motion_token_string": '',
280
+ "motion_token_length": 0,
281
+ }
282
+
283
+ data_stored[-1]['model_output'] = {
284
+ "feats": out_feats,
285
+ "joints": out_joints,
286
+ "length": out_lengths,
287
+ "texts": out_texts,
288
+ "motion_video": output_mp4_path,
289
+ "motion_video_fname": video_fname,
290
+ "motion_joints": output_npy_path,
291
+ "motion_joints_fname": joints_fname,
292
+ }
293
+
294
+ if '<Motion_Placeholder>' == out_texts:
295
+ response = [
296
+ Video_Components.format(video_path=output_mp4_path,
297
+ video_fname=video_fname,
298
+ motion_path=output_npy_path,
299
+ motion_fname=joints_fname)
300
+ ]
301
+ elif '<Motion_Placeholder>' in out_texts:
302
+ response = [
303
+ Text_Components.format(
304
+ msg=out_texts.split("<Motion_Placeholder>")[0]),
305
+ Video_Components.format(video_path=output_mp4_path,
306
+ video_fname=video_fname,
307
+ motion_path=output_npy_path,
308
+ motion_fname=joints_fname),
309
+ Text_Components.format(
310
+ msg=out_texts.split("<Motion_Placeholder>")[1]),
311
+ ]
312
+ else:
313
+ response = f"""<h3>{out_texts}</h3>"""
314
+
315
+ history[-1][1] = ""
316
+ for character in response:
317
+ history[-1][1] += character
318
+ time.sleep(0.02)
319
+ yield history, motion_uploaded, data_stored
320
+
321
+
322
+ def bot_example(history, responses):
323
+ for response in responses:
324
+ history[-1][1] = ""
325
+ for character in response:
326
+ history[-1][1] += character
327
+ time.sleep(0.02)
328
+ yield history, motion_uploaded, data_stored
329
+
330
+
331
+ # Examples
332
+ chat_instruct = [
333
+ (None,
334
+ "**👋 Hi, I'm MotionGPT! I can generate realistic human motion from text, or generate text from motion.**"
335
+ ),
336
+ (None,
337
+ "You can chat with me in pure text like generating human motion following your descriptions."
338
+ ),
339
+ (None,
340
+ "After generation, you can click the button in the top right of generation human motion result to download the human motion video or feature stored in .npy format."
341
+ ),
342
+ (None,
343
+ "With the human motion feature file downloaded or got from dataset, you are able to ask me to translate it!"
344
+ ),
345
+ (None,
346
+ "Of courser, you can also purely chat with me and let me give you human motion in text, here are some examples!"
347
+ ),
348
+ (None,
349
+ "We provide two motion visulization methods. The default fast method is skeleton line ploting which is like the examples below:"
350
+ ),
351
+ (None,
352
+ Video_Components_example.format(video_path="assets/videos/t2m_0.mp4",
353
+ video_fname="example1.mp4")),
354
+ (None,
355
+ "And the slow method is SMPL model rendering which is more realistic but slower."
356
+ ),
357
+ (None,
358
+ Video_Components_example.format(video_path="assets/videos/t2m_0.mp4",
359
+ video_fname="example1.mp4")),
360
+ (None, "👉 Follow the examples and try yourself!"),
361
+ ]
362
+
363
+ t2m_examples = [
364
+ (None,
365
+ "You can chat with me in pure text, following are some examples of text-to-motion generation!"
366
+ ),
367
+ ("Generate a person is walking forwards, but stumbles and steps back, then carries on forward.",
368
+ Video_Components_example.format(video_path="assets/videos/t2m_0.mp4",
369
+ video_fname="example1.mp4")),
370
+ ("Generate a person is walking forwards, but stumbles and steps back, then carries on forward.",
371
+ Video_Components_example.format(video_path="assets/videos/t2m_0.mp4",
372
+ video_fname="example1.mp4")),
373
+ ("Generate a person is walking forwards, but stumbles and steps back, then carries on forward.",
374
+ Video_Components_example.format(video_path="assets/videos/t2m_0.mp4",
375
+ video_fname="example1.mp4")),
376
+ ]
377
+
378
+ m2t_examples = [
379
+ (None,
380
+ "With the human motion feature file downloaded or got from dataset, you are able to ask me to translate it, here are some examples!"
381
+ ),
382
+ ("Please explain the movement shown in [Motion_tokens] using natural language.",
383
+ None),
384
+ (Video_Components_example.format(video_path="assets/videos/m2t_0.mp4",
385
+ video_fname="example2.mp4"),
386
+ "a person walks forward then does a backwards z-shape movement to its left side. then back to the right."
387
+ ),
388
+ ("Please explain the movement shown in [Motion_tokens] using natural language.",
389
+ None),
390
+ (Video_Components_example.format(video_path="assets/videos/m2t_0.mp4",
391
+ video_fname="example2.mp4"),
392
+ "a person walks forward then does a backwards z-shape movement to its left side. then back to the right."
393
+ ),
394
+ ]
395
+
396
+ t2t_examples = [
397
+ (None,
398
+ "Of courser, you can also purely chat with me and let me give you human motion in text, here are some examples!"
399
+ ),
400
+ ('Depict a motion as like you have seen it.',
401
+ "The person walks while swaying their hips along a curved path to the left slowly then stops to look down at the edge of the grey platform at something."
402
+ ),
403
+ ('Depict a motion as like you have seen it.',
404
+ "The person walks while swaying their hips along a curved path to the left slowly then stops to look down at the edge of the grey platform at something."
405
+ ),
406
+ ]
407
+
408
+ Init_chatbot = [
409
+ (None,
410
+ "**👋 Hi, I'm MotionGPT! I can generate realistic human motion from text, or generate text from motion.**"
411
+ )
412
+ ] + t2m_examples[:3] + m2t_examples[:2] + t2t_examples[:2] + chat_instruct[-4:]
413
+
414
+ with open("assets/css/custom.css", "r", encoding="utf-8") as f:
415
+ customCSS = f.read()
416
+
417
+ with gr.Blocks(css=customCSS) as demo:
418
+
419
+ # Variables
420
+ motion_uploaded = gr.State({
421
+ "feats": None,
422
+ "joints": None,
423
+ "motion_video": None,
424
+ "motion_lengths": 0,
425
+ "motion_token": None,
426
+ "motion_token_string": '',
427
+ "motion_token_length": 0,
428
+ })
429
+ data_stored = gr.State([])
430
+
431
+ gr.Markdown("# MotionGPT")
432
+
433
+ chatbot = gr.Chatbot(Init_chatbot,
434
+ elem_id="mGPT",
435
+ height=600,
436
+ label="MotionGPT",
437
+ avatar_images=(None,
438
+ ("assets/images/avatar_bot.jpg")),
439
+ bubble_full_width=False)
440
+
441
+ with gr.Row():
442
+ with gr.Column(scale=0.85):
443
+ with gr.Row():
444
+ txt = gr.Textbox(
445
+ label="Text",
446
+ show_label=False,
447
+ placeholder=
448
+ "Enter text and press ENTER or speak to input. You can also upload motion.",
449
+ container=False)
450
+
451
+ with gr.Row():
452
+ aud = gr.Audio(source="microphone",
453
+ label="Speak input",
454
+ type='filepath')
455
+ btn = gr.UploadButton("📁 Upload motion",
456
+ elem_id="upload",
457
+ file_types=["file"],
458
+ variant='primary')
459
+ regen = gr.Button("🔄 Regenerate", elem_id="regen")
460
+ clear = gr.ClearButton([txt, chatbot, aud], value='🗑️ Clear')
461
+
462
+ with gr.Row():
463
+ gr.Markdown('''
464
+ ### You can get more examples (pre-generated for faster response) by clicking the buttons below:
465
+ ''')
466
+
467
+ with gr.Row():
468
+ instruct = gr.Button("Instructions", elem_id="instruction")
469
+ t2m_eg = gr.Button("Text-to-Motion", elem_id="t2m")
470
+ m2t_eg = gr.Button("Motion-to-Text", elem_id="m2t")
471
+ t2t_eg = gr.Button("Random description", elem_id="t2t")
472
+
473
+ with gr.Column(scale=0.15, min_width=150):
474
+ method = gr.Dropdown(["slow", "fast"],
475
+ label="Visulization method",
476
+ interactive=True,
477
+ elem_id="method",
478
+ value="fast")
479
+
480
+ language = gr.Dropdown(["English", "中文"],
481
+ label="Speech language",
482
+ interactive=True,
483
+ elem_id="language",
484
+ value="English")
485
+
486
+ txt_msg = txt.submit(
487
+ add_text, [chatbot, txt, motion_uploaded, data_stored, method],
488
+ [chatbot, txt, motion_uploaded, data_stored],
489
+ queue=False).then(bot, [chatbot, motion_uploaded, data_stored, method],
490
+ [chatbot, motion_uploaded, data_stored])
491
+
492
+ txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
493
+
494
+ file_msg = btn.upload(add_file, [chatbot, btn, txt, motion_uploaded],
495
+ [chatbot, txt, motion_uploaded],
496
+ queue=False)
497
+ aud_msg = aud.stop_recording(
498
+ add_audio, [chatbot, aud, data_stored, language],
499
+ [chatbot, data_stored],
500
+ queue=False).then(bot, [chatbot, motion_uploaded, data_stored, method],
501
+ [chatbot, motion_uploaded, data_stored])
502
+ regen_msg = regen.click(bot,
503
+ [chatbot, motion_uploaded, data_stored, method],
504
+ [chatbot, motion_uploaded, data_stored],
505
+ queue=False)
506
+ chatbot.change(scroll_to_output=True)
507
+
508
+ demo.queue()
509
+
510
+ if __name__ == "__main__":
511
+ demo.launch(debug=True)
assets/css/custom.css ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Borrowed from https://huggingface.co/spaces/project-baize/chat-with-baize */
2
+
3
+ :root {
4
+ --chatbot-color-light: #f6f6f6;
5
+ --chatbot-color-dark: #121111;
6
+ }
7
+
8
+ /* Light mode (default) */
9
+ #mGPT {
10
+ background-color: var(--chatbot-color-light) !important;
11
+ color: #000000 !important;
12
+ }
13
+ [data-testid='bot'] {
14
+ background-color: #ffffff !important;
15
+ }
16
+ [data-testid='user'] {
17
+ background-color: #95ec69 !important;
18
+ }
19
+
20
+ /* Dark mode */
21
+ .dark #mGPT {
22
+ background-color: var(--chatbot-color-dark) !important;
23
+ color: #ffffff !important;
24
+ }
25
+ .dark [data-testid='bot'] {
26
+ background-color: #2c2c2c !important;
27
+ }
28
+
29
+ .dark [data-testid='user'] {
30
+ background-color: #26b561 !important;
31
+ }
32
+
33
+ #mGPT {
34
+ height: 100%;
35
+ min-height: 500px;
36
+ }
37
+
38
+ [class*='message-buttons'] {
39
+ visibility: hidden;
40
+ }
41
+
42
+ [class*='message'] {
43
+ border: none;
44
+ font-size: var(--text-xl) !important;
45
+ line-height: var(--line-xl) !important;
46
+ }
47
+ /* [data-testid='bot'] {
48
+ max-width: 85%;
49
+ width: auto !important;
50
+ border-bottom-left-radius: 0 !important;
51
+ }
52
+ [data-testid='user'] {
53
+ max-width: 85%;
54
+ width: auto !important;
55
+ border-bottom-right-radius: 0 !important;
56
+ } */
57
+
58
+ /* Text & Video */
59
+ #method {
60
+ line-height: 1.95 !important;
61
+ }
62
+
63
+ .side-content {
64
+ max-width: 340px;
65
+ }
66
+
67
+ /* @media only screen and (min-width: 768px) {
68
+ .side-content {
69
+ float: left;
70
+ overflow-wrap: break-word;
71
+ padding-right: 2rem;
72
+ }
73
+
74
+ .side-video {
75
+ float: right;
76
+ }
77
+ } */
78
+
79
+ /* Buttom */
80
+ #upload {
81
+ color: #000000;
82
+ }
83
+
84
+ .videodl-button {
85
+ position: absolute;
86
+ left: 80%;
87
+ top: 5px;
88
+ width: 24px;
89
+ height: 24px;
90
+ }
91
+
92
+ .videodl-button svg {
93
+ width: 24px;
94
+ height: 24px;
95
+ }
96
+
97
+ .npydl-button {
98
+ position: absolute;
99
+ left: 90%;
100
+ top: 5px;
101
+ width: 24px;
102
+ height: 24px;
103
+ }
104
+
105
+ .npydl-button svg {
106
+ width: 24px;
107
+ height: 24px;
108
+ }
109
+
110
+ /* Table */
111
+ table {
112
+ margin: 1em 0;
113
+ border-collapse: collapse;
114
+ empty-cells: show;
115
+ }
116
+ td,
117
+ th {
118
+ border: 1.2px solid var(--border-color-primary) !important;
119
+ padding: 0.2em;
120
+ }
121
+ thead {
122
+ background-color: rgba(175, 184, 193, 0.2);
123
+ }
124
+ thead th {
125
+ padding: 0.5em 0.2em;
126
+ }
127
+ /* Inline code */
128
+ #mGPT code {
129
+ display: inline;
130
+ white-space: break-spaces;
131
+ border-radius: 6px;
132
+ margin: 0 2px 0 2px;
133
+ padding: 0.2em 0.4em 0.1em 0.4em;
134
+ background-color: rgba(175, 184, 193, 0.2);
135
+ }
136
+ /* Code block */
137
+ #mGPT pre code {
138
+ display: block;
139
+ overflow: auto;
140
+ white-space: pre;
141
+ background-color: hsla(0, 0%, 0%, 80%) !important;
142
+ border-radius: 10px;
143
+ padding: 1.4em 1.2em 0em 1.4em;
144
+ margin: 1.2em 2em 1.2em 0.5em;
145
+ color: #fff;
146
+ box-shadow: 6px 6px 16px hsla(0, 0%, 0%, 0.2);
147
+ }
148
+ /* Hightlight */
149
+ #mGPT .highlight {
150
+ background-color: transparent;
151
+ }
152
+ #mGPT .highlight .hll {
153
+ background-color: #49483e;
154
+ }
155
+ #mGPT .highlight .c {
156
+ color: #75715e;
157
+ } /* Comment */
158
+ #mGPT .highlight .err {
159
+ color: #960050;
160
+ background-color: #1e0010;
161
+ } /* Error */
162
+ #mGPT .highlight .k {
163
+ color: #66d9ef;
164
+ } /* Keyword */
165
+ #mGPT .highlight .l {
166
+ color: #ae81ff;
167
+ } /* Literal */
168
+ #mGPT .highlight .n {
169
+ color: #f8f8f2;
170
+ } /* Name */
171
+ #mGPT .highlight .o {
172
+ color: #f92672;
173
+ } /* Operator */
174
+ #mGPT .highlight .p {
175
+ color: #f8f8f2;
176
+ } /* Punctuation */
177
+ #mGPT .highlight .ch {
178
+ color: #75715e;
179
+ } /* Comment.Hashbang */
180
+ #mGPT .highlight .cm {
181
+ color: #75715e;
182
+ } /* Comment.Multiline */
183
+ #mGPT .highlight .cp {
184
+ color: #75715e;
185
+ } /* Comment.Preproc */
186
+ #mGPT .highlight .cpf {
187
+ color: #75715e;
188
+ } /* Comment.PreprocFile */
189
+ #mGPT .highlight .c1 {
190
+ color: #75715e;
191
+ } /* Comment.Single */
192
+ #mGPT .highlight .cs {
193
+ color: #75715e;
194
+ } /* Comment.Special */
195
+ #mGPT .highlight .gd {
196
+ color: #f92672;
197
+ } /* Generic.Deleted */
198
+ #mGPT .highlight .ge {
199
+ font-style: italic;
200
+ } /* Generic.Emph */
201
+ #mGPT .highlight .gi {
202
+ color: #a6e22e;
203
+ } /* Generic.Inserted */
204
+ #mGPT .highlight .gs {
205
+ font-weight: bold;
206
+ } /* Generic.Strong */
207
+ #mGPT .highlight .gu {
208
+ color: #75715e;
209
+ } /* Generic.Subheading */
210
+ #mGPT .highlight .kc {
211
+ color: #66d9ef;
212
+ } /* Keyword.Constant */
213
+ #mGPT .highlight .kd {
214
+ color: #66d9ef;
215
+ } /* Keyword.Declaration */
216
+ #mGPT .highlight .kn {
217
+ color: #f92672;
218
+ } /* Keyword.Namespace */
219
+ #mGPT .highlight .kp {
220
+ color: #66d9ef;
221
+ } /* Keyword.Pseudo */
222
+ #mGPT .highlight .kr {
223
+ color: #66d9ef;
224
+ } /* Keyword.Reserved */
225
+ #mGPT .highlight .kt {
226
+ color: #66d9ef;
227
+ } /* Keyword.Type */
228
+ #mGPT .highlight .ld {
229
+ color: #e6db74;
230
+ } /* Literal.Date */
231
+ #mGPT .highlight .m {
232
+ color: #ae81ff;
233
+ } /* Literal.Number */
234
+ #mGPT .highlight .s {
235
+ color: #e6db74;
236
+ } /* Literal.String */
237
+ #mGPT .highlight .na {
238
+ color: #a6e22e;
239
+ } /* Name.Attribute */
240
+ #mGPT .highlight .nb {
241
+ color: #f8f8f2;
242
+ } /* Name.Builtin */
243
+ #mGPT .highlight .nc {
244
+ color: #a6e22e;
245
+ } /* Name.Class */
246
+ #mGPT .highlight .no {
247
+ color: #66d9ef;
248
+ } /* Name.Constant */
249
+ #mGPT .highlight .nd {
250
+ color: #a6e22e;
251
+ } /* Name.Decorator */
252
+ #mGPT .highlight .ni {
253
+ color: #f8f8f2;
254
+ } /* Name.Entity */
255
+ #mGPT .highlight .ne {
256
+ color: #a6e22e;
257
+ } /* Name.Exception */
258
+ #mGPT .highlight .nf {
259
+ color: #a6e22e;
260
+ } /* Name.Function */
261
+ #mGPT .highlight .nl {
262
+ color: #f8f8f2;
263
+ } /* Name.Label */
264
+ #mGPT .highlight .nn {
265
+ color: #f8f8f2;
266
+ } /* Name.Namespace */
267
+ #mGPT .highlight .nx {
268
+ color: #a6e22e;
269
+ } /* Name.Other */
270
+ #mGPT .highlight .py {
271
+ color: #f8f8f2;
272
+ } /* Name.Property */
273
+ #mGPT .highlight .nt {
274
+ color: #f92672;
275
+ } /* Name.Tag */
276
+ #mGPT .highlight .nv {
277
+ color: #f8f8f2;
278
+ } /* Name.Variable */
279
+ #mGPT .highlight .ow {
280
+ color: #f92672;
281
+ } /* Operator.Word */
282
+ #mGPT .highlight .w {
283
+ color: #f8f8f2;
284
+ } /* Text.Whitespace */
285
+ #mGPT .highlight .mb {
286
+ color: #ae81ff;
287
+ } /* Literal.Number.Bin */
288
+ #mGPT .highlight .mf {
289
+ color: #ae81ff;
290
+ } /* Literal.Number.Float */
291
+ #mGPT .highlight .mh {
292
+ color: #ae81ff;
293
+ } /* Literal.Number.Hex */
294
+ #mGPT .highlight .mi {
295
+ color: #ae81ff;
296
+ } /* Literal.Number.Integer */
297
+ #mGPT .highlight .mo {
298
+ color: #ae81ff;
299
+ } /* Literal.Number.Oct */
300
+ #mGPT .highlight .sa {
301
+ color: #e6db74;
302
+ } /* Literal.String.Affix */
303
+ #mGPT .highlight .sb {
304
+ color: #e6db74;
305
+ } /* Literal.String.Backtick */
306
+ #mGPT .highlight .sc {
307
+ color: #e6db74;
308
+ } /* Literal.String.Char */
309
+ #mGPT .highlight .dl {
310
+ color: #e6db74;
311
+ } /* Literal.String.Delimiter */
312
+ #mGPT .highlight .sd {
313
+ color: #e6db74;
314
+ } /* Literal.String.Doc */
315
+ #mGPT .highlight .s2 {
316
+ color: #e6db74;
317
+ } /* Literal.String.Double */
318
+ #mGPT .highlight .se {
319
+ color: #ae81ff;
320
+ } /* Literal.String.Escape */
321
+ #mGPT .highlight .sh {
322
+ color: #e6db74;
323
+ } /* Literal.String.Heredoc */
324
+ #mGPT .highlight .si {
325
+ color: #e6db74;
326
+ } /* Literal.String.Interpol */
327
+ #mGPT .highlight .sx {
328
+ color: #e6db74;
329
+ } /* Literal.String.Other */
330
+ #mGPT .highlight .sr {
331
+ color: #e6db74;
332
+ } /* Literal.String.Regex */
333
+ #mGPT .highlight .s1 {
334
+ color: #e6db74;
335
+ } /* Literal.String.Single */
336
+ #mGPT .highlight .ss {
337
+ color: #e6db74;
338
+ } /* Literal.String.Symbol */
339
+ #mGPT .highlight .bp {
340
+ color: #f8f8f2;
341
+ } /* Name.Builtin.Pseudo */
342
+ #mGPT .highlight .fm {
343
+ color: #a6e22e;
344
+ } /* Name.Function.Magic */
345
+ #mGPT .highlight .vc {
346
+ color: #f8f8f2;
347
+ } /* Name.Variable.Class */
348
+ #mGPT .highlight .vg {
349
+ color: #f8f8f2;
350
+ } /* Name.Variable.Global */
351
+ #mGPT .highlight .vi {
352
+ color: #f8f8f2;
353
+ } /* Name.Variable.Instance */
354
+ #mGPT .highlight .vm {
355
+ color: #f8f8f2;
356
+ } /* Name.Variable.Magic */
357
+ #mGPT .highlight .il {
358
+ color: #ae81ff;
359
+ } /* Literal.Number.Integer.Long */
assets/images/avatar_bot.jpg ADDED
assets/meta/mean.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bdb5ba69a3a9e34d71990db15bc535ebc024c8d95ddb5574196f96058faa7d3
3
+ size 2232
assets/meta/mean_eval.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bdb5ba69a3a9e34d71990db15bc535ebc024c8d95ddb5574196f96058faa7d3
3
+ size 2232
assets/meta/std.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a5f7d60301c9465972fc225f8ad0ee8f957e7720431189123eb6d15873a9557
3
+ size 2232
assets/meta/std_eval.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a5f7d60301c9465972fc225f8ad0ee8f957e7720431189123eb6d15873a9557
3
+ size 2232
assets/videos/m2t_0.mp4 ADDED
Binary file (500 kB). View file
 
assets/videos/t2m_0.mp4 ADDED
Binary file (811 kB). View file
 
configs/assets.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONFIG_FOLDER: configs # Config files path
2
+ FOLDER: experiments # Experiment files saving path
3
+
4
+ TEST:
5
+ FOLDER: results # Testing files saving path
6
+
7
+ DATASET:
8
+ TASK_ROOT: deps/mGPT_instructions
9
+ SMPL_PATH: deps/smpl
10
+ TRANSFORM_PATH: deps/transforms/
11
+ WORD_VERTILIZER_PATH: deps/glove/
12
+ KIT:
13
+ ROOT: datasets/kit-ml # KIT directory
14
+ SPLIT_ROOT: datasets/kit-ml # KIT splits directory
15
+ MEAN_STD_PATH: deps/t2m/
16
+ HUMANML3D:
17
+ ROOT: datasets/humanml3d # HumanML3D directory
18
+ SPLIT_ROOT: datasets/humanml3d # HumanML3D splits directory
19
+ MEAN_STD_PATH: deps/t2m/
20
+
21
+ METRIC:
22
+ TM2T:
23
+ t2m_path: deps/t2m/ # path for tm2t evaluator
24
+
25
+ model:
26
+ whisper_path: openai/whisper-large-v2 # path for whisper model, webui only
27
+
28
+ RENDER:
29
+ BLENDER_PATH: libs/blender-2.93.2-linux-x64/blender
30
+ SMPL_MODEL_PATH: deps/smpl/smpl_models/smpl
31
+ MODEL_PATH: deps/smpl/smpl_models/
32
+ FACES_PATH: deps/smplh/smplh.faces
configs/default.yaml ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SEED_VALUE: 1234 # Seed value
2
+ DEBUG: True # Debug mode
3
+ FULL_CONFIG: false
4
+
5
+ TRAIN:
6
+ SPLIT: 'train' # Training split name
7
+ NUM_WORKERS: 8 # Number of workers
8
+ BATCH_SIZE: 8 # Size of batches
9
+ END_EPOCH: 2000 # End epoch
10
+
11
+ RESUME: '' # Experiment path to be resumed training
12
+ PRETRAINED_VAE: '' # Pretrained vae/vqvae model path
13
+ PRETRAINED: '' # Pretrained model path
14
+
15
+ OPTIM:
16
+ target: AdamW
17
+ params:
18
+ lr: 2e-4
19
+ betas: [0.9, 0.99]
20
+ weight_decay: 0.0
21
+
22
+ LR_SCHEDULER:
23
+ target: CosineAnnealingLR
24
+ params:
25
+ T_max: ${eval:${LOGGER.VAL_EVERY_STEPS} * 100}
26
+ eta_min: 1e-6
27
+
28
+ EVAL:
29
+ SPLIT: 'val' # Validation split name
30
+ BATCH_SIZE: 16 # Validation Batch size
31
+ NUM_WORKERS: 8 # Validation Batch size
32
+
33
+ TEST:
34
+ CHECKPOINTS: '' # Pretrained model path
35
+ SPLIT: 'test' # Testing split name
36
+ BATCH_SIZE: 16 # Testing Batch size
37
+ NUM_WORKERS: 8 # Testing Batch size
38
+
39
+ SAVE_PREDICTIONS: False # Weather to save predictions
40
+ COUNT_TIME: False # Weather to count time during test
41
+ REPLICATION_TIMES: 20 # Number of times to replicate the test
42
+ REP_I: 0 # For counting replication times
43
+
44
+ model:
45
+ target: mGPT.models.mgpt.MotionGPT
46
+ params:
47
+ condition: 'text'
48
+ task: 't2m'
49
+ lm: ${lm.default}
50
+ motion_vae: ${vq.default}
51
+
52
+ # Related parameters
53
+ stage: ${TRAIN.STAGE}
54
+ debug: ${DEBUG}
55
+ codebook_size: ${model.params.motion_vae.params.code_num}
56
+ metrics_dict: ${METRIC.TYPE}
57
+
58
+ LOSS:
59
+ LAMBDA_REC: 1.0 # Lambda for reconstruction losses
60
+ LAMBDA_JOINT: 1.0 # Lambda for joint losses
61
+
62
+ LAMBDA_LATENT: 1e-5 # Lambda for latent losses
63
+ LAMBDA_KL: 1e-5 # Lambda for kl losses
64
+ LAMBDA_GEN: 1.0 # Lambda for text-motion generation losses
65
+ LAMBDA_CROSS: 1.0 # Lambda for cross-reconstruction losses
66
+ LAMBDA_CYCLE: 1.0 # Lambda for cycle losses
67
+ LAMBDA_PRIOR: 0.0 # Lambda for diffusion prior losses
68
+
69
+ LAMBDA_VELOCITY: 0.5 # Lambda for velocity losses
70
+ LAMBDA_COMMIT: 0.02 # Lambda for commitment losses
71
+
72
+ ABLATION:
73
+ RECONS_LOSS: 'l1_smooth'
74
+
75
+ METRIC:
76
+ TASK: 't2m'
77
+ FORCE_IN_METER: True
78
+ DIST_SYNC_ON_STEP: True
79
+ MM_NUM_SAMPLES: 100 # Number of samples for multimodal test
80
+ MM_NUM_REPEATS: 30 # Number of repeats for multimodal test
81
+ MM_NUM_TIMES: 10 # Number of times to repeat the multimodal test
82
+ DIVERSITY_TIMES: 300 # Number of times to repeat the diversity test
83
+ TM2T: ${evaluator.tm2t}
84
+
85
+ DATASET:
86
+ target: mGPT.data.HumanML3D.HumanML3DDataModule
87
+ CODE_PATH: 'VQVAE'
88
+ TASK_ROOT: ''
89
+ TASK_PATH: ''
90
+ NFEATS: 263
91
+ KIT:
92
+ MAX_MOTION_LEN: 196
93
+ MIN_MOTION_LEN: 24
94
+ MAX_TEXT_LEN: 20
95
+ PICK_ONE_TEXT: true
96
+ FRAME_RATE: 12.5
97
+ UNIT_LEN: 4
98
+ HUMANML3D:
99
+ MAX_MOTION_LEN: 196
100
+ MIN_MOTION_LEN: 40
101
+ MAX_TEXT_LEN: 20
102
+ PICK_ONE_TEXT: true
103
+ FRAME_RATE: 20.0
104
+ UNIT_LEN: 4
105
+ STD_TEXT: False
106
+
107
+ ABLATION:
108
+ # For MotionGPT
109
+ use_length: False
110
+ predict_ratio: 0.2
111
+ inbetween_ratio: 0.25
112
+ image_size: 256
113
+
114
+ # For Motion-latent-diffusion
115
+ VAE_TYPE: 'actor' # vae ablation: actor or mcross
116
+ VAE_ARCH: 'encoder_decoder' # mdiffusion vae architecture
117
+ PE_TYPE: 'actor' # mdiffusion mld or actor
118
+ DIFF_PE_TYPE: 'actor' # mdiffusion mld or actor
119
+ SKIP_CONNECT: False # skip connection for denoiser va
120
+ MLP_DIST: False # use linear to expand mean and std rather expand token nums
121
+ IS_DIST: False # Mcross distribution kl
122
+ PREDICT_EPSILON: True # noise or motion
123
+
124
+ LOGGER:
125
+ VAL_EVERY_STEPS: 10
126
+ LOGGERS: ['tensorboard', 'wandb']
127
+ TENSORBOARD:
128
+ target: pytorch_lightning.loggers.TensorBoardLogger
129
+ params:
130
+ save_dir: ${FOLDER_EXP}
131
+ name: 'tensorboard'
132
+ version: ''
133
+ WANDB:
134
+ target: pytorch_lightning.loggers.WandbLogger
135
+ params:
136
+ project: null
137
+ offline: False
138
+ id: null
139
+ version: ''
140
+ name: ${NAME}
141
+ save_dir: ${FOLDER_EXP}
configs/evaluator/tm2t.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ t2m_textencoder:
2
+ target: mGPT.archs.tm2t_evaluator.TextEncoderBiGRUCo
3
+ params:
4
+ word_size: 300
5
+ pos_size: 15
6
+ hidden_size: 512
7
+ output_size: 512
8
+ t2m_moveencoder:
9
+ target: mGPT.archs.tm2t_evaluator.MovementConvEncoder
10
+ params:
11
+ input_size: ${eval:${DATASET.NFEATS} - 4}
12
+ hidden_size: 512
13
+ output_size: 512
14
+ t2m_motionencoder:
15
+ target: mGPT.archs.tm2t_evaluator.MotionEncoderBiGRUCo
16
+ params:
17
+ input_size: ${evaluator.tm2t.t2m_moveencoder.params.output_size}
18
+ hidden_size: 1024
19
+ output_size: 512
configs/lm/default.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ target: mGPT.archs.mgpt_lm.MLM
2
+ params:
3
+ model_type: t5
4
+ model_path: google/flan-t5-base
5
+ stage: ${TRAIN.STAGE}
6
+ motion_codebook_size: ${model.params.codebook_size}
7
+ ablation: ${ABLATION}
configs/render.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NAME: '___render_do_not_need_name__' # Experiment name
2
+ ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
3
+ DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3]
4
+
5
+ RENDER:
6
+ FOLDER: '___no_need__'
7
+ INPUT_MODE: 'npy'
8
+ DIR: ''
9
+ NPY: '___no_need__'
10
+ DENOISING: True
11
+ OLDRENDER: True
12
+ # ["ultra", "high", "med", "low"]
13
+ # RES: 'high'
14
+ RES: 'med'
15
+ DOWNSAMPLE: False
16
+ FPS: 20.0
17
+ CANONICALIZE: True
18
+ EXACT_FRAME: 0.5
19
+ NUM: 8
20
+ MODE: '___no_need__' #sequence frame video
21
+ VID_EXT: mp4
22
+ ALWAYS_ON_FLOOR: false
23
+ GT: false
configs/vq/default.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ target: mGPT.archs.mgpt_vq.VQVae
2
+ params:
3
+ quantizer: 'ema_reset'
4
+ code_num: 512
5
+ code_dim: 512
6
+ output_emb_width: 512
7
+ down_t: 2
8
+ stride_t: 2
9
+ width: 512
10
+ depth: 3
11
+ dilation_growth_rate: 3
12
+ norm: None
13
+ activation: 'relu'
14
+ nfeats: ${DATASET.NFEATS}
15
+ ablation: ${ABLATION}
configs/webui.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NAME: Webui # Experiment name
2
+ DEBUG: False # Debug mode
3
+ ACCELERATOR: 'cpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
4
+ DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3]
5
+
6
+ # Training configuration
7
+ TRAIN:
8
+ #---------------------------------
9
+ STAGE: lm_instruct
10
+ DATASETS: ['humanml3d'] # Training datasets
11
+ NUM_WORKERS: 32 # Number of workers
12
+ BATCH_SIZE: 16 # Size of batches
13
+ START_EPOCH: 0 # Start epochMMOTIONENCODER
14
+ END_EPOCH: 99999 # End epoch
15
+ ABLATION:
16
+ pkeep: 0.5
17
+ OPTIM:
18
+ TYPE: AdamW # Optimizer type
19
+ LR: 2e-4 # Learning rate
20
+ WEIGHT_DECAY: 0.0
21
+ LR_SCHEDULER: [100, 200, 300, 400]
22
+ GAMMA: 0.8
23
+
24
+ # Evaluating Configuration
25
+ EVAL:
26
+ DATASETS: ['humanml3d'] # Evaluating datasets
27
+ BATCH_SIZE: 32 # Evaluating Batch size
28
+ SPLIT: test
29
+
30
+ # Test Configuration
31
+ TEST:
32
+ CHECKPOINTS: checkpoints/MotionGPT-base/motiongpt_s3_h3d.ckpt
33
+ DATASETS: ['humanml3d'] # training datasets
34
+ SPLIT: test
35
+ BATCH_SIZE: 32 # training Batch size
36
+ MEAN: False
37
+ NUM_SAMPLES: 1
38
+ FACT: 1
39
+
40
+ # Datasets Configuration
41
+ DATASET:
42
+ JOINT_TYPE: 'humanml3d' # join type
43
+ CODE_PATH: 'VQBEST'
44
+ METRIC:
45
+ TYPE: ['TM2TMetrics']
46
+ # Losses Configuration
47
+ LOSS:
48
+ TYPE: t2mgpt # Losses type
49
+ LAMBDA_FEATURE: 1.0
50
+ LAMBDA_VELOCITY: 0.5
51
+ LAMBDA_COMMIT: 0.02
52
+ LAMBDA_CLS: 1.0
53
+ LAMBDA_M2T2M: 1.0
54
+ LAMBDA_T2M2T: 10.0
55
+ ABLATION:
56
+ RECONS_LOSS: 'l1_smooth'
57
+
58
+ # Model Configuration
59
+ model:
60
+ target: mGPT.models.mgpt.MotionGPT
61
+ params:
62
+ condition: 'text'
63
+ task: 't2m'
64
+ lm: ${lm.default}
65
+ motion_vae: ${vq.default}
66
+
67
+ # Logger configuration
68
+ LOGGER:
69
+ LOG_EVERY_STEPS: 5
70
+ VAL_EVERY_STEPS: 10
71
+ TENSORBOARD: True
72
+ wandb:
73
+ params:
74
+ project: null
deps/smpl/smpl_models/SMPL_downsample_index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5b783c1677079397ee4bc26df5c72d73b8bb393bea41fa295b951187443daec
3
+ size 3556
deps/smpl/smpl_models/gmm_08.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1374908aae055a2afa01a2cd9a169bc6cfec1ceb7aa590e201a47b383060491
3
+ size 839127
deps/smpl/smpl_models/neutral_smpl_mean_params.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9b474c74daec0253ed084720f662059336e976850f08a4a9a3f76d06613776
3
+ size 4848
deps/smpl/smpl_models/smpl.faces ADDED
Binary file (331 kB). View file
 
deps/smpl/smpl_models/smpl.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf4793af6b29677b0841c58db392642cb70b477890dc91de01128c7f34738d8d
3
+ size 45
deps/smpl/smpl_models/smpl/SMPL_FEMALE.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a583c1b98e4afc19042641f1bae5cd8a1f712a6724886291a7627ec07acd408d
3
+ size 39056454
deps/smpl/smpl_models/smpl/SMPL_MALE.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e8c0bbbbc635dcb166ed29c303fb4bef16ea5f623e5a89263495a9e403575bd
3
+ size 39056404
deps/smpl/smpl_models/smpl/SMPL_NEUTRAL.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98e65c74ad9b998783132f00880d1025a8d64b158e040e6ef13a557e5098bc42
3
+ size 39001280
deps/smpl/smpl_models/smpl/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ This directory leaves for SMPL models
deps/smpl/smpl_models/smplh/SMPLH_FEMALE.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0fba73ef2494b26de243c1d88a1dbe1047e5566128cf7222c942089543f4560
3
+ size 39708434
deps/smpl/smpl_models/smplh/SMPLH_MALE.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10b617fdd329557937d6fe38e8a542afab236a8887522d9da0bd42e7f2b76eaa
3
+ size 39686902
deps/smpl/smpl_models/smplh/SMPLH_NEUTRAL.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42969b34d8cd383e172515a7bca6ff3b2c37aa2c5c78088c69d20e517fa96026
3
+ size 39708959
deps/smpl/smpl_models/smplh/mano_v1_2.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50976831790ea9657d8110e0c94e50e90eaf35cd76169f0b27e5d32f3fcd951f
3
+ size 175200815
deps/smpl/smpl_models/smplh/smplh.faces ADDED
Binary file (165 kB). View file
 
deps/smpl/smpl_models/smplh/smplh.tar.xz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46d5b8687be48c91181fa88271feff3a5e83aa62a481fad8a0bcb9254b2a74f1
3
+ size 113231292
deps/smpl/smpl_models/smplx_parts_segm.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb69c10801205c9cfb5353fdeb1b9cc5ade53d14c265c3339421cdde8b9c91e7
3
+ size 1323168
mGPT/__init__.py ADDED
File without changes
mGPT/archs/__init__.py ADDED
File without changes
mGPT/archs/mgpt_lm.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+ import numpy as np
4
+ import math
5
+ import time
6
+ import heapq
7
+ import torch
8
+ from torch import Tensor, nn
9
+ from torch.distributions.distribution import Distribution
10
+ from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer
11
+ import random
12
+ from typing import Optional
13
+ from .tools.token_emb import NewTokenEmb
14
+
15
+
16
+ class MLM(nn.Module):
17
+
18
+ def __init__(
19
+ self,
20
+ model_path: str,
21
+ model_type: str = "t5",
22
+ stage: str = "lm_pretrain",
23
+ new_token_type: str = "insert",
24
+ motion_codebook_size: int = 512,
25
+ framerate: float = 20.0,
26
+ down_t: int = 4,
27
+ predict_ratio: float = 0.2,
28
+ inbetween_ratio: float = 0.25,
29
+ max_length: int = 256,
30
+ lora: bool = False,
31
+ quota_ratio: float = 0.5,
32
+ noise_density: float = 0.15,
33
+ mean_noise_span_length: int = 3,
34
+ **kwargs,
35
+ ) -> None:
36
+
37
+ super().__init__()
38
+
39
+ # Parameters
40
+ self.m_codebook_size = motion_codebook_size
41
+ self.max_length = max_length
42
+ self.framerate = framerate
43
+ self.down_t = down_t
44
+ self.predict_ratio = predict_ratio
45
+ self.inbetween_ratio = inbetween_ratio
46
+ self.noise_density = noise_density
47
+ self.mean_noise_span_length = mean_noise_span_length
48
+ self.quota_ratio = quota_ratio
49
+ self.stage = stage
50
+
51
+ # Instantiate language model
52
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=True)
53
+ if model_type == "t5":
54
+ self.language_model = T5ForConditionalGeneration.from_pretrained(
55
+ model_path)
56
+ self.lm_type = 'encdec'
57
+ elif model_type == "gpt2":
58
+ self.language_model = GPT2LMHeadModel.from_pretrained(model_path)
59
+ self.lm_type = 'dec'
60
+ else:
61
+ raise ValueError("type must be either seq2seq or conditional")
62
+
63
+ if self.lm_type == 'dec':
64
+ self.tokenizer.pad_token = self.tokenizer.eos_token
65
+
66
+ # Add motion tokens
67
+ self.tokenizer.add_tokens(
68
+ [f'<motion_id_{i}>' for i in range(self.m_codebook_size + 3)])
69
+
70
+ if new_token_type == "insert":
71
+ self.language_model.resize_token_embeddings(len(self.tokenizer))
72
+ elif new_token_type == "mlp":
73
+ shared = NewTokenEmb(self.language_model.shared,
74
+ self.m_codebook_size + 3)
75
+ # lm_head = NewTokenEmb(self.language_model.lm_head,
76
+ # self.m_codebook_size + 3)
77
+ self.language_model.resize_token_embeddings(len(self.tokenizer))
78
+ self.language_model.shared = shared
79
+ # self.language_model.lm_head = lm_head
80
+
81
+ # Lora
82
+ if lora:
83
+ from peft import LoraConfig, TaskType, get_peft_model, get_peft_model_state_dict
84
+ from peft.utils.other import fsdp_auto_wrap_policy
85
+ peft_config = LoraConfig(
86
+ bias="none",
87
+ task_type="CAUSAL_LM",
88
+ # inference_mode=False,
89
+ r=8,
90
+ lora_alpha=16,
91
+ lora_dropout=0.05)
92
+ self.language_model = get_peft_model(self.language_model,
93
+ peft_config)
94
+
95
+ def forward(self, texts: List[str], motion_tokens: Tensor,
96
+ lengths: List[int], tasks: dict):
97
+ if self.lm_type == 'encdec':
98
+ return self.forward_encdec(texts, motion_tokens, lengths, tasks)
99
+ elif self.lm_type == 'dec':
100
+ return self.forward_dec(texts, motion_tokens, lengths, tasks)
101
+ else:
102
+ raise NotImplementedError("Only conditional_multitask supported")
103
+
104
+ def forward_encdec(
105
+ self,
106
+ texts: List[str],
107
+ motion_tokens: Tensor,
108
+ lengths: List[int],
109
+ tasks: dict,
110
+ ):
111
+
112
+ # Tensor to string
113
+ motion_strings = self.motion_token_to_string(motion_tokens, lengths)
114
+
115
+ # Supervised or unsupervised
116
+ # condition = random.choice(
117
+ # ['text', 'motion', 'supervised', 'supervised', 'supervised'])
118
+ condition = random.choice(['supervised', 'supervised', 'supervised'])
119
+
120
+ if condition == 'text':
121
+ inputs = texts
122
+ outputs = texts
123
+ elif condition == 'motion':
124
+ inputs = motion_strings
125
+ outputs = motion_strings
126
+ else:
127
+ inputs, outputs = self.template_fulfill(tasks, lengths,
128
+ motion_strings, texts)
129
+
130
+ # Tokenize
131
+ source_encoding = self.tokenizer(inputs,
132
+ padding='max_length',
133
+ max_length=self.max_length,
134
+ truncation=True,
135
+ return_attention_mask=True,
136
+ add_special_tokens=True,
137
+ return_tensors="pt")
138
+
139
+ source_attention_mask = source_encoding.attention_mask.to(
140
+ motion_tokens.device)
141
+ source_input_ids = source_encoding.input_ids.to(motion_tokens.device)
142
+
143
+ if condition in ['text', 'motion']:
144
+ batch_size, expandend_input_length = source_input_ids.shape
145
+ mask_indices = np.asarray([
146
+ self.random_spans_noise_mask(expandend_input_length)
147
+ for i in range(batch_size)
148
+ ])
149
+ target_mask = ~mask_indices
150
+ input_ids_sentinel = self.create_sentinel_ids(
151
+ mask_indices.astype(np.int8))
152
+ target_sentinel = self.create_sentinel_ids(
153
+ target_mask.astype(np.int8))
154
+
155
+ labels_input_ids = self.filter_input_ids(source_input_ids,
156
+ target_sentinel)
157
+ source_input_ids = self.filter_input_ids(source_input_ids,
158
+ input_ids_sentinel)
159
+
160
+ else:
161
+ target_inputs = self.tokenizer(outputs,
162
+ padding='max_length',
163
+ max_length=self.max_length,
164
+ truncation=True,
165
+ return_attention_mask=True,
166
+ add_special_tokens=True,
167
+ return_tensors="pt")
168
+
169
+ labels_input_ids = target_inputs.input_ids.to(motion_tokens.device)
170
+ lables_attention_mask = target_inputs.attention_mask.to(
171
+ motion_tokens.device)
172
+
173
+ labels_input_ids[labels_input_ids == 0] = -100
174
+ outputs = self.language_model(
175
+ input_ids=source_input_ids,
176
+ attention_mask=source_attention_mask
177
+ if condition == 'supervised' else None,
178
+ labels=labels_input_ids,
179
+ decoder_attention_mask=lables_attention_mask
180
+ if condition == 'supervised' else None,
181
+ )
182
+
183
+ return outputs
184
+
185
+ def forward_dec(
186
+ self,
187
+ texts: List[str],
188
+ motion_tokens: Tensor,
189
+ lengths: List[int],
190
+ tasks: dict,
191
+ ):
192
+ self.tokenizer.padding_side = "right"
193
+
194
+ # Tensor to string
195
+ motion_strings = self.motion_token_to_string(motion_tokens, lengths)
196
+
197
+ # Supervised or unsupervised
198
+ condition = random.choice(
199
+ ['text', 'motion', 'supervised', 'supervised', 'supervised'])
200
+
201
+ if condition == 'text':
202
+ labels = texts
203
+ elif condition == 'motion':
204
+ labels = motion_strings
205
+ else:
206
+ inputs, outputs = self.template_fulfill(tasks, lengths,
207
+ motion_strings, texts)
208
+ labels = []
209
+ for i in range(len(inputs)):
210
+ labels.append(inputs[i] + ' \n ' + outputs[i] +
211
+ self.tokenizer.eos_token)
212
+
213
+ # Tokenize
214
+ inputs = self.tokenizer(labels,
215
+ padding='max_length',
216
+ max_length=self.max_length,
217
+ truncation=True,
218
+ return_attention_mask=True,
219
+ return_tensors="pt")
220
+
221
+ labels_input_ids = inputs.input_ids.to(motion_tokens.device)
222
+ lables_attention_mask = inputs.attention_mask.to(motion_tokens.device)
223
+
224
+ # print(labels_input_ids[0:5])
225
+
226
+ outputs = self.language_model(input_ids=labels_input_ids,
227
+ attention_mask=lables_attention_mask,
228
+ labels=inputs["input_ids"])
229
+
230
+ return outputs
231
+
232
+ def generate_direct(self,
233
+ texts: List[str],
234
+ max_length: int = 256,
235
+ num_beams: int = 1,
236
+ do_sample: bool = True,
237
+ bad_words_ids: List[int] = None):
238
+
239
+ # Device
240
+ self.device = self.language_model.device
241
+
242
+ # Tokenize
243
+ if self.lm_type == 'dec':
244
+ texts = [text + " \n " for text in texts]
245
+
246
+ source_encoding = self.tokenizer(texts,
247
+ padding='max_length',
248
+ max_length=self.max_length,
249
+ truncation=True,
250
+ return_attention_mask=True,
251
+ add_special_tokens=True,
252
+ return_tensors="pt")
253
+
254
+ source_input_ids = source_encoding.input_ids.to(self.device)
255
+ source_attention_mask = source_encoding.attention_mask.to(self.device)
256
+
257
+ if self.lm_type == 'encdec':
258
+ outputs = self.language_model.generate(
259
+ source_input_ids,
260
+ max_length=max_length,
261
+ num_beams=num_beams,
262
+ do_sample=do_sample,
263
+ bad_words_ids=bad_words_ids,
264
+ )
265
+ elif self.lm_type == 'dec':
266
+ outputs = self.language_model.generate(
267
+ input_ids=source_input_ids,
268
+ attention_mask=source_attention_mask,
269
+ pad_token_id=self.tokenizer.pad_token_id,
270
+ do_sample=do_sample,
271
+ max_new_tokens=max_length)
272
+ self.tokenizer.padding_side = 'left'
273
+
274
+ outputs_string = self.tokenizer.batch_decode(outputs,
275
+ skip_special_tokens=True)
276
+
277
+ print(texts[:2])
278
+ print(outputs_string[:2])
279
+
280
+ outputs_tokens, cleaned_text = self.motion_string_to_token(
281
+ outputs_string)
282
+
283
+ return outputs_tokens, cleaned_text
284
+
285
+ def generate_conditional(self,
286
+ texts: Optional[List[str]] = None,
287
+ motion_tokens: Optional[Tensor] = None,
288
+ lengths: Optional[List[int]] = None,
289
+ task: str = "t2m",
290
+ with_len: bool = False,
291
+ stage: str = 'train',
292
+ tasks: dict = None):
293
+
294
+ self.device = self.language_model.device
295
+
296
+ if task in ["t2m", "m2m", "pred", "inbetween"]:
297
+
298
+ if task == "t2m":
299
+ assert texts is not None
300
+ motion_strings = [''] * len(texts)
301
+ if not with_len:
302
+ if tasks is None:
303
+ tasks = [{
304
+ 'input':
305
+ ['Generate motion: <Caption_Placeholder>'],
306
+ 'output': ['']
307
+ }] * len(texts)
308
+
309
+ lengths = [0] * len(texts)
310
+ else:
311
+ tasks = [{
312
+ 'input': [
313
+ 'Generate motion with <Frame_Placeholder> frames: <Caption_Placeholder>'
314
+ ],
315
+ 'output': ['']
316
+ }] * len(texts)
317
+
318
+ elif task == "pred":
319
+ assert motion_tokens is not None and lengths is not None
320
+ texts = [''] * len(lengths)
321
+ tasks = [{
322
+ 'input': ['Predict motion: <Motion_Placeholder_s1>'],
323
+ 'output': ['']
324
+ }] * len(lengths)
325
+
326
+ motion_strings_old = self.motion_token_to_string(
327
+ motion_tokens, lengths)
328
+ motion_strings = []
329
+ for i, length in enumerate(lengths):
330
+ split = length // 5
331
+ motion_strings.append(
332
+ '>'.join(motion_strings_old[i].split('>')[:split]) +
333
+ '>')
334
+
335
+ elif task == "inbetween":
336
+ assert motion_tokens is not None and lengths is not None
337
+ texts = [''] * len(lengths)
338
+ tasks = [{
339
+ 'input': [
340
+ "Complete the masked motion: <Motion_Placeholder_Masked>"
341
+ ],
342
+ 'output': ['']
343
+ }] * len(lengths)
344
+ motion_strings = self.motion_token_to_string(
345
+ motion_tokens, lengths)
346
+
347
+ inputs, outputs = self.template_fulfill(tasks, lengths,
348
+ motion_strings, texts,
349
+ stage)
350
+
351
+ outputs_tokens, cleaned_text = self.generate_direct(inputs,
352
+ max_length=128,
353
+ num_beams=1,
354
+ do_sample=True)
355
+
356
+ return outputs_tokens
357
+
358
+ elif task == "m2t":
359
+ assert motion_tokens is not None and lengths is not None
360
+
361
+ motion_strings = self.motion_token_to_string(
362
+ motion_tokens, lengths)
363
+
364
+ if not with_len:
365
+ tasks = [{
366
+ 'input': ['Generate text: <Motion_Placeholder>'],
367
+ 'output': ['']
368
+ }] * len(lengths)
369
+ else:
370
+ tasks = [{
371
+ 'input': [
372
+ 'Generate text with <Frame_Placeholder> frames: <Motion_Placeholder>'
373
+ ],
374
+ 'output': ['']
375
+ }] * len(lengths)
376
+
377
+ texts = [''] * len(lengths)
378
+
379
+ inputs, outputs = self.template_fulfill(tasks, lengths,
380
+ motion_strings, texts)
381
+ outputs_tokens, cleaned_text = self.generate_direct(
382
+ inputs,
383
+ max_length=40,
384
+ num_beams=1,
385
+ do_sample=False,
386
+ # bad_words_ids=self.bad_words_ids
387
+ )
388
+ return cleaned_text
389
+
390
+ def motion_token_to_string(self, motion_token: Tensor, lengths: List[int]):
391
+ motion_string = []
392
+ for i in range(len(motion_token)):
393
+ motion_i = motion_token[i].cpu(
394
+ ) if motion_token[i].device.type == 'cuda' else motion_token[i]
395
+ motion_list = motion_i.tolist()[:lengths[i]]
396
+ motion_string.append(
397
+ (f'<motion_id_{self.m_codebook_size}>' +
398
+ ''.join([f'<motion_id_{int(i)}>' for i in motion_list]) +
399
+ f'<motion_id_{self.m_codebook_size + 1}>'))
400
+ return motion_string
401
+
402
+ def motion_token_list_to_string(self, motion_token: Tensor):
403
+ motion_string = []
404
+ for i in range(len(motion_token)):
405
+ motion_i = motion_token[i].cpu(
406
+ ) if motion_token[i].device.type == 'cuda' else motion_token[i]
407
+ motion_list = motion_i.tolist()
408
+ motion_string.append(
409
+ (f'<motion_id_{self.m_codebook_size}>' +
410
+ ''.join([f'<motion_id_{int(i)}>' for i in motion_list]) +
411
+ f'<motion_id_{self.m_codebook_size + 1}>'))
412
+ return motion_string
413
+
414
+ def motion_string_to_token(self, motion_string: List[str]):
415
+ motion_tokens = []
416
+ output_string = []
417
+ for i in range(len(motion_string)):
418
+ string = self.get_middle_str(
419
+ motion_string[i], f'<motion_id_{self.m_codebook_size}>',
420
+ f'<motion_id_{self.m_codebook_size + 1}>')
421
+ string_list = string.split('><')
422
+ token_list = [
423
+ int(i.split('_')[-1].replace('>', ''))
424
+ for i in string_list[1:-1]
425
+ ]
426
+ if len(token_list) == 0:
427
+ token_list = [0]
428
+ token_list_padded = torch.tensor(token_list,
429
+ dtype=int).to(self.device)
430
+ motion_tokens.append(token_list_padded)
431
+ output_string.append(motion_string[i].replace(
432
+ string, '<Motion_Placeholder>'))
433
+
434
+ return motion_tokens, output_string
435
+
436
+ def placeholder_fulfill(self, prompt: str, length: int, motion_string: str,
437
+ text: str):
438
+
439
+ seconds = math.floor(length / self.framerate)
440
+ motion_splited = motion_string.split('>')
441
+ token_length = length / self.down_t
442
+ predict_head = int(token_length * self.predict_ratio + 1)
443
+ masked_head = int(token_length * self.inbetween_ratio + 1)
444
+ masked_tail = int(token_length * (1 - self.inbetween_ratio) + 1)
445
+
446
+ motion_predict_head = '>'.join(
447
+ motion_splited[:predict_head]
448
+ ) + f'><motion_id_{self.m_codebook_size+1}>'
449
+ motion_predict_last = f'<motion_id_{self.m_codebook_size}>' + '>'.join(
450
+ motion_splited[predict_head:])
451
+
452
+ motion_masked = '>'.join(
453
+ motion_splited[:masked_head]
454
+ ) + '>' + f'<motion_id_{self.m_codebook_size+2}>' * (
455
+ masked_tail - masked_head) + '>'.join(motion_splited[masked_tail:])
456
+
457
+ if random.random() < self.quota_ratio:
458
+ text = f'\"{text}\"'
459
+
460
+ prompt = prompt.replace('<Caption_Placeholder>', text).replace(
461
+ '<Motion_Placeholder>',
462
+ motion_string).replace('<Frame_Placeholder>', f'{length}').replace(
463
+ '<Second_Placeholder>', '%.1f' % seconds).replace(
464
+ '<Motion_Placeholder_s1>', motion_predict_head).replace(
465
+ '<Motion_Placeholder_s2>',
466
+ motion_predict_last).replace(
467
+ '<Motion_Placeholder_Masked>', motion_masked)
468
+
469
+ return prompt
470
+
471
+ def template_fulfill(self,
472
+ tasks,
473
+ lengths,
474
+ motion_strings,
475
+ texts,
476
+ stage='test'):
477
+ inputs = []
478
+ outputs = []
479
+ for i in range(len(lengths)):
480
+ input_template = random.choice(tasks[i]['input'])
481
+ output_template = random.choice(tasks[i]['output'])
482
+ length = lengths[i]
483
+ inputs.append(
484
+ self.placeholder_fulfill(input_template, length,
485
+ motion_strings[i], texts[i]))
486
+ outputs.append(
487
+ self.placeholder_fulfill(output_template, length,
488
+ motion_strings[i], texts[i]))
489
+
490
+ return inputs, outputs
491
+
492
+ def get_middle_str(self, content, startStr, endStr):
493
+ try:
494
+ startIndex = content.index(startStr)
495
+ if startIndex >= 0:
496
+ startIndex += len(startStr)
497
+ endIndex = content.index(endStr)
498
+ except:
499
+ return f'<motion_id_{self.m_codebook_size}><motion_id_0><motion_id_{self.m_codebook_size+1}>'
500
+
501
+ return f'<motion_id_{self.m_codebook_size}>' + content[
502
+ startIndex:endIndex] + f'<motion_id_{self.m_codebook_size+1}>'
503
+
504
+ def random_spans_noise_mask(self, length):
505
+ # From https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py
506
+
507
+ orig_length = length
508
+
509
+ num_noise_tokens = int(np.round(length * self.noise_density))
510
+ # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
511
+ num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
512
+ num_noise_spans = int(
513
+ np.round(num_noise_tokens / self.mean_noise_span_length))
514
+
515
+ # avoid degeneracy by ensuring positive number of noise spans
516
+ num_noise_spans = max(num_noise_spans, 1)
517
+ num_nonnoise_tokens = length - num_noise_tokens
518
+
519
+ # pick the lengths of the noise spans and the non-noise spans
520
+ def _random_segmentation(num_items, num_segments):
521
+ """Partition a sequence of items randomly into non-empty segments.
522
+ Args:
523
+ num_items: an integer scalar > 0
524
+ num_segments: an integer scalar in [1, num_items]
525
+ Returns:
526
+ a Tensor with shape [num_segments] containing positive integers that add
527
+ up to num_items
528
+ """
529
+ mask_indices = np.arange(num_items - 1) < (num_segments - 1)
530
+ np.random.shuffle(mask_indices)
531
+ first_in_segment = np.pad(mask_indices, [[1, 0]])
532
+ segment_id = np.cumsum(first_in_segment)
533
+ # count length of sub segments assuming that list is sorted
534
+ _, segment_length = np.unique(segment_id, return_counts=True)
535
+ return segment_length
536
+
537
+ noise_span_lengths = _random_segmentation(num_noise_tokens,
538
+ num_noise_spans)
539
+ nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens,
540
+ num_noise_spans)
541
+
542
+ interleaved_span_lengths = np.reshape(
543
+ np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
544
+ [num_noise_spans * 2],
545
+ )
546
+ span_starts = np.cumsum(interleaved_span_lengths)[:-1]
547
+ span_start_indicator = np.zeros((length, ), dtype=np.int8)
548
+ span_start_indicator[span_starts] = True
549
+ span_num = np.cumsum(span_start_indicator)
550
+ is_noise = np.equal(span_num % 2, 1)
551
+
552
+ return is_noise[:orig_length]
553
+
554
+ def create_sentinel_ids(self, mask_indices):
555
+ # From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
556
+ start_indices = mask_indices - np.roll(mask_indices, 1,
557
+ axis=-1) * mask_indices
558
+ start_indices[:, 0] = mask_indices[:, 0]
559
+
560
+ sentinel_ids = np.where(start_indices != 0,
561
+ np.cumsum(start_indices, axis=-1),
562
+ start_indices)
563
+ sentinel_ids = np.where(sentinel_ids != 0,
564
+ (len(self.tokenizer) - sentinel_ids), 0)
565
+ sentinel_ids -= mask_indices - start_indices
566
+
567
+ return sentinel_ids
568
+
569
+ def filter_input_ids(self, input_ids, sentinel_ids):
570
+ # From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
571
+ batch_size = input_ids.shape[0]
572
+
573
+ input_ids_full = np.where(sentinel_ids != 0, sentinel_ids,
574
+ input_ids.to('cpu'))
575
+
576
+ # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
577
+ # masked tokens coming after sentinel tokens and should be removed
578
+ input_ids = input_ids_full[input_ids_full >= 0].reshape(
579
+ (batch_size, -1))
580
+ input_ids = np.concatenate(
581
+ [
582
+ input_ids,
583
+ np.full((batch_size, 1),
584
+ self.tokenizer.eos_token_id,
585
+ dtype=np.int32),
586
+ ],
587
+ axis=-1,
588
+ )
589
+
590
+ input_ids = torch.tensor(input_ids, device=self.device)
591
+
592
+ return input_ids
mGPT/archs/mgpt_vq.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Partially from https://github.com/Mael-zys/T2M-GPT
2
+
3
+ from typing import List, Optional, Union
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor, nn
7
+ from torch.distributions.distribution import Distribution
8
+ from .tools.resnet import Resnet1D
9
+ from .tools.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset
10
+ from collections import OrderedDict
11
+
12
+
13
+ class VQVae(nn.Module):
14
+
15
+ def __init__(self,
16
+ nfeats: int,
17
+ quantizer: str = "ema_reset",
18
+ code_num=512,
19
+ code_dim=512,
20
+ output_emb_width=512,
21
+ down_t=3,
22
+ stride_t=2,
23
+ width=512,
24
+ depth=3,
25
+ dilation_growth_rate=3,
26
+ norm=None,
27
+ activation: str = "relu",
28
+ **kwargs) -> None:
29
+
30
+ super().__init__()
31
+
32
+ self.code_dim = code_dim
33
+
34
+ self.encoder = Encoder(nfeats,
35
+ output_emb_width,
36
+ down_t,
37
+ stride_t,
38
+ width,
39
+ depth,
40
+ dilation_growth_rate,
41
+ activation=activation,
42
+ norm=norm)
43
+
44
+ self.decoder = Decoder(nfeats,
45
+ output_emb_width,
46
+ down_t,
47
+ stride_t,
48
+ width,
49
+ depth,
50
+ dilation_growth_rate,
51
+ activation=activation,
52
+ norm=norm)
53
+
54
+ if quantizer == "ema_reset":
55
+ self.quantizer = QuantizeEMAReset(code_num, code_dim, mu=0.99)
56
+ elif quantizer == "orig":
57
+ self.quantizer = Quantizer(code_num, code_dim, beta=1.0)
58
+ elif quantizer == "ema":
59
+ self.quantizer = QuantizeEMA(code_num, code_dim, mu=0.99)
60
+ elif quantizer == "reset":
61
+ self.quantizer = QuantizeReset(code_num, code_dim)
62
+
63
+ def preprocess(self, x):
64
+ # (bs, T, Jx3) -> (bs, Jx3, T)
65
+ x = x.permute(0, 2, 1)
66
+ return x
67
+
68
+ def postprocess(self, x):
69
+ # (bs, Jx3, T) -> (bs, T, Jx3)
70
+ x = x.permute(0, 2, 1)
71
+ return x
72
+
73
+ def forward(self, features: Tensor):
74
+ # Preprocess
75
+ x_in = self.preprocess(features)
76
+
77
+ # Encode
78
+ x_encoder = self.encoder(x_in)
79
+
80
+ # quantization
81
+ x_quantized, loss, perplexity = self.quantizer(x_encoder)
82
+
83
+ # decoder
84
+ x_decoder = self.decoder(x_quantized)
85
+ x_out = self.postprocess(x_decoder)
86
+
87
+ return x_out, loss, perplexity
88
+
89
+ def encode(
90
+ self,
91
+ features: Tensor,
92
+ ) -> Union[Tensor, Distribution]:
93
+
94
+ N, T, _ = features.shape
95
+ x_in = self.preprocess(features)
96
+ x_encoder = self.encoder(x_in)
97
+ x_encoder = self.postprocess(x_encoder)
98
+ x_encoder = x_encoder.contiguous().view(-1,
99
+ x_encoder.shape[-1]) # (NT, C)
100
+ code_idx = self.quantizer.quantize(x_encoder)
101
+ code_idx = code_idx.view(N, -1)
102
+
103
+ # latent, dist
104
+ return code_idx, None
105
+
106
+ def decode(self, z: Tensor):
107
+
108
+ x_d = self.quantizer.dequantize(z)
109
+ x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
110
+
111
+ # decoder
112
+ x_decoder = self.decoder(x_d)
113
+ x_out = self.postprocess(x_decoder)
114
+ return x_out
115
+
116
+
117
+ class Encoder(nn.Module):
118
+
119
+ def __init__(self,
120
+ input_emb_width=3,
121
+ output_emb_width=512,
122
+ down_t=3,
123
+ stride_t=2,
124
+ width=512,
125
+ depth=3,
126
+ dilation_growth_rate=3,
127
+ activation='relu',
128
+ norm=None):
129
+ super().__init__()
130
+
131
+ blocks = []
132
+ filter_t, pad_t = stride_t * 2, stride_t // 2
133
+ blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1))
134
+ blocks.append(nn.ReLU())
135
+
136
+ for i in range(down_t):
137
+ input_dim = width
138
+ block = nn.Sequential(
139
+ nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t),
140
+ Resnet1D(width,
141
+ depth,
142
+ dilation_growth_rate,
143
+ activation=activation,
144
+ norm=norm),
145
+ )
146
+ blocks.append(block)
147
+ blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1))
148
+ self.model = nn.Sequential(*blocks)
149
+
150
+ def forward(self, x):
151
+ return self.model(x)
152
+
153
+
154
+ class Decoder(nn.Module):
155
+
156
+ def __init__(self,
157
+ input_emb_width=3,
158
+ output_emb_width=512,
159
+ down_t=3,
160
+ stride_t=2,
161
+ width=512,
162
+ depth=3,
163
+ dilation_growth_rate=3,
164
+ activation='relu',
165
+ norm=None):
166
+ super().__init__()
167
+ blocks = []
168
+
169
+ filter_t, pad_t = stride_t * 2, stride_t // 2
170
+ blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1))
171
+ blocks.append(nn.ReLU())
172
+ for i in range(down_t):
173
+ out_dim = width
174
+ block = nn.Sequential(
175
+ Resnet1D(width,
176
+ depth,
177
+ dilation_growth_rate,
178
+ reverse_dilation=True,
179
+ activation=activation,
180
+ norm=norm), nn.Upsample(scale_factor=2,
181
+ mode='nearest'),
182
+ nn.Conv1d(width, out_dim, 3, 1, 1))
183
+ blocks.append(block)
184
+ blocks.append(nn.Conv1d(width, width, 3, 1, 1))
185
+ blocks.append(nn.ReLU())
186
+ blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1))
187
+ self.model = nn.Sequential(*blocks)
188
+
189
+ def forward(self, x):
190
+ return self.model(x)
mGPT/archs/tm2t_evaluator.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils.rnn import pack_padded_sequence
4
+
5
+
6
+ class MovementConvEncoder(nn.Module):
7
+ def __init__(self, input_size, hidden_size, output_size):
8
+ super(MovementConvEncoder, self).__init__()
9
+ self.main = nn.Sequential(
10
+ nn.Conv1d(input_size, hidden_size, 4, 2, 1),
11
+ nn.Dropout(0.2, inplace=True),
12
+ nn.LeakyReLU(0.2, inplace=True),
13
+ nn.Conv1d(hidden_size, output_size, 4, 2, 1),
14
+ nn.Dropout(0.2, inplace=True),
15
+ nn.LeakyReLU(0.2, inplace=True),
16
+ )
17
+ self.out_net = nn.Linear(output_size, output_size)
18
+ # self.main.apply(init_weight)
19
+ # self.out_net.apply(init_weight)
20
+
21
+ def forward(self, inputs):
22
+ inputs = inputs.permute(0, 2, 1)
23
+ outputs = self.main(inputs).permute(0, 2, 1)
24
+ # print(outputs.shape)
25
+ return self.out_net(outputs)
26
+
27
+
28
+ class MotionEncoderBiGRUCo(nn.Module):
29
+ def __init__(self, input_size, hidden_size, output_size):
30
+ super(MotionEncoderBiGRUCo, self).__init__()
31
+
32
+ self.input_emb = nn.Linear(input_size, hidden_size)
33
+ self.gru = nn.GRU(
34
+ hidden_size, hidden_size, batch_first=True, bidirectional=True
35
+ )
36
+ self.output_net = nn.Sequential(
37
+ nn.Linear(hidden_size * 2, hidden_size),
38
+ nn.LayerNorm(hidden_size),
39
+ nn.LeakyReLU(0.2, inplace=True),
40
+ nn.Linear(hidden_size, output_size),
41
+ )
42
+
43
+ # self.input_emb.apply(init_weight)
44
+ # self.output_net.apply(init_weight)
45
+ self.hidden_size = hidden_size
46
+ self.hidden = nn.Parameter(
47
+ torch.randn((2, 1, self.hidden_size), requires_grad=True)
48
+ )
49
+
50
+ # input(batch_size, seq_len, dim)
51
+ def forward(self, inputs, m_lens):
52
+ num_samples = inputs.shape[0]
53
+
54
+ input_embs = self.input_emb(inputs)
55
+ hidden = self.hidden.repeat(1, num_samples, 1)
56
+
57
+ cap_lens = m_lens.data.tolist()
58
+
59
+ # emb = pack_padded_sequence(input=input_embs, lengths=cap_lens, batch_first=True)
60
+ emb = input_embs
61
+
62
+ gru_seq, gru_last = self.gru(emb, hidden)
63
+
64
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
65
+
66
+ return self.output_net(gru_last)
67
+
68
+
69
+ class TextEncoderBiGRUCo(nn.Module):
70
+ def __init__(self, word_size, pos_size, hidden_size, output_size):
71
+ super(TextEncoderBiGRUCo, self).__init__()
72
+
73
+ self.pos_emb = nn.Linear(pos_size, word_size)
74
+ self.input_emb = nn.Linear(word_size, hidden_size)
75
+ self.gru = nn.GRU(
76
+ hidden_size, hidden_size, batch_first=True, bidirectional=True
77
+ )
78
+ self.output_net = nn.Sequential(
79
+ nn.Linear(hidden_size * 2, hidden_size),
80
+ nn.LayerNorm(hidden_size),
81
+ nn.LeakyReLU(0.2, inplace=True),
82
+ nn.Linear(hidden_size, output_size),
83
+ )
84
+
85
+ # self.input_emb.apply(init_weight)
86
+ # self.pos_emb.apply(init_weight)
87
+ # self.output_net.apply(init_weight)
88
+ # self.linear2.apply(init_weight)
89
+ # self.batch_size = batch_size
90
+ self.hidden_size = hidden_size
91
+ self.hidden = nn.Parameter(
92
+ torch.randn((2, 1, self.hidden_size), requires_grad=True)
93
+ )
94
+
95
+ # input(batch_size, seq_len, dim)
96
+ def forward(self, word_embs, pos_onehot, cap_lens):
97
+ num_samples = word_embs.shape[0]
98
+
99
+ pos_embs = self.pos_emb(pos_onehot)
100
+ inputs = word_embs + pos_embs
101
+ input_embs = self.input_emb(inputs)
102
+ hidden = self.hidden.repeat(1, num_samples, 1)
103
+
104
+ cap_lens = cap_lens.data.tolist()
105
+ emb = pack_padded_sequence(input=input_embs, lengths=cap_lens, batch_first=True)
106
+
107
+ gru_seq, gru_last = self.gru(emb, hidden)
108
+
109
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
110
+
111
+ return self.output_net(gru_last)
mGPT/archs/tools/embeddings.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is taken from signjoey repository
2
+ import math
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+
7
+
8
+ def get_activation(activation_type):
9
+ if activation_type == "relu":
10
+ return nn.ReLU()
11
+ elif activation_type == "relu6":
12
+ return nn.ReLU6()
13
+ elif activation_type == "prelu":
14
+ return nn.PReLU()
15
+ elif activation_type == "selu":
16
+ return nn.SELU()
17
+ elif activation_type == "celu":
18
+ return nn.CELU()
19
+ elif activation_type == "gelu":
20
+ return nn.GELU()
21
+ elif activation_type == "sigmoid":
22
+ return nn.Sigmoid()
23
+ elif activation_type == "softplus":
24
+ return nn.Softplus()
25
+ elif activation_type == "softshrink":
26
+ return nn.Softshrink()
27
+ elif activation_type == "softsign":
28
+ return nn.Softsign()
29
+ elif activation_type == "tanh":
30
+ return nn.Tanh()
31
+ elif activation_type == "tanhshrink":
32
+ return nn.Tanhshrink()
33
+ else:
34
+ raise ValueError("Unknown activation type {}".format(activation_type))
35
+
36
+
37
+ class MaskedNorm(nn.Module):
38
+ """
39
+ Original Code from:
40
+ https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8
41
+ """
42
+
43
+ def __init__(self, norm_type, num_groups, num_features):
44
+ super().__init__()
45
+ self.norm_type = norm_type
46
+ if self.norm_type == "batch":
47
+ self.norm = nn.BatchNorm1d(num_features=num_features)
48
+ elif self.norm_type == "group":
49
+ self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features)
50
+ elif self.norm_type == "layer":
51
+ self.norm = nn.LayerNorm(normalized_shape=num_features)
52
+ else:
53
+ raise ValueError("Unsupported Normalization Layer")
54
+
55
+ self.num_features = num_features
56
+
57
+ def forward(self, x: Tensor, mask: Tensor):
58
+ if self.training:
59
+ reshaped = x.reshape([-1, self.num_features])
60
+ reshaped_mask = mask.reshape([-1, 1]) > 0
61
+ selected = torch.masked_select(reshaped, reshaped_mask).reshape(
62
+ [-1, self.num_features]
63
+ )
64
+ batch_normed = self.norm(selected)
65
+ scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
66
+ return scattered.reshape([x.shape[0], -1, self.num_features])
67
+ else:
68
+ reshaped = x.reshape([-1, self.num_features])
69
+ batched_normed = self.norm(reshaped)
70
+ return batched_normed.reshape([x.shape[0], -1, self.num_features])
71
+
72
+
73
+ # TODO (Cihan): Spatial and Word Embeddings are pretty much the same
74
+ # We might as well convert them into a single module class.
75
+ # Only difference is the lut vs linear layers.
76
+ class Embeddings(nn.Module):
77
+
78
+ """
79
+ Simple embeddings class
80
+ """
81
+
82
+ # pylint: disable=unused-argument
83
+ def __init__(
84
+ self,
85
+ embedding_dim: int = 64,
86
+ num_heads: int = 8,
87
+ scale: bool = False,
88
+ scale_factor: float = None,
89
+ norm_type: str = None,
90
+ activation_type: str = None,
91
+ vocab_size: int = 0,
92
+ padding_idx: int = 1,
93
+ freeze: bool = False,
94
+ **kwargs
95
+ ):
96
+ """
97
+ Create new embeddings for the vocabulary.
98
+ Use scaling for the Transformer.
99
+
100
+ :param embedding_dim:
101
+ :param scale:
102
+ :param vocab_size:
103
+ :param padding_idx:
104
+ :param freeze: freeze the embeddings during training
105
+ """
106
+ super().__init__()
107
+
108
+ self.embedding_dim = embedding_dim
109
+ self.vocab_size = vocab_size
110
+ self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx)
111
+
112
+ self.norm_type = norm_type
113
+ if self.norm_type:
114
+ self.norm = MaskedNorm(
115
+ norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim
116
+ )
117
+
118
+ self.activation_type = activation_type
119
+ if self.activation_type:
120
+ self.activation = get_activation(activation_type)
121
+
122
+ self.scale = scale
123
+ if self.scale:
124
+ if scale_factor:
125
+ self.scale_factor = scale_factor
126
+ else:
127
+ self.scale_factor = math.sqrt(self.embedding_dim)
128
+
129
+ if freeze:
130
+ freeze_params(self)
131
+
132
+ # pylint: disable=arguments-differ
133
+ def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
134
+ """
135
+ Perform lookup for input `x` in the embedding table.
136
+
137
+ :param mask: token masks
138
+ :param x: index in the vocabulary
139
+ :return: embedded representation for `x`
140
+ """
141
+
142
+ x = self.lut(x)
143
+
144
+ if self.norm_type:
145
+ x = self.norm(x, mask)
146
+
147
+ if self.activation_type:
148
+ x = self.activation(x)
149
+
150
+ if self.scale:
151
+ return x * self.scale_factor
152
+ else:
153
+ return x
154
+
155
+ def __repr__(self):
156
+ return "%s(embedding_dim=%d, vocab_size=%d)" % (
157
+ self.__class__.__name__,
158
+ self.embedding_dim,
159
+ self.vocab_size,
160
+ )
161
+
162
+
163
+ class SpatialEmbeddings(nn.Module):
164
+
165
+ """
166
+ Simple Linear Projection Layer
167
+ (For encoder outputs to predict glosses)
168
+ """
169
+
170
+ # pylint: disable=unused-argument
171
+ def __init__(
172
+ self,
173
+ embedding_dim: int,
174
+ input_size: int,
175
+ num_heads: int,
176
+ freeze: bool = False,
177
+ norm_type: str = "batch",
178
+ activation_type: str = "softsign",
179
+ scale: bool = False,
180
+ scale_factor: float = None,
181
+ **kwargs
182
+ ):
183
+ """
184
+ Create new embeddings for the vocabulary.
185
+ Use scaling for the Transformer.
186
+
187
+ :param embedding_dim:
188
+ :param input_size:
189
+ :param freeze: freeze the embeddings during training
190
+ """
191
+ super().__init__()
192
+
193
+ self.embedding_dim = embedding_dim
194
+ self.input_size = input_size
195
+ self.ln = nn.Linear(self.input_size, self.embedding_dim)
196
+
197
+ self.norm_type = norm_type
198
+ if self.norm_type:
199
+ self.norm = MaskedNorm(
200
+ norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim
201
+ )
202
+
203
+ self.activation_type = activation_type
204
+ if self.activation_type:
205
+ self.activation = get_activation(activation_type)
206
+
207
+ self.scale = scale
208
+ if self.scale:
209
+ if scale_factor:
210
+ self.scale_factor = scale_factor
211
+ else:
212
+ self.scale_factor = math.sqrt(self.embedding_dim)
213
+
214
+ if freeze:
215
+ freeze_params(self)
216
+
217
+ # pylint: disable=arguments-differ
218
+ def forward(self, x: Tensor, mask: Tensor) -> Tensor:
219
+ """
220
+ :param mask: frame masks
221
+ :param x: input frame features
222
+ :return: embedded representation for `x`
223
+ """
224
+
225
+ x = self.ln(x)
226
+
227
+ if self.norm_type:
228
+ x = self.norm(x, mask)
229
+
230
+ if self.activation_type:
231
+ x = self.activation(x)
232
+
233
+ if self.scale:
234
+ return x * self.scale_factor
235
+ else:
236
+ return x
237
+
238
+ def __repr__(self):
239
+ return "%s(embedding_dim=%d, input_size=%d)" % (
240
+ self.__class__.__name__,
241
+ self.embedding_dim,
242
+ self.input_size,
243
+ )
244
+
245
+ def get_timestep_embedding(
246
+ timesteps: torch.Tensor,
247
+ embedding_dim: int,
248
+ flip_sin_to_cos: bool = False,
249
+ downscale_freq_shift: float = 1,
250
+ scale: float = 1,
251
+ max_period: int = 10000,
252
+ ):
253
+ """
254
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
255
+
256
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
257
+ These may be fractional.
258
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
259
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
260
+ """
261
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
262
+
263
+ half_dim = embedding_dim // 2
264
+ exponent = -math.log(max_period) * torch.arange(
265
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
266
+ )
267
+ exponent = exponent / (half_dim - downscale_freq_shift)
268
+
269
+ emb = torch.exp(exponent)
270
+ emb = timesteps[:, None].float() * emb[None, :]
271
+
272
+ # scale embeddings
273
+ emb = scale * emb
274
+
275
+ # concat sine and cosine embeddings
276
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
277
+
278
+ # flip sine and cosine embeddings
279
+ if flip_sin_to_cos:
280
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
281
+
282
+ # zero pad
283
+ if embedding_dim % 2 == 1:
284
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
285
+ return emb
286
+
287
+
288
+ class TimestepEmbedding(nn.Module):
289
+ def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
290
+ super().__init__()
291
+
292
+ self.linear_1 = nn.Linear(channel, time_embed_dim)
293
+ self.act = None
294
+ if act_fn == "silu":
295
+ self.act = nn.SiLU()
296
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
297
+
298
+ def forward(self, sample):
299
+ sample = self.linear_1(sample)
300
+
301
+ if self.act is not None:
302
+ sample = self.act(sample)
303
+
304
+ sample = self.linear_2(sample)
305
+ return sample
306
+
307
+
308
+ class Timesteps(nn.Module):
309
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
310
+ super().__init__()
311
+ self.num_channels = num_channels
312
+ self.flip_sin_to_cos = flip_sin_to_cos
313
+ self.downscale_freq_shift = downscale_freq_shift
314
+
315
+ def forward(self, timesteps):
316
+ t_emb = get_timestep_embedding(
317
+ timesteps,
318
+ self.num_channels,
319
+ flip_sin_to_cos=self.flip_sin_to_cos,
320
+ downscale_freq_shift=self.downscale_freq_shift,
321
+ )
322
+ return t_emb
mGPT/archs/tools/quantize_cnn.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class QuantizeEMAReset(nn.Module):
7
+ def __init__(self, nb_code, code_dim, mu):
8
+ super().__init__()
9
+ self.nb_code = nb_code
10
+ self.code_dim = code_dim
11
+ self.mu = mu
12
+ self.reset_codebook()
13
+
14
+ def reset_codebook(self):
15
+ self.init = False
16
+ self.code_sum = None
17
+ self.code_count = None
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).to(device))
20
+
21
+ def _tile(self, x):
22
+ nb_code_x, code_dim = x.shape
23
+ if nb_code_x < self.nb_code:
24
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
25
+ std = 0.01 / np.sqrt(code_dim)
26
+ out = x.repeat(n_repeats, 1)
27
+ out = out + torch.randn_like(out) * std
28
+ else :
29
+ out = x
30
+ return out
31
+
32
+ def init_codebook(self, x):
33
+ out = self._tile(x)
34
+ self.codebook = out[:self.nb_code]
35
+ self.code_sum = self.codebook.clone()
36
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
37
+ self.init = True
38
+
39
+ @torch.no_grad()
40
+ def compute_perplexity(self, code_idx) :
41
+ # Calculate new centres
42
+ code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
43
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
44
+
45
+ code_count = code_onehot.sum(dim=-1) # nb_code
46
+ prob = code_count / torch.sum(code_count)
47
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
48
+ return perplexity
49
+
50
+ @torch.no_grad()
51
+ def update_codebook(self, x, code_idx):
52
+
53
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
54
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
55
+
56
+ code_sum = torch.matmul(code_onehot, x) # nb_code, w
57
+ code_count = code_onehot.sum(dim=-1) # nb_code
58
+
59
+ out = self._tile(x)
60
+ code_rand = out[:self.nb_code]
61
+
62
+ # Update centres
63
+ self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
64
+ self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
65
+
66
+ usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
67
+ code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
68
+
69
+ self.codebook = usage * code_update + (1 - usage) * code_rand
70
+ prob = code_count / torch.sum(code_count)
71
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
72
+
73
+
74
+ return perplexity
75
+
76
+ def preprocess(self, x):
77
+ # NCT -> NTC -> [NT, C]
78
+ x = x.permute(0, 2, 1).contiguous()
79
+ x = x.view(-1, x.shape[-1])
80
+ return x
81
+
82
+ def quantize(self, x):
83
+ # Calculate latent code x_l
84
+ k_w = self.codebook.t()
85
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
86
+ keepdim=True) # (N * L, b)
87
+ _, code_idx = torch.min(distance, dim=-1)
88
+ return code_idx
89
+
90
+ def dequantize(self, code_idx):
91
+ x = F.embedding(code_idx, self.codebook)
92
+ return x
93
+
94
+
95
+ def forward(self, x):
96
+ N, width, T = x.shape
97
+
98
+ # Preprocess
99
+ x = self.preprocess(x)
100
+
101
+ # Init codebook if not inited
102
+ if self.training and not self.init:
103
+ self.init_codebook(x)
104
+
105
+ # quantize and dequantize through bottleneck
106
+ code_idx = self.quantize(x)
107
+ x_d = self.dequantize(code_idx)
108
+
109
+ # Update embeddings
110
+ if self.training:
111
+ perplexity = self.update_codebook(x, code_idx)
112
+ else :
113
+ perplexity = self.compute_perplexity(code_idx)
114
+
115
+ # Loss
116
+ commit_loss = F.mse_loss(x, x_d.detach())
117
+
118
+ # Passthrough
119
+ x_d = x + (x_d - x).detach()
120
+
121
+ # Postprocess
122
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
123
+
124
+ return x_d, commit_loss, perplexity
125
+
126
+
127
+
128
+ class Quantizer(nn.Module):
129
+ def __init__(self, n_e, e_dim, beta):
130
+ super(Quantizer, self).__init__()
131
+
132
+ self.e_dim = e_dim
133
+ self.n_e = n_e
134
+ self.beta = beta
135
+
136
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
137
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
138
+
139
+ def forward(self, z):
140
+
141
+ N, width, T = z.shape
142
+ z = self.preprocess(z)
143
+ assert z.shape[-1] == self.e_dim
144
+ z_flattened = z.contiguous().view(-1, self.e_dim)
145
+
146
+ # B x V
147
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
148
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
149
+ torch.matmul(z_flattened, self.embedding.weight.t())
150
+ # B x 1
151
+ min_encoding_indices = torch.argmin(d, dim=1)
152
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
153
+
154
+ # compute loss for embedding
155
+ loss = torch.mean((z_q - z.detach())**2) + self.beta * \
156
+ torch.mean((z_q.detach() - z)**2)
157
+
158
+ # preserve gradients
159
+ z_q = z + (z_q - z).detach()
160
+ z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
161
+
162
+ min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype)
163
+ e_mean = torch.mean(min_encodings, dim=0)
164
+ perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10)))
165
+ return z_q, loss, perplexity
166
+
167
+ def quantize(self, z):
168
+
169
+ assert z.shape[-1] == self.e_dim
170
+
171
+ # B x V
172
+ d = torch.sum(z ** 2, dim=1, keepdim=True) + \
173
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
174
+ torch.matmul(z, self.embedding.weight.t())
175
+ # B x 1
176
+ min_encoding_indices = torch.argmin(d, dim=1)
177
+ return min_encoding_indices
178
+
179
+ def dequantize(self, indices):
180
+
181
+ index_flattened = indices.view(-1)
182
+ z_q = self.embedding(index_flattened)
183
+ z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous()
184
+ return z_q
185
+
186
+ def preprocess(self, x):
187
+ # NCT -> NTC -> [NT, C]
188
+ x = x.permute(0, 2, 1).contiguous()
189
+ x = x.view(-1, x.shape[-1])
190
+ return x
191
+
192
+
193
+
194
+ class QuantizeReset(nn.Module):
195
+ def __init__(self, nb_code, code_dim):
196
+ super().__init__()
197
+ self.nb_code = nb_code
198
+ self.code_dim = code_dim
199
+ self.reset_codebook()
200
+ self.codebook = nn.Parameter(torch.randn(nb_code, code_dim))
201
+
202
+ def reset_codebook(self):
203
+ self.init = False
204
+ self.code_count = None
205
+
206
+ def _tile(self, x):
207
+ nb_code_x, code_dim = x.shape
208
+ if nb_code_x < self.nb_code:
209
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
210
+ std = 0.01 / np.sqrt(code_dim)
211
+ out = x.repeat(n_repeats, 1)
212
+ out = out + torch.randn_like(out) * std
213
+ else :
214
+ out = x
215
+ return out
216
+
217
+ def init_codebook(self, x):
218
+ out = self._tile(x)
219
+ self.codebook = nn.Parameter(out[:self.nb_code])
220
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
221
+ self.init = True
222
+
223
+ @torch.no_grad()
224
+ def compute_perplexity(self, code_idx) :
225
+ # Calculate new centres
226
+ code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
227
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
228
+
229
+ code_count = code_onehot.sum(dim=-1) # nb_code
230
+ prob = code_count / torch.sum(code_count)
231
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
232
+ return perplexity
233
+
234
+ def update_codebook(self, x, code_idx):
235
+
236
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
237
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
238
+
239
+ code_count = code_onehot.sum(dim=-1) # nb_code
240
+
241
+ out = self._tile(x)
242
+ code_rand = out[:self.nb_code]
243
+
244
+ # Update centres
245
+ self.code_count = code_count # nb_code
246
+ usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
247
+
248
+ self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand
249
+ prob = code_count / torch.sum(code_count)
250
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
251
+
252
+
253
+ return perplexity
254
+
255
+ def preprocess(self, x):
256
+ # NCT -> NTC -> [NT, C]
257
+ x = x.permute(0, 2, 1).contiguous()
258
+ x = x.view(-1, x.shape[-1])
259
+ return x
260
+
261
+ def quantize(self, x):
262
+ # Calculate latent code x_l
263
+ k_w = self.codebook.t()
264
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
265
+ keepdim=True) # (N * L, b)
266
+ _, code_idx = torch.min(distance, dim=-1)
267
+ return code_idx
268
+
269
+ def dequantize(self, code_idx):
270
+ x = F.embedding(code_idx, self.codebook)
271
+ return x
272
+
273
+
274
+ def forward(self, x):
275
+ N, width, T = x.shape
276
+ # Preprocess
277
+ x = self.preprocess(x)
278
+ # Init codebook if not inited
279
+ if self.training and not self.init:
280
+ self.init_codebook(x)
281
+ # quantize and dequantize through bottleneck
282
+ code_idx = self.quantize(x)
283
+ x_d = self.dequantize(code_idx)
284
+ # Update embeddings
285
+ if self.training:
286
+ perplexity = self.update_codebook(x, code_idx)
287
+ else :
288
+ perplexity = self.compute_perplexity(code_idx)
289
+
290
+ # Loss
291
+ commit_loss = F.mse_loss(x, x_d.detach())
292
+
293
+ # Passthrough
294
+ x_d = x + (x_d - x).detach()
295
+
296
+ # Postprocess
297
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
298
+
299
+ return x_d, commit_loss, perplexity
300
+
301
+
302
+ class QuantizeEMA(nn.Module):
303
+ def __init__(self, nb_code, code_dim, mu):
304
+ super().__init__()
305
+ self.nb_code = nb_code
306
+ self.code_dim = code_dim
307
+ self.mu = mu
308
+ self.reset_codebook()
309
+
310
+ def reset_codebook(self):
311
+ self.init = False
312
+ self.code_sum = None
313
+ self.code_count = None
314
+ self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda())
315
+
316
+ def _tile(self, x):
317
+ nb_code_x, code_dim = x.shape
318
+ if nb_code_x < self.nb_code:
319
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
320
+ std = 0.01 / np.sqrt(code_dim)
321
+ out = x.repeat(n_repeats, 1)
322
+ out = out + torch.randn_like(out) * std
323
+ else :
324
+ out = x
325
+ return out
326
+
327
+ def init_codebook(self, x):
328
+ out = self._tile(x)
329
+ self.codebook = out[:self.nb_code]
330
+ self.code_sum = self.codebook.clone()
331
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
332
+ self.init = True
333
+
334
+ @torch.no_grad()
335
+ def compute_perplexity(self, code_idx) :
336
+ # Calculate new centres
337
+ code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
338
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
339
+
340
+ code_count = code_onehot.sum(dim=-1) # nb_code
341
+ prob = code_count / torch.sum(code_count)
342
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
343
+ return perplexity
344
+
345
+ @torch.no_grad()
346
+ def update_codebook(self, x, code_idx):
347
+
348
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
349
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
350
+
351
+ code_sum = torch.matmul(code_onehot, x) # nb_code, w
352
+ code_count = code_onehot.sum(dim=-1) # nb_code
353
+
354
+ # Update centres
355
+ self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
356
+ self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
357
+
358
+ code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
359
+
360
+ self.codebook = code_update
361
+ prob = code_count / torch.sum(code_count)
362
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
363
+
364
+ return perplexity
365
+
366
+ def preprocess(self, x):
367
+ # NCT -> NTC -> [NT, C]
368
+ x = x.permute(0, 2, 1).contiguous()
369
+ x = x.view(-1, x.shape[-1])
370
+ return x
371
+
372
+ def quantize(self, x):
373
+ # Calculate latent code x_l
374
+ k_w = self.codebook.t()
375
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
376
+ keepdim=True) # (N * L, b)
377
+ _, code_idx = torch.min(distance, dim=-1)
378
+ return code_idx
379
+
380
+ def dequantize(self, code_idx):
381
+ x = F.embedding(code_idx, self.codebook)
382
+ return x
383
+
384
+
385
+ def forward(self, x):
386
+ N, width, T = x.shape
387
+
388
+ # Preprocess
389
+ x = self.preprocess(x)
390
+
391
+ # Init codebook if not inited
392
+ if self.training and not self.init:
393
+ self.init_codebook(x)
394
+
395
+ # quantize and dequantize through bottleneck
396
+ code_idx = self.quantize(x)
397
+ x_d = self.dequantize(code_idx)
398
+
399
+ # Update embeddings
400
+ if self.training:
401
+ perplexity = self.update_codebook(x, code_idx)
402
+ else :
403
+ perplexity = self.compute_perplexity(code_idx)
404
+
405
+ # Loss
406
+ commit_loss = F.mse_loss(x, x_d.detach())
407
+
408
+ # Passthrough
409
+ x_d = x + (x_d - x).detach()
410
+
411
+ # Postprocess
412
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
413
+
414
+ return x_d, commit_loss, perplexity
mGPT/archs/tools/resnet.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class nonlinearity(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def forward(self, x):
9
+ # swish
10
+ return x * torch.sigmoid(x)
11
+
12
+ class ResConv1DBlock(nn.Module):
13
+ def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None):
14
+ super().__init__()
15
+ padding = dilation
16
+ self.norm = norm
17
+ if norm == "LN":
18
+ self.norm1 = nn.LayerNorm(n_in)
19
+ self.norm2 = nn.LayerNorm(n_in)
20
+ elif norm == "GN":
21
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
22
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
23
+ elif norm == "BN":
24
+ self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
25
+ self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
26
+
27
+ else:
28
+ self.norm1 = nn.Identity()
29
+ self.norm2 = nn.Identity()
30
+
31
+ if activation == "relu":
32
+ self.activation1 = nn.ReLU()
33
+ self.activation2 = nn.ReLU()
34
+
35
+ elif activation == "silu":
36
+ self.activation1 = nonlinearity()
37
+ self.activation2 = nonlinearity()
38
+
39
+ elif activation == "gelu":
40
+ self.activation1 = nn.GELU()
41
+ self.activation2 = nn.GELU()
42
+
43
+
44
+
45
+ self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation)
46
+ self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,)
47
+
48
+
49
+ def forward(self, x):
50
+ x_orig = x
51
+ if self.norm == "LN":
52
+ x = self.norm1(x.transpose(-2, -1))
53
+ x = self.activation1(x.transpose(-2, -1))
54
+ else:
55
+ x = self.norm1(x)
56
+ x = self.activation1(x)
57
+
58
+ x = self.conv1(x)
59
+
60
+ if self.norm == "LN":
61
+ x = self.norm2(x.transpose(-2, -1))
62
+ x = self.activation2(x.transpose(-2, -1))
63
+ else:
64
+ x = self.norm2(x)
65
+ x = self.activation2(x)
66
+
67
+ x = self.conv2(x)
68
+ x = x + x_orig
69
+ return x
70
+
71
+ class Resnet1D(nn.Module):
72
+ def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None):
73
+ super().__init__()
74
+
75
+ blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)]
76
+ if reverse_dilation:
77
+ blocks = blocks[::-1]
78
+
79
+ self.model = nn.Sequential(*blocks)
80
+
81
+ def forward(self, x):
82
+ return self.model(x)
mGPT/archs/tools/token_emb.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from torch import Tensor, nn
3
+
4
+ class NewTokenEmb(nn.Module):
5
+ """
6
+ For adding new tokens to a pretrained model
7
+ """
8
+
9
+ def __init__(self,
10
+ old_embeddings: nn.Embedding,
11
+ new_num_tokens: int = None) -> None:
12
+
13
+ super().__init__()
14
+
15
+ self.num_tokens = old_embeddings.num_embeddings + new_num_tokens
16
+ self.old_num_tokens = old_embeddings.num_embeddings
17
+ self.new_num_tokens = new_num_tokens
18
+ self.embedding_dim = old_embeddings.embedding_dim
19
+
20
+ # For text embeddings
21
+ self.text_embeddings = nn.Embedding(
22
+ self.num_tokens,
23
+ self.embedding_dim,
24
+ device=old_embeddings.weight.device,
25
+ dtype=old_embeddings.weight.dtype)
26
+ with torch.no_grad():
27
+ self.text_embeddings.weight.data[:old_embeddings.
28
+ num_embeddings] = old_embeddings.weight.data
29
+ self.text_embeddings.weight.data[
30
+ self.old_num_tokens:] = torch.zeros(
31
+ self.new_num_tokens,
32
+ self.embedding_dim,
33
+ dtype=old_embeddings.weight.dtype,
34
+ device=old_embeddings.weight.device)
35
+ self.text_embeddings.weight.requires_grad_(False)
36
+
37
+ # For motion embeddings
38
+ self.motion_embeddings = nn.Embedding(
39
+ new_num_tokens,
40
+ self.embedding_dim,
41
+ device=old_embeddings.weight.device,
42
+ dtype=old_embeddings.weight.dtype)
43
+ with torch.no_grad():
44
+ self.motion_embeddings.weight.data[:self.
45
+ old_num_tokens] = torch.zeros(
46
+ new_num_tokens,
47
+ self.embedding_dim,
48
+ dtype=old_embeddings.weight.
49
+ dtype,
50
+ device=old_embeddings.
51
+ weight.device)
52
+ self.word2motionProj = nn.Linear(self.old_num_tokens, new_num_tokens)
53
+
54
+ def forward(self, input: Tensor) -> Tensor:
55
+
56
+ with torch.no_grad():
57
+ self.motion_embeddings.weight.data[:self.
58
+ old_num_tokens] = torch.zeros(
59
+ self.new_num_tokens,
60
+ self.embedding_dim,
61
+ dtype=self.motion_embeddings
62
+ .weight.dtype,
63
+ device=self.
64
+ motion_embeddings.weight.
65
+ device)
66
+
67
+ self.motion_embeddings.weight.data[
68
+ self.old_num_tokens:] = self.word2motionProj(
69
+ self.text_embeddings.weight.data[:self.old_num_tokens].permute(
70
+ 1, 0)).permute(1, 0)
71
+
72
+ return self.text_embeddings(input) + self.motion_embeddings(input)
73
+
mGPT/archs/tools/transformer_layers.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+
7
+ # Took from https://github.com/joeynmt/joeynmt/blob/fb66afcbe1beef9acd59283bcc084c4d4c1e6343/joeynmt/transformer_layers.py
8
+
9
+
10
+ # pylint: disable=arguments-differ
11
+ class MultiHeadedAttention(nn.Module):
12
+ """
13
+ Multi-Head Attention module from "Attention is All You Need"
14
+
15
+ Implementation modified from OpenNMT-py.
16
+ https://github.com/OpenNMT/OpenNMT-py
17
+ """
18
+
19
+ def __init__(self, num_heads: int, size: int, dropout: float = 0.1):
20
+ """
21
+ Create a multi-headed attention layer.
22
+ :param num_heads: the number of heads
23
+ :param size: model size (must be divisible by num_heads)
24
+ :param dropout: probability of dropping a unit
25
+ """
26
+ super().__init__()
27
+
28
+ assert size % num_heads == 0
29
+
30
+ self.head_size = head_size = size // num_heads
31
+ self.model_size = size
32
+ self.num_heads = num_heads
33
+
34
+ self.k_layer = nn.Linear(size, num_heads * head_size)
35
+ self.v_layer = nn.Linear(size, num_heads * head_size)
36
+ self.q_layer = nn.Linear(size, num_heads * head_size)
37
+
38
+ self.output_layer = nn.Linear(size, size)
39
+ self.softmax = nn.Softmax(dim=-1)
40
+ self.dropout = nn.Dropout(dropout)
41
+
42
+ def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None):
43
+ """
44
+ Computes multi-headed attention.
45
+
46
+ :param k: keys [B, M, D] with M being the sentence length.
47
+ :param v: values [B, M, D]
48
+ :param q: query [B, M, D]
49
+ :param mask: optional mask [B, 1, M] or [B, M, M]
50
+ :return:
51
+ """
52
+ batch_size = k.size(0)
53
+ num_heads = self.num_heads
54
+
55
+ # project the queries (q), keys (k), and values (v)
56
+ k = self.k_layer(k)
57
+ v = self.v_layer(v)
58
+ q = self.q_layer(q)
59
+
60
+ # reshape q, k, v for our computation to [batch_size, num_heads, ..]
61
+ k = k.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)
62
+ v = v.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)
63
+ q = q.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)
64
+
65
+ # compute scores
66
+ q = q / math.sqrt(self.head_size)
67
+
68
+ # batch x num_heads x query_len x key_len
69
+ scores = torch.matmul(q, k.transpose(2, 3))
70
+ # torch.Size([48, 8, 183, 183])
71
+
72
+ # apply the mask (if we have one)
73
+ # we add a dimension for the heads to it below: [B, 1, 1, M]
74
+ if mask is not None:
75
+ scores = scores.masked_fill(~mask.unsqueeze(1), float('-inf'))
76
+
77
+ # apply attention dropout and compute context vectors.
78
+ attention = self.softmax(scores)
79
+ attention = self.dropout(attention)
80
+ # torch.Size([48, 8, 183, 183]) [bs, nheads, time, time] (for decoding)
81
+
82
+ # v: torch.Size([48, 8, 183, 32]) (32 is 256/8)
83
+ # get context vector (select values with attention) and reshape
84
+ # back to [B, M, D]
85
+ context = torch.matmul(attention, v) # torch.Size([48, 8, 183, 32])
86
+ context = context.transpose(1, 2).contiguous().view(
87
+ batch_size, -1, num_heads * self.head_size)
88
+ # torch.Size([48, 183, 256]) put back to 256 (combine the heads)
89
+
90
+ output = self.output_layer(context)
91
+ # torch.Size([48, 183, 256]): 1 output per time step
92
+
93
+ return output
94
+
95
+
96
+ # pylint: disable=arguments-differ
97
+ class PositionwiseFeedForward(nn.Module):
98
+ """
99
+ Position-wise Feed-forward layer
100
+ Projects to ff_size and then back down to input_size.
101
+ """
102
+
103
+ def __init__(self, input_size, ff_size, dropout=0.1):
104
+ """
105
+ Initializes position-wise feed-forward layer.
106
+ :param input_size: dimensionality of the input.
107
+ :param ff_size: dimensionality of intermediate representation
108
+ :param dropout:
109
+ """
110
+ super().__init__()
111
+ self.layer_norm = nn.LayerNorm(input_size, eps=1e-6)
112
+ self.pwff_layer = nn.Sequential(
113
+ nn.Linear(input_size, ff_size),
114
+ nn.ReLU(),
115
+ nn.Dropout(dropout),
116
+ nn.Linear(ff_size, input_size),
117
+ nn.Dropout(dropout),
118
+ )
119
+
120
+ def forward(self, x):
121
+ x_norm = self.layer_norm(x)
122
+ return self.pwff_layer(x_norm) + x
123
+
124
+
125
+ # pylint: disable=arguments-differ
126
+ class PositionalEncoding(nn.Module):
127
+ """
128
+ Pre-compute position encodings (PE).
129
+ In forward pass, this adds the position-encodings to the
130
+ input for as many time steps as necessary.
131
+
132
+ Implementation based on OpenNMT-py.
133
+ https://github.com/OpenNMT/OpenNMT-py
134
+ """
135
+
136
+ def __init__(self, size: int = 0, max_len: int = 5000):
137
+ """
138
+ Positional Encoding with maximum length max_len
139
+ :param size:
140
+ :param max_len:
141
+ :param dropout:
142
+ """
143
+ if size % 2 != 0:
144
+ raise ValueError("Cannot use sin/cos positional encoding with "
145
+ "odd dim (got dim={:d})".format(size))
146
+ pe = torch.zeros(max_len, size)
147
+ position = torch.arange(0, max_len).unsqueeze(1)
148
+ div_term = torch.exp((torch.arange(0, size, 2, dtype=torch.float) *
149
+ -(math.log(10000.0) / size)))
150
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
151
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
152
+ pe = pe.unsqueeze(0) # shape: [1, size, max_len]
153
+ super().__init__()
154
+ self.register_buffer('pe', pe)
155
+ self.dim = size
156
+
157
+ def forward(self, emb):
158
+ """Embed inputs.
159
+ Args:
160
+ emb (FloatTensor): Sequence of word vectors
161
+ ``(seq_len, batch_size, self.dim)``
162
+ """
163
+ # Add position encodings
164
+ return emb + self.pe[:, :emb.size(1)]
165
+
166
+
167
+ class TransformerEncoderLayer(nn.Module):
168
+ """
169
+ One Transformer encoder layer has a Multi-head attention layer plus
170
+ a position-wise feed-forward layer.
171
+ """
172
+
173
+ def __init__(self,
174
+ size: int = 0,
175
+ ff_size: int = 0,
176
+ num_heads: int = 0,
177
+ dropout: float = 0.1):
178
+ """
179
+ A single Transformer layer.
180
+ :param size:
181
+ :param ff_size:
182
+ :param num_heads:
183
+ :param dropout:
184
+ """
185
+ super().__init__()
186
+
187
+ self.layer_norm = nn.LayerNorm(size, eps=1e-6)
188
+ self.src_src_att = MultiHeadedAttention(num_heads,
189
+ size,
190
+ dropout=dropout)
191
+ self.feed_forward = PositionwiseFeedForward(size,
192
+ ff_size=ff_size,
193
+ dropout=dropout)
194
+ self.dropout = nn.Dropout(dropout)
195
+ self.size = size
196
+
197
+ # pylint: disable=arguments-differ
198
+ def forward(self, x: Tensor, mask: Tensor) -> Tensor:
199
+ """
200
+ Forward pass for a single transformer encoder layer.
201
+ First applies layer norm, then self attention,
202
+ then dropout with residual connection (adding the input to the result),
203
+ and then a position-wise feed-forward layer.
204
+
205
+ :param x: layer input
206
+ :param mask: input mask
207
+ :return: output tensor
208
+ """
209
+ x_norm = self.layer_norm(x)
210
+ h = self.src_src_att(x_norm, x_norm, x_norm, mask)
211
+ h = self.dropout(h) + x
212
+ o = self.feed_forward(h)
213
+ return o
214
+
215
+
216
+ class TransformerDecoderLayer(nn.Module):
217
+ """
218
+ Transformer decoder layer.
219
+
220
+ Consists of self-attention, source-attention, and feed-forward.
221
+ """
222
+
223
+ def __init__(self,
224
+ size: int = 0,
225
+ ff_size: int = 0,
226
+ num_heads: int = 0,
227
+ dropout: float = 0.1):
228
+ """
229
+ Represents a single Transformer decoder layer.
230
+
231
+ It attends to the source representation and the previous decoder states.
232
+
233
+ :param size: model dimensionality
234
+ :param ff_size: size of the feed-forward intermediate layer
235
+ :param num_heads: number of heads
236
+ :param dropout: dropout to apply to input
237
+ """
238
+ super().__init__()
239
+ self.size = size
240
+
241
+ self.trg_trg_att = MultiHeadedAttention(num_heads,
242
+ size,
243
+ dropout=dropout)
244
+ self.src_trg_att = MultiHeadedAttention(num_heads,
245
+ size,
246
+ dropout=dropout)
247
+
248
+ self.feed_forward = PositionwiseFeedForward(size,
249
+ ff_size=ff_size,
250
+ dropout=dropout)
251
+
252
+ self.x_layer_norm = nn.LayerNorm(size, eps=1e-6)
253
+ self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6)
254
+
255
+ self.dropout = nn.Dropout(dropout)
256
+
257
+ # pylint: disable=arguments-differ
258
+ def forward(self,
259
+ x: Tensor = None,
260
+ memory: Tensor = None,
261
+ src_mask: Tensor = None,
262
+ trg_mask: Tensor = None) -> Tensor:
263
+ """
264
+ Forward pass of a single Transformer decoder layer.
265
+
266
+ :param x: inputs
267
+ :param memory: source representations
268
+ :param src_mask: source mask
269
+ :param trg_mask: target mask (so as to not condition on future steps)
270
+ :return: output tensor
271
+ """
272
+ # decoder/target self-attention
273
+ x_norm = self.x_layer_norm(x) # torch.Size([48, 183, 256])
274
+ h1 = self.trg_trg_att(x_norm, x_norm, x_norm, mask=trg_mask)
275
+ h1 = self.dropout(h1) + x
276
+
277
+ # source-target attention
278
+ h1_norm = self.dec_layer_norm(
279
+ h1) # torch.Size([48, 183, 256]) (same for memory)
280
+ h2 = self.src_trg_att(memory, memory, h1_norm, mask=src_mask)
281
+
282
+ # final position-wise feed-forward layer
283
+ o = self.feed_forward(self.dropout(h2) + h1)
284
+
285
+ return o
mGPT/callback.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pytorch_lightning import LightningModule, Trainer
3
+ from pytorch_lightning.callbacks import Callback, RichProgressBar, ModelCheckpoint
4
+
5
+
6
+ def build_callbacks(cfg, logger=None, phase='test', **kwargs):
7
+ callbacks = []
8
+ logger = logger
9
+
10
+ # Rich Progress Bar
11
+ callbacks.append(progressBar())
12
+
13
+ # Checkpoint Callback
14
+ if phase == 'train':
15
+ callbacks.extend(getCheckpointCallback(cfg, logger=logger, **kwargs))
16
+
17
+ return callbacks
18
+
19
+ def getCheckpointCallback(cfg, logger=None, **kwargs):
20
+ callbacks = []
21
+ # Logging
22
+ metric_monitor = {
23
+ "loss_total": "total/train",
24
+ "Train_jf": "recons/text2jfeats/train",
25
+ "Val_jf": "recons/text2jfeats/val",
26
+ "Train_rf": "recons/text2rfeats/train",
27
+ "Val_rf": "recons/text2rfeats/val",
28
+ "APE root": "Metrics/APE_root",
29
+ "APE mean pose": "Metrics/APE_mean_pose",
30
+ "AVE root": "Metrics/AVE_root",
31
+ "AVE mean pose": "Metrics/AVE_mean_pose",
32
+ "R_TOP_1": "Metrics/R_precision_top_1",
33
+ "R_TOP_2": "Metrics/R_precision_top_2",
34
+ "R_TOP_3": "Metrics/R_precision_top_3",
35
+ "gt_R_TOP_3": "Metrics/gt_R_precision_top_3",
36
+ "FID": "Metrics/FID",
37
+ "gt_FID": "Metrics/gt_FID",
38
+ "Diversity": "Metrics/Diversity",
39
+ "MM dist": "Metrics/Matching_score",
40
+ "Accuracy": "Metrics/accuracy",
41
+ }
42
+ callbacks.append(
43
+ progressLogger(logger,metric_monitor=metric_monitor,log_every_n_steps=1))
44
+
45
+ # Save 10 latest checkpoints
46
+ checkpointParams = {
47
+ 'dirpath': os.path.join(cfg.FOLDER_EXP, "checkpoints"),
48
+ 'filename': "{epoch}",
49
+ 'monitor': "step",
50
+ 'mode': "max",
51
+ 'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS,
52
+ 'save_top_k': 8,
53
+ 'save_last': True,
54
+ 'save_on_train_epoch_end': True
55
+ }
56
+ callbacks.append(ModelCheckpoint(**checkpointParams))
57
+
58
+ # Save checkpoint every n*10 epochs
59
+ checkpointParams.update({
60
+ 'every_n_epochs':
61
+ cfg.LOGGER.VAL_EVERY_STEPS * 10,
62
+ 'save_top_k':
63
+ -1,
64
+ 'save_last':
65
+ False
66
+ })
67
+ callbacks.append(ModelCheckpoint(**checkpointParams))
68
+
69
+ metrics = cfg.METRIC.TYPE
70
+ metric_monitor_map = {
71
+ 'TemosMetric': {
72
+ 'Metrics/APE_root': {
73
+ 'abbr': 'APEroot',
74
+ 'mode': 'min'
75
+ },
76
+ },
77
+ 'TM2TMetrics': {
78
+ 'Metrics/FID': {
79
+ 'abbr': 'FID',
80
+ 'mode': 'min'
81
+ },
82
+ 'Metrics/R_precision_top_3': {
83
+ 'abbr': 'R3',
84
+ 'mode': 'max'
85
+ }
86
+ },
87
+ 'MRMetrics': {
88
+ 'Metrics/MPJPE': {
89
+ 'abbr': 'MPJPE',
90
+ 'mode': 'min'
91
+ }
92
+ },
93
+ 'HUMANACTMetrics': {
94
+ 'Metrics/Accuracy': {
95
+ 'abbr': 'Accuracy',
96
+ 'mode': 'max'
97
+ }
98
+ },
99
+ 'UESTCMetrics': {
100
+ 'Metrics/Accuracy': {
101
+ 'abbr': 'Accuracy',
102
+ 'mode': 'max'
103
+ }
104
+ },
105
+ 'UncondMetrics': {
106
+ 'Metrics/FID': {
107
+ 'abbr': 'FID',
108
+ 'mode': 'min'
109
+ }
110
+ }
111
+ }
112
+
113
+ checkpointParams.update({
114
+ 'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS,
115
+ 'save_top_k': 1,
116
+ })
117
+
118
+ for metric in metrics:
119
+ if metric in metric_monitor_map.keys():
120
+ metric_monitors = metric_monitor_map[metric]
121
+
122
+ # Delete R3 if training VAE
123
+ if cfg.TRAIN.STAGE == 'vae' and metric == 'TM2TMetrics':
124
+ del metric_monitors['Metrics/R_precision_top_3']
125
+
126
+ for metric_monitor in metric_monitors:
127
+ checkpointParams.update({
128
+ 'filename':
129
+ metric_monitor_map[metric][metric_monitor]['mode']
130
+ + "-" +
131
+ metric_monitor_map[metric][metric_monitor]['abbr']
132
+ + "{ep}",
133
+ 'monitor':
134
+ metric_monitor,
135
+ 'mode':
136
+ metric_monitor_map[metric][metric_monitor]['mode'],
137
+ })
138
+ callbacks.append(
139
+ ModelCheckpoint(**checkpointParams))
140
+ return callbacks
141
+
142
+ class progressBar(RichProgressBar):
143
+ def __init__(self, ):
144
+ super().__init__()
145
+
146
+ def get_metrics(self, trainer, model):
147
+ # Don't show the version number
148
+ items = super().get_metrics(trainer, model)
149
+ items.pop("v_num", None)
150
+ return items
151
+
152
+ class progressLogger(Callback):
153
+ def __init__(self,
154
+ logger,
155
+ metric_monitor: dict,
156
+ precision: int = 3,
157
+ log_every_n_steps: int = 1):
158
+ # Metric to monitor
159
+ self.logger = logger
160
+ self.metric_monitor = metric_monitor
161
+ self.precision = precision
162
+ self.log_every_n_steps = log_every_n_steps
163
+
164
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule,
165
+ **kwargs) -> None:
166
+ self.logger.info("Training started")
167
+
168
+ def on_train_end(self, trainer: Trainer, pl_module: LightningModule,
169
+ **kwargs) -> None:
170
+ self.logger.info("Training done")
171
+
172
+ def on_validation_epoch_end(self, trainer: Trainer,
173
+ pl_module: LightningModule, **kwargs) -> None:
174
+ if trainer.sanity_checking:
175
+ self.logger.info("Sanity checking ok.")
176
+
177
+ def on_train_epoch_end(self,
178
+ trainer: Trainer,
179
+ pl_module: LightningModule,
180
+ padding=False,
181
+ **kwargs) -> None:
182
+ metric_format = f"{{:.{self.precision}e}}"
183
+ line = f"Epoch {trainer.current_epoch}"
184
+ if padding:
185
+ line = f"{line:>{len('Epoch xxxx')}}" # Right padding
186
+
187
+ if trainer.current_epoch % self.log_every_n_steps == 0:
188
+ metrics_str = []
189
+
190
+ losses_dict = trainer.callback_metrics
191
+ for metric_name, dico_name in self.metric_monitor.items():
192
+ if dico_name in losses_dict:
193
+ metric = losses_dict[dico_name].item()
194
+ metric = metric_format.format(metric)
195
+ metric = f"{metric_name} {metric}"
196
+ metrics_str.append(metric)
197
+
198
+ line = line + ": " + " ".join(metrics_str)
199
+
200
+ self.logger.info(line)
mGPT/config.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from argparse import ArgumentParser
3
+ from omegaconf import OmegaConf
4
+ from os.path import join as pjoin
5
+ import os
6
+ import glob
7
+
8
+
9
+ def get_module_config(cfg, filepath="./configs"):
10
+ """
11
+ Load yaml config files from subfolders
12
+ """
13
+
14
+ yamls = glob.glob(pjoin(filepath, '*', '*.yaml'))
15
+ yamls = [y.replace(filepath, '') for y in yamls]
16
+ for yaml in yamls:
17
+ nodes = yaml.replace('.yaml', '').replace('/', '.')
18
+ nodes = nodes[1:] if nodes[0] == '.' else nodes
19
+ OmegaConf.update(cfg, nodes, OmegaConf.load('./configs' + yaml))
20
+
21
+ return cfg
22
+
23
+
24
+ def get_obj_from_str(string, reload=False):
25
+ """
26
+ Get object from string
27
+ """
28
+
29
+ module, cls = string.rsplit(".", 1)
30
+ if reload:
31
+ module_imp = importlib.import_module(module)
32
+ importlib.reload(module_imp)
33
+ return getattr(importlib.import_module(module, package=None), cls)
34
+
35
+
36
+ def instantiate_from_config(config):
37
+ """
38
+ Instantiate object from config
39
+ """
40
+ if not "target" in config:
41
+ raise KeyError("Expected key `target` to instantiate.")
42
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
43
+
44
+
45
+ def resume_config(cfg: OmegaConf):
46
+ """
47
+ Resume model and wandb
48
+ """
49
+
50
+ if cfg.TRAIN.RESUME:
51
+ resume = cfg.TRAIN.RESUME
52
+ if os.path.exists(resume):
53
+ # Checkpoints
54
+ cfg.TRAIN.PRETRAINED = pjoin(resume, "checkpoints", "last.ckpt")
55
+ # Wandb
56
+ wandb_files = os.listdir(pjoin(resume, "wandb", "latest-run"))
57
+ wandb_run = [item for item in wandb_files if "run-" in item][0]
58
+ cfg.LOGGER.WANDB.params.id = wandb_run.replace("run-","").replace(".wandb", "")
59
+ else:
60
+ raise ValueError("Resume path is not right.")
61
+
62
+ return cfg
63
+
64
+ def parse_args(phase="train"):
65
+ """
66
+ Parse arguments and load config files
67
+ """
68
+
69
+ parser = ArgumentParser()
70
+ group = parser.add_argument_group("Training options")
71
+
72
+ # Assets
73
+ group.add_argument(
74
+ "--cfg_assets",
75
+ type=str,
76
+ required=False,
77
+ default="./configs/assets.yaml",
78
+ help="config file for asset paths",
79
+ )
80
+
81
+ # Default config
82
+ if phase in ["train", "test"]:
83
+ cfg_defualt = "./configs/default.yaml"
84
+ elif phase == "render":
85
+ cfg_defualt = "./configs/render.yaml"
86
+ elif phase == "webui":
87
+ cfg_defualt = "./configs/webui.yaml"
88
+
89
+ group.add_argument(
90
+ "--cfg",
91
+ type=str,
92
+ required=False,
93
+ default=cfg_defualt,
94
+ help="config file",
95
+ )
96
+
97
+ # Parse for each phase
98
+ if phase in ["train", "test"]:
99
+ group.add_argument("--batch_size",
100
+ type=int,
101
+ required=False,
102
+ help="training batch size")
103
+ group.add_argument("--num_nodes",
104
+ type=int,
105
+ required=False,
106
+ help="number of nodes")
107
+ group.add_argument("--device",
108
+ type=int,
109
+ nargs="+",
110
+ required=False,
111
+ help="training device")
112
+ group.add_argument("--task",
113
+ type=str,
114
+ required=False,
115
+ help="evaluation task type")
116
+ group.add_argument("--nodebug",
117
+ action="store_true",
118
+ required=False,
119
+ help="debug or not")
120
+
121
+
122
+ if phase == "demo":
123
+ group.add_argument(
124
+ "--example",
125
+ type=str,
126
+ required=False,
127
+ help="input text and lengths with txt format",
128
+ )
129
+ group.add_argument(
130
+ "--out_dir",
131
+ type=str,
132
+ required=False,
133
+ help="output dir",
134
+ )
135
+ group.add_argument("--task",
136
+ type=str,
137
+ required=False,
138
+ help="evaluation task type")
139
+
140
+ if phase == "render":
141
+ group.add_argument("--npy",
142
+ type=str,
143
+ required=False,
144
+ default=None,
145
+ help="npy motion files")
146
+ group.add_argument("--dir",
147
+ type=str,
148
+ required=False,
149
+ default=None,
150
+ help="npy motion folder")
151
+ group.add_argument("--fps",
152
+ type=int,
153
+ required=False,
154
+ default=30,
155
+ help="render fps")
156
+ group.add_argument(
157
+ "--mode",
158
+ type=str,
159
+ required=False,
160
+ default="sequence",
161
+ help="render target: video, sequence, frame",
162
+ )
163
+
164
+ params = parser.parse_args()
165
+
166
+ # Load yaml config files
167
+ OmegaConf.register_new_resolver("eval", eval)
168
+ cfg_assets = OmegaConf.load(params.cfg_assets)
169
+ cfg_base = OmegaConf.load(pjoin(cfg_assets.CONFIG_FOLDER, 'default.yaml'))
170
+ cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load(params.cfg))
171
+ if not cfg_exp.FULL_CONFIG:
172
+ cfg_exp = get_module_config(cfg_exp, cfg_assets.CONFIG_FOLDER)
173
+ cfg = OmegaConf.merge(cfg_exp, cfg_assets)
174
+
175
+ # Update config with arguments
176
+ if phase in ["train", "test"]:
177
+ cfg.TRAIN.BATCH_SIZE = params.batch_size if params.batch_size else cfg.TRAIN.BATCH_SIZE
178
+ cfg.DEVICE = params.device if params.device else cfg.DEVICE
179
+ cfg.NUM_NODES = params.num_nodes if params.num_nodes else cfg.NUM_NODES
180
+ cfg.model.params.task = params.task if params.task else cfg.model.params.task
181
+ cfg.DEBUG = not params.nodebug if params.nodebug is not None else cfg.DEBUG
182
+
183
+ # Force no debug in test
184
+ if phase == "test":
185
+ cfg.DEBUG = False
186
+ cfg.DEVICE = [0]
187
+ print("Force no debugging and one gpu when testing")
188
+
189
+ if phase == "demo":
190
+ cfg.DEMO.RENDER = params.render
191
+ cfg.DEMO.FRAME_RATE = params.frame_rate
192
+ cfg.DEMO.EXAMPLE = params.example
193
+ cfg.DEMO.TASK = params.task
194
+ cfg.TEST.FOLDER = params.out_dir if params.out_dir else cfg.TEST.FOLDER
195
+ os.makedirs(cfg.TEST.FOLDER, exist_ok=True)
196
+
197
+ if phase == "render":
198
+ if params.npy:
199
+ cfg.RENDER.NPY = params.npy
200
+ cfg.RENDER.INPUT_MODE = "npy"
201
+ if params.dir:
202
+ cfg.RENDER.DIR = params.dir
203
+ cfg.RENDER.INPUT_MODE = "dir"
204
+ if params.fps:
205
+ cfg.RENDER.FPS = float(params.fps)
206
+ cfg.RENDER.MODE = params.mode
207
+
208
+ # Debug mode
209
+ if cfg.DEBUG:
210
+ cfg.NAME = "debug--" + cfg.NAME
211
+ cfg.LOGGER.WANDB.params.offline = True
212
+ cfg.LOGGER.VAL_EVERY_STEPS = 1
213
+
214
+ # Resume config
215
+ cfg = resume_config(cfg)
216
+
217
+ return cfg
mGPT/data/HumanML3D.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from os.path import join as pjoin
4
+ from .humanml.utils.word_vectorizer import WordVectorizer
5
+ from .humanml.scripts.motion_process import (process_file, recover_from_ric)
6
+ from . import BASEDataModule
7
+ from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken, Text2MotionDatasetM2T
8
+ from .utils import humanml3d_collate
9
+
10
+
11
+ class HumanML3DDataModule(BASEDataModule):
12
+ def __init__(self, cfg, **kwargs):
13
+
14
+ super().__init__(collate_fn=humanml3d_collate)
15
+ self.cfg = cfg
16
+ self.save_hyperparameters(logger=False)
17
+
18
+ # Basic info of the dataset
19
+ cfg.DATASET.JOINT_TYPE = 'humanml3d'
20
+ self.name = "humanml3d"
21
+ self.njoints = 22
22
+
23
+ # Path to the dataset
24
+ data_root = cfg.DATASET.HUMANML3D.ROOT
25
+ self.hparams.data_root = data_root
26
+ self.hparams.text_dir = pjoin(data_root, "texts")
27
+ self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs')
28
+
29
+ # Mean and std of the dataset
30
+ self.hparams.mean = np.load(pjoin('assets/meta', "mean.npy"))
31
+ self.hparams.std = np.load(pjoin('assets/meta', "std.npy"))
32
+
33
+ # Mean and std for fair evaluation
34
+ self.hparams.mean_eval = np.load(pjoin('assets/meta', "mean_eval.npy"))
35
+ self.hparams.std_eval = np.load(pjoin('assets/meta', "std_eval.npy"))
36
+
37
+ # Length of the dataset
38
+ self.hparams.max_motion_length = cfg.DATASET.HUMANML3D.MAX_MOTION_LEN
39
+ self.hparams.min_motion_length = cfg.DATASET.HUMANML3D.MIN_MOTION_LEN
40
+ self.hparams.max_text_len = cfg.DATASET.HUMANML3D.MAX_TEXT_LEN
41
+ self.hparams.unit_length = cfg.DATASET.HUMANML3D.UNIT_LEN
42
+
43
+ # Additional parameters
44
+ self.hparams.debug = cfg.DEBUG
45
+ self.hparams.stage = cfg.TRAIN.STAGE
46
+
47
+ # Dataset switch
48
+ self.DatasetEval = Text2MotionDatasetEval
49
+
50
+ if cfg.TRAIN.STAGE == "vae":
51
+ if cfg.model.params.motion_vae.target.split('.')[-1].lower() == "vqvae":
52
+ self.hparams.win_size = 64
53
+ self.Dataset = MotionDatasetVQ
54
+ else:
55
+ self.Dataset = MotionDataset
56
+ elif 'lm' in cfg.TRAIN.STAGE:
57
+ self.hparams.code_path = cfg.DATASET.CODE_PATH
58
+ self.hparams.task_path = cfg.DATASET.TASK_PATH
59
+ self.hparams.std_text = cfg.DATASET.HUMANML3D.STD_TEXT
60
+ self.Dataset = Text2MotionDatasetCB
61
+ elif cfg.TRAIN.STAGE == "token":
62
+ self.Dataset = Text2MotionDatasetToken
63
+ self.DatasetEval = Text2MotionDatasetToken
64
+ elif cfg.TRAIN.STAGE == "m2t":
65
+ self.Dataset = Text2MotionDatasetM2T
66
+ self.DatasetEval = Text2MotionDatasetM2T
67
+ else:
68
+ self.Dataset = Text2MotionDataset
69
+
70
+ # Get additional info of the dataset
71
+ self.nfeats = 263
72
+ cfg.DATASET.NFEATS = self.nfeats
73
+
74
+
75
+ def feats2joints(self, features):
76
+ mean = torch.tensor(self.hparams.mean).to(features)
77
+ std = torch.tensor(self.hparams.std).to(features)
78
+ features = features * std + mean
79
+ return recover_from_ric(features, self.njoints)
80
+
81
+ def joints2feats(self, features):
82
+ features = process_file(features, self.njoints)[0]
83
+ return features
84
+
85
+ def normalize(self, features):
86
+ mean = torch.tensor(self.hparams.mean).to(features)
87
+ std = torch.tensor(self.hparams.std).to(features)
88
+ features = (features - mean) / std
89
+ return features
90
+
91
+ def denormalize(self, features):
92
+ mean = torch.tensor(self.hparams.mean).to(features)
93
+ std = torch.tensor(self.hparams.std).to(features)
94
+ features = features * std + mean
95
+ return features
96
+
97
+ def renorm4t2m(self, features):
98
+ # renorm to t2m norms for using t2m evaluators
99
+ ori_mean = torch.tensor(self.hparams.mean).to(features)
100
+ ori_std = torch.tensor(self.hparams.std).to(features)
101
+ eval_mean = torch.tensor(self.hparams.mean_eval).to(features)
102
+ eval_std = torch.tensor(self.hparams.std_eval).to(features)
103
+ features = features * ori_std + ori_mean
104
+ features = (features - eval_mean) / eval_std
105
+ return features
106
+
107
+ def mm_mode(self, mm_on=True):
108
+ if mm_on:
109
+ self.is_mm = True
110
+ self.name_list = self.test_dataset.name_list
111
+ self.mm_list = np.random.choice(self.name_list,
112
+ self.cfg.METRIC.MM_NUM_SAMPLES,
113
+ replace=False)
114
+ self.test_dataset.name_list = self.mm_list
115
+ else:
116
+ self.is_mm = False
117
+ self.test_dataset.name_list = self.name_list
mGPT/data/Kit.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from os.path import join as pjoin
4
+ from .humanml.utils.word_vectorizer import WordVectorizer
5
+ from .humanml.scripts.motion_process import (process_file, recover_from_ric)
6
+ from .HumanML3D import HumanML3DDataModule
7
+ from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken
8
+
9
+
10
+ class KitDataModule(HumanML3DDataModule):
11
+ def __init__(self, cfg, **kwargs):
12
+
13
+ super().__init__(cfg, **kwargs)
14
+
15
+ # Basic info of the dataset
16
+ self.name = "kit"
17
+ self.njoints = 21
18
+
19
+ # Path to the dataset
20
+ data_root = cfg.DATASET.KIT.ROOT
21
+ self.hparams.data_root = data_root
22
+ self.hparams.text_dir = pjoin(data_root, "texts")
23
+ self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs')
24
+
25
+ # Mean and std of the dataset
26
+ dis_data_root = pjoin(cfg.DATASET.KIT.MEAN_STD_PATH, 'kit',
27
+ "VQVAEV3_CB1024_CMT_H1024_NRES3", "meta")
28
+ self.hparams.mean = np.load(pjoin(dis_data_root, "mean.npy"))
29
+ self.hparams.std = np.load(pjoin(dis_data_root, "std.npy"))
30
+
31
+ # Mean and std for fair evaluation
32
+ dis_data_root_eval = pjoin(cfg.DATASET.KIT.MEAN_STD_PATH, 't2m',
33
+ "Comp_v6_KLD005", "meta")
34
+ self.hparams.mean_eval = np.load(pjoin(dis_data_root_eval, "mean.npy"))
35
+ self.hparams.std_eval = np.load(pjoin(dis_data_root_eval, "std.npy"))
36
+
37
+ # Length of the dataset
38
+ self.hparams.max_motion_length = cfg.DATASET.KIT.MAX_MOTION_LEN
39
+ self.hparams.min_motion_length = cfg.DATASET.KIT.MIN_MOTION_LEN
40
+ self.hparams.max_text_len = cfg.DATASET.KIT.MAX_TEXT_LEN
41
+ self.hparams.unit_length = cfg.DATASET.KIT.UNIT_LEN
42
+
43
+ # Get additional info of the dataset
44
+ self._sample_set = self.get_sample_set(overrides={"split": "test", "tiny": True})
45
+ self.nfeats = self._sample_set.nfeats
46
+ cfg.DATASET.NFEATS = self.nfeats
47
+
48
+ def feats2joints(self, features):
49
+ mean = torch.tensor(self.hparams.mean).to(features)
50
+ std = torch.tensor(self.hparams.std).to(features)
51
+ features = features * std + mean
52
+ return recover_from_ric(features, self.njoints)
53
+
54
+ def joints2feats(self, features):
55
+ features = process_file(features, self.njoints)[0]
56
+ # mean = torch.tensor(self.hparams.mean).to(features)
57
+ # std = torch.tensor(self.hparams.std).to(features)
58
+ # features = (features - mean) / std
59
+ return features
60
+
61
+ def normalize(self, features):
62
+ mean = torch.tensor(self.hparams.mean).to(features)
63
+ std = torch.tensor(self.hparams.std).to(features)
64
+ features = (features - mean) / std
65
+ return features
66
+
67
+ def renorm4t2m(self, features):
68
+ # renorm to t2m norms for using t2m evaluators
69
+ ori_mean = torch.tensor(self.hparams.mean).to(features)
70
+ ori_std = torch.tensor(self.hparams.std).to(features)
71
+ eval_mean = torch.tensor(self.hparams.mean_eval).to(features)
72
+ eval_std = torch.tensor(self.hparams.std_eval).to(features)
73
+ features = features * ori_std + ori_mean
74
+ features = (features - eval_mean) / eval_std
75
+ return features
76
+
77
+ def mm_mode(self, mm_on=True):
78
+ # random select samples for mm
79
+ if mm_on:
80
+ self.is_mm = True
81
+ self.name_list = self.test_dataset.name_list
82
+ self.mm_list = np.random.choice(self.name_list,
83
+ self.cfg.METRIC.MM_NUM_SAMPLES,
84
+ replace=False)
85
+ self.test_dataset.name_list = self.mm_list
86
+ else:
87
+ self.is_mm = False
88
+ self.test_dataset.name_list = self.name_list
mGPT/data/__init__.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ from torch.utils.data import DataLoader
3
+
4
+
5
+ class BASEDataModule(pl.LightningDataModule):
6
+ def __init__(self, collate_fn):
7
+ super().__init__()
8
+
9
+ self.dataloader_options = {"collate_fn": collate_fn}
10
+ self.persistent_workers = True
11
+ self.is_mm = False
12
+
13
+ self._train_dataset = None
14
+ self._val_dataset = None
15
+ self._test_dataset = None
16
+
17
+ def get_sample_set(self, overrides={}):
18
+ sample_params = self.hparams.copy()
19
+ sample_params.update(overrides)
20
+ return self.DatasetEval(**sample_params)
21
+
22
+ @property
23
+ def train_dataset(self):
24
+ if self._train_dataset is None:
25
+ self._train_dataset = self.Dataset(split=self.cfg.TRAIN.SPLIT,
26
+ **self.hparams)
27
+ return self._train_dataset
28
+
29
+ @property
30
+ def val_dataset(self):
31
+ if self._val_dataset is None:
32
+ params = self.hparams.copy()
33
+ params['code_path'] = None
34
+ params['split'] = self.cfg.EVAL.SPLIT
35
+ self._val_dataset = self.DatasetEval(**params)
36
+ return self._val_dataset
37
+
38
+ @property
39
+ def test_dataset(self):
40
+ if self._test_dataset is None:
41
+ # self._test_dataset = self.DatasetEval(split=self.cfg.TEST.SPLIT,
42
+ # **self.hparams)
43
+ params = self.hparams.copy()
44
+ params['code_path'] = None
45
+ params['split'] = self.cfg.TEST.SPLIT
46
+ self._test_dataset = self.DatasetEval( **params)
47
+ return self._test_dataset
48
+
49
+ def setup(self, stage=None):
50
+ # Use the getter the first time to load the data
51
+ if stage in (None, "fit"):
52
+ _ = self.train_dataset
53
+ _ = self.val_dataset
54
+ if stage in (None, "test"):
55
+ _ = self.test_dataset
56
+
57
+ def train_dataloader(self):
58
+ dataloader_options = self.dataloader_options.copy()
59
+ dataloader_options["batch_size"] = self.cfg.TRAIN.BATCH_SIZE
60
+ dataloader_options["num_workers"] = self.cfg.TRAIN.NUM_WORKERS
61
+ return DataLoader(
62
+ self.train_dataset,
63
+ shuffle=False,
64
+ persistent_workers=True,
65
+ **dataloader_options,
66
+ )
67
+
68
+ def predict_dataloader(self):
69
+ dataloader_options = self.dataloader_options.copy()
70
+ dataloader_options[
71
+ "batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
72
+ dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
73
+ dataloader_options["shuffle"] = False
74
+ return DataLoader(
75
+ self.test_dataset,
76
+ persistent_workers=True,
77
+ **dataloader_options,
78
+ )
79
+
80
+ def val_dataloader(self):
81
+ # overrides batch_size and num_workers
82
+ dataloader_options = self.dataloader_options.copy()
83
+ dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE
84
+ dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS
85
+ dataloader_options["shuffle"] = False
86
+ return DataLoader(
87
+ self.val_dataset,
88
+ persistent_workers=True,
89
+ **dataloader_options,
90
+ )
91
+
92
+ def test_dataloader(self):
93
+ # overrides batch_size and num_workers
94
+ dataloader_options = self.dataloader_options.copy()
95
+ dataloader_options[
96
+ "batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
97
+ dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
98
+ dataloader_options["shuffle"] = False
99
+ return DataLoader(
100
+ self.test_dataset,
101
+ persistent_workers=True,
102
+ **dataloader_options,
103
+ )
mGPT/data/build_data.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ from os.path import join as pjoin
3
+ from mGPT.config import instantiate_from_config
4
+
5
+
6
+ def build_data(cfg, phase="train"):
7
+ data_config = OmegaConf.to_container(cfg.DATASET, resolve=True)
8
+ data_config['params'] = {'cfg': cfg, 'phase': phase}
9
+ if isinstance(data_config['target'], str):
10
+ return instantiate_from_config(data_config)
11
+ elif isinstance(data_config['target'], list):
12
+ data_config_tmp = data_config.copy()
13
+ data_config_tmp['params']['dataModules'] = data_config['target']
14
+ data_config_tmp['target'] = 'mGPT.data.Concat.ConcatDataModule'
15
+ return instantiate_from_config(data_config)