diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..78bca0ba017c38e6187cb95c2ee7ec31e099241b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +main/assets/in_between_edit.gif filter=lfs diff=lfs merge=lfs -text +main/mydiffusion_zeggs/0001-0933.mkv filter=lfs diff=lfs merge=lfs -text +main/mydiffusion_zeggs/0001-0933.mp4 filter=lfs diff=lfs merge=lfs -text +main/mydiffusion_zeggs/015_Happy_4_x_1_0.wav filter=lfs diff=lfs merge=lfs -text +ubisoft-laforge-ZeroEGGS-main/ZEGGS/bvh2fbx/LaForgeFemale.fbx filter=lfs diff=lfs merge=lfs -text +ubisoft-laforge-ZeroEGGS-main/ZEGGS/bvh2fbx/Rendered/001_Neutral_0_x_0_9.bvh filter=lfs diff=lfs merge=lfs -text +ubisoft-laforge-ZeroEGGS-main/ZEGGS/bvh2fbx/Rendered/001_Neutral_0_x_0_9.fbx filter=lfs diff=lfs merge=lfs -text +ubisoft-laforge-ZeroEGGS-main/ZEGGS/bvh2fbx/Rendered/001_Neutral_0_x_0_9.wav filter=lfs diff=lfs merge=lfs -text diff --git a/Framework.png b/Framework.png new file mode 100644 index 0000000000000000000000000000000000000000..4ee69ca32e6fc43252c391d42d0b174e34f83a94 Binary files /dev/null and b/Framework.png differ diff --git a/README.md b/README.md index b187bb7e7d837a367ccd0862441947ad412c77f7..cbd368fbdd1a2d4339a1edc61cf4a5f30648aa66 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,86 @@ ---- -license: cc-by-4.0 ---- +# DiffuseStyleGesture: Stylized Audio-Driven Co-Speech Gesture Generation with Diffusion Models + +[![arXiv](https://img.shields.io/badge/arXiv-2305.04919-red.svg)](https://arxiv.org/abs/2305.04919) + + + +
+ +
+ + +## News + +๐Ÿ“ข **9/May/23** - First release - arxiv, code and pre-trained models. + + +## 1. Getting started + +This code was tested on `NVIDIA GeForce RTX 2080 Ti` and requires: + +* conda3 or miniconda3 + +``` +conda create -n DiffuseStyleGesture python=3.7 +pip install -r requirements.txt +``` + +[//]: # (-i https://pypi.tuna.tsinghua.edu.cn/simple) + +## 2. Quick Start + +1. Download pre-trained model from [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/8ade7c73e05c4549ac6b/) or [Google Cloud](https://drive.google.com/file/d/1RlusxWJFJMyauXdbfbI_XreJwVRnrBv_/view?usp=share_link) +and put it into `./main/mydiffusion_zeggs/`. +2. Download the [WavLM Large](https://github.com/microsoft/unilm/tree/master/wavlm) and put it into `./main/mydiffusion_zeggs/WavLM/`. +3. cd `./main/mydiffusion_zeggs/` and run +```python +python sample.py --config=./configs/DiffuseStyleGesture.yml --no_cuda 0 --gpu 0 --model_path './model000450000.pt' --audiowavlm_path "./015_Happy_4_x_1_0.wav" --max_len 320 +``` +You will get the `.bvh` file named `yyyymmdd_hhmmss_smoothing_SG_minibatch_320_[1, 0, 0, 0, 0, 0]_123456.bvh` in the `sample_dir` folder, which can then be visualized using [Blender](https://www.blender.org/). + +## 3. Train your own model + +### (1) Get ZEGGS dataset + +Same as [ZEGGS](https://github.com/ubisoft/ubisoft-laforge-ZeroEGGS). + +An example is as follows. +Download original ZEGGS datasets from [here](https://github.com/ubisoft/ubisoft-laforge-ZeroEGGS) and put it in `./ubisoft-laforge-ZeroEGGS-main/data/` folder. +Then `cd ./ubisoft-laforge-ZeroEGGS-main/ZEGGS` and run `python data_pipeline.py` to process the dataset. +You will get `./ubisoft-laforge-ZeroEGGS-main/data/processed_v1/trimmed/train/` and `./ubisoft-laforge-ZeroEGGS-main/data/processed_v1/trimmed/test/` folders. + +If you find it difficult to obtain and process the data, you can download the data after it has been processed by ZEGGS from [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/ba5f3b33d94b4cba875b/) or [Baidu Cloud](https://pan.baidu.com/s/1KakkGpRZWfaJzfN5gQvPAw?pwd=vfuc). +And put it in `./ubisoft-laforge-ZeroEGGS-main/data/processed_v1/trimmed/` folder. + + +### (2) Process ZEGGS dataset + +``` +cd ./main/mydiffusion_zeggs/ +python zeggs_data_to_lmdb.py +``` + +### (3) Train + +``` +python end2end.py --config=./configs/DiffuseStyleGesture.yml --no_cuda 0 --gpu 0 +``` +The model will save in `./main/mydiffusion_zeggs/zeggs_mymodel3_wavlm/` folder. + +## Reference +Our work mainly inspired by: [MDM](https://github.com/GuyTevet/motion-diffusion-model), [Text2Gesture](https://github.com/youngwoo-yoon/Co-Speech_Gesture_Generation), [Listen, denoise, action!](https://arxiv.org/abs/2211.09707) + +## Citation +If you find this code useful in your research, please cite: + +``` +@inproceedings{yang2023DiffuseStyleGesture, + author = {Sicheng Yang and Zhiyong Wu and Minglei Li and Zhensong Zhang and Lei Hao and Weihong Bao and Ming Cheng and Long Xiao}, + title = {DiffuseStyleGesture: Stylized Audio-Driven Co-Speech Gesture Generation with Diffusion Models}, + booktitle = {Proceedings of the 32nd International Joint Conference on Artificial Intelligence, {IJCAI} 2023}, + publisher = {ijcai.org}, + year = {2023}, +} +``` + +Please feel free to contact us ([yangsc21@mails.tsinghua.edu.cn](yangsc21@mails.tsinghua.edu.cn)) with any question or concerns. diff --git a/main/.gitignore b/main/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b6e47617de110dea7ca47e087ff1347cc2646eda --- /dev/null +++ b/main/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/main/LICENSE b/main/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..54055df7f1b68c2a10a2de314e160182de405e91 --- /dev/null +++ b/main/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Guy Tevet + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/main/assets/example_action_names_humanact12.txt b/main/assets/example_action_names_humanact12.txt new file mode 100644 index 0000000000000000000000000000000000000000..31b0728a52f155e585a1d5c4afe63e5d9dc8bdf5 --- /dev/null +++ b/main/assets/example_action_names_humanact12.txt @@ -0,0 +1,2 @@ +drink +lift_dumbbell diff --git a/main/assets/example_action_names_uestc.txt b/main/assets/example_action_names_uestc.txt new file mode 100644 index 0000000000000000000000000000000000000000..a3095abdecfed8cd63c4b73e8cf2d20cad606872 --- /dev/null +++ b/main/assets/example_action_names_uestc.txt @@ -0,0 +1,7 @@ +jumping-jack +left-lunging +left-stretching +raising-hand-and-jumping +rotation-clapping +front-raising +pulling-chest-expanders diff --git a/main/assets/example_stick_fig.gif b/main/assets/example_stick_fig.gif new file mode 100644 index 0000000000000000000000000000000000000000..78db0677019f8e73c82376ca3003631ef90a4057 Binary files /dev/null and b/main/assets/example_stick_fig.gif differ diff --git a/main/assets/example_text_prompts.txt b/main/assets/example_text_prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..57e6ed42fc8810ab071edef9fc310f988c10d2e3 --- /dev/null +++ b/main/assets/example_text_prompts.txt @@ -0,0 +1,8 @@ +person got down and is crawling across the floor. +a person walks forward with wide steps. +a person drops their hands then brings them together in front of their face clasped. +a person lifts their right arm and slaps something, then repeats the motion again. +a person walks forward and stops. +a person marches forward, turns around, and then marches back. +a person is stretching their arms. +person is making attention gesture \ No newline at end of file diff --git a/main/assets/in_between_edit.gif b/main/assets/in_between_edit.gif new file mode 100644 index 0000000000000000000000000000000000000000..ead5432f68670a17c347bf445777a3135a8fadd6 --- /dev/null +++ b/main/assets/in_between_edit.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1ed52e9f08d96500c8a414830065f263d6305e803d0ddec393f5dbb3981ecd0 +size 1518198 diff --git a/main/assets/upper_body_edit.gif b/main/assets/upper_body_edit.gif new file mode 100644 index 0000000000000000000000000000000000000000..1e03e6f10795427b18263673b1dcfb68e25875c2 Binary files /dev/null and b/main/assets/upper_body_edit.gif differ diff --git a/main/body_models/README.md b/main/body_models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e63996a323ce132e06a343e7ba109ab970affab2 --- /dev/null +++ b/main/body_models/README.md @@ -0,0 +1,3 @@ +## Body models + +Put SMPL models here (full instractions in the main README) \ No newline at end of file diff --git a/main/data_loaders/a2m/dataset.py b/main/data_loaders/a2m/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b34ce34362f98f3d0158cb0ca2912df35f512bd5 --- /dev/null +++ b/main/data_loaders/a2m/dataset.py @@ -0,0 +1,255 @@ +import random + +import numpy as np +import torch +# from utils.action_label_to_idx import action_label_to_idx +from data_loaders.tensors import collate +from utils.misc import to_torch +import utils.rotation_conversions as geometry + +class Dataset(torch.utils.data.Dataset): + def __init__(self, num_frames=1, sampling="conseq", sampling_step=1, split="train", + pose_rep="rot6d", translation=True, glob=True, max_len=-1, min_len=-1, num_seq_max=-1, **kwargs): + self.num_frames = num_frames + self.sampling = sampling + self.sampling_step = sampling_step + self.split = split + self.pose_rep = pose_rep + self.translation = translation + self.glob = glob + self.max_len = max_len + self.min_len = min_len + self.num_seq_max = num_seq_max + + self.align_pose_frontview = kwargs.get('align_pose_frontview', False) + self.use_action_cat_as_text_labels = kwargs.get('use_action_cat_as_text_labels', False) + self.only_60_classes = kwargs.get('only_60_classes', False) + self.leave_out_15_classes = kwargs.get('leave_out_15_classes', False) + self.use_only_15_classes = kwargs.get('use_only_15_classes', False) + + if self.split not in ["train", "val", "test"]: + raise ValueError(f"{self.split} is not a valid split") + + super().__init__() + + # to remove shuffling + self._original_train = None + self._original_test = None + + def action_to_label(self, action): + return self._action_to_label[action] + + def label_to_action(self, label): + import numbers + if isinstance(label, numbers.Integral): + return self._label_to_action[label] + else: # if it is one hot vector + label = np.argmax(label) + return self._label_to_action[label] + + def get_pose_data(self, data_index, frame_ix): + pose = self._load(data_index, frame_ix) + label = self.get_label(data_index) + return pose, label + + def get_label(self, ind): + action = self.get_action(ind) + return self.action_to_label(action) + + def get_action(self, ind): + return self._actions[ind] + + def action_to_action_name(self, action): + return self._action_classes[action] + + def action_name_to_action(self, action_name): + # self._action_classes is either a list or a dictionary. If it's a dictionary, we 1st convert it to a list + all_action_names = self._action_classes + if isinstance(all_action_names, dict): + all_action_names = list(all_action_names.values()) + assert list(self._action_classes.keys()) == list(range(len(all_action_names))) # the keys should be ordered from 0 to num_actions + + sorter = np.argsort(all_action_names) + actions = sorter[np.searchsorted(all_action_names, action_name, sorter=sorter)] + return actions + + def __getitem__(self, index): + if self.split == 'train': + data_index = self._train[index] + else: + data_index = self._test[index] + + # inp, target = self._get_item_data_index(data_index) + # return inp, target + return self._get_item_data_index(data_index) + + def _load(self, ind, frame_ix): + pose_rep = self.pose_rep + if pose_rep == "xyz" or self.translation: + if getattr(self, "_load_joints3D", None) is not None: + # Locate the root joint of initial pose at origin + joints3D = self._load_joints3D(ind, frame_ix) + joints3D = joints3D - joints3D[0, 0, :] + ret = to_torch(joints3D) + if self.translation: + ret_tr = ret[:, 0, :] + else: + if pose_rep == "xyz": + raise ValueError("This representation is not possible.") + if getattr(self, "_load_translation") is None: + raise ValueError("Can't extract translations.") + ret_tr = self._load_translation(ind, frame_ix) + ret_tr = to_torch(ret_tr - ret_tr[0]) + + if pose_rep != "xyz": + if getattr(self, "_load_rotvec", None) is None: + raise ValueError("This representation is not possible.") + else: + pose = self._load_rotvec(ind, frame_ix) + if not self.glob: + pose = pose[:, 1:, :] + pose = to_torch(pose) + if self.align_pose_frontview: + first_frame_root_pose_matrix = geometry.axis_angle_to_matrix(pose[0][0]) + all_root_poses_matrix = geometry.axis_angle_to_matrix(pose[:, 0, :]) + aligned_root_poses_matrix = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1), + all_root_poses_matrix) + pose[:, 0, :] = geometry.matrix_to_axis_angle(aligned_root_poses_matrix) + + if self.translation: + ret_tr = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1).float(), + torch.transpose(ret_tr, 0, 1)) + ret_tr = torch.transpose(ret_tr, 0, 1) + + if pose_rep == "rotvec": + ret = pose + elif pose_rep == "rotmat": + ret = geometry.axis_angle_to_matrix(pose).view(*pose.shape[:2], 9) + elif pose_rep == "rotquat": + ret = geometry.axis_angle_to_quaternion(pose) + elif pose_rep == "rot6d": + ret = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(pose)) + if pose_rep != "xyz" and self.translation: + padded_tr = torch.zeros((ret.shape[0], ret.shape[2]), dtype=ret.dtype) + padded_tr[:, :3] = ret_tr + ret = torch.cat((ret, padded_tr[:, None]), 1) + ret = ret.permute(1, 2, 0).contiguous() + return ret.float() + + def _get_item_data_index(self, data_index): + nframes = self._num_frames_in_video[data_index] + + if self.num_frames == -1 and (self.max_len == -1 or nframes <= self.max_len): + frame_ix = np.arange(nframes) + else: + if self.num_frames == -2: + if self.min_len <= 0: + raise ValueError("You should put a min_len > 0 for num_frames == -2 mode") + if self.max_len != -1: + max_frame = min(nframes, self.max_len) + else: + max_frame = nframes + + num_frames = random.randint(self.min_len, max(max_frame, self.min_len)) + else: + num_frames = self.num_frames if self.num_frames != -1 else self.max_len + + if num_frames > nframes: + fair = False # True + if fair: + # distills redundancy everywhere + choices = np.random.choice(range(nframes), + num_frames, + replace=True) + frame_ix = sorted(choices) + else: + # adding the last frame until done + ntoadd = max(0, num_frames - nframes) + lastframe = nframes - 1 + padding = lastframe * np.ones(ntoadd, dtype=int) + frame_ix = np.concatenate((np.arange(0, nframes), + padding)) + + elif self.sampling in ["conseq", "random_conseq"]: + step_max = (nframes - 1) // (num_frames - 1) + if self.sampling == "conseq": + if self.sampling_step == -1 or self.sampling_step * (num_frames - 1) >= nframes: + step = step_max + else: + step = self.sampling_step + elif self.sampling == "random_conseq": + step = random.randint(1, step_max) + + lastone = step * (num_frames - 1) + shift_max = nframes - lastone - 1 + shift = random.randint(0, max(0, shift_max - 1)) + frame_ix = shift + np.arange(0, lastone + 1, step) + + elif self.sampling == "random": + choices = np.random.choice(range(nframes), + num_frames, + replace=False) + frame_ix = sorted(choices) + + else: + raise ValueError("Sampling not recognized.") + + inp, action = self.get_pose_data(data_index, frame_ix) + + + output = {'inp': inp, 'action': action} + + if hasattr(self, '_actions') and hasattr(self, '_action_classes'): + output['action_text'] = self.action_to_action_name(self.get_action(data_index)) + + return output + + + def get_mean_length_label(self, label): + if self.num_frames != -1: + return self.num_frames + + if self.split == 'train': + index = self._train + else: + index = self._test + + action = self.label_to_action(label) + choices = np.argwhere(self._actions[index] == action).squeeze(1) + lengths = self._num_frames_in_video[np.array(index)[choices]] + + if self.max_len == -1: + return np.mean(lengths) + else: + # make the lengths less than max_len + lengths[lengths > self.max_len] = self.max_len + return np.mean(lengths) + + def __len__(self): + num_seq_max = getattr(self, "num_seq_max", -1) + if num_seq_max == -1: + from math import inf + num_seq_max = inf + + if self.split == 'train': + return min(len(self._train), num_seq_max) + else: + return min(len(self._test), num_seq_max) + + def shuffle(self): + if self.split == 'train': + random.shuffle(self._train) + else: + random.shuffle(self._test) + + def reset_shuffle(self): + if self.split == 'train': + if self._original_train is None: + self._original_train = self._train + else: + self._train = self._original_train + else: + if self._original_test is None: + self._original_test = self._test + else: + self._test = self._original_test diff --git a/main/data_loaders/a2m/humanact12poses.py b/main/data_loaders/a2m/humanact12poses.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b8894a5e7435f0f35aee1d326fead5a3123bae --- /dev/null +++ b/main/data_loaders/a2m/humanact12poses.py @@ -0,0 +1,57 @@ +import pickle as pkl +import numpy as np +import os +from .dataset import Dataset + + +class HumanAct12Poses(Dataset): + dataname = "humanact12" + + def __init__(self, datapath="dataset/HumanAct12Poses", split="train", **kargs): + self.datapath = datapath + + super().__init__(**kargs) + + pkldatafilepath = os.path.join(datapath, "humanact12poses.pkl") + data = pkl.load(open(pkldatafilepath, "rb")) + + self._pose = [x for x in data["poses"]] + self._num_frames_in_video = [p.shape[0] for p in self._pose] + self._joints = [x for x in data["joints3D"]] + + self._actions = [x for x in data["y"]] + + total_num_actions = 12 + self.num_actions = total_num_actions + + self._train = list(range(len(self._pose))) + + keep_actions = np.arange(0, total_num_actions) + + self._action_to_label = {x: i for i, x in enumerate(keep_actions)} + self._label_to_action = {i: x for i, x in enumerate(keep_actions)} + + self._action_classes = humanact12_coarse_action_enumerator + + def _load_joints3D(self, ind, frame_ix): + return self._joints[ind][frame_ix] + + def _load_rotvec(self, ind, frame_ix): + pose = self._pose[ind][frame_ix].reshape(-1, 24, 3) + return pose + + +humanact12_coarse_action_enumerator = { + 0: "warm_up", + 1: "walk", + 2: "run", + 3: "jump", + 4: "drink", + 5: "lift_dumbbell", + 6: "sit", + 7: "eat", + 8: "turn steering wheel", + 9: "phone", + 10: "boxing", + 11: "throw", +} diff --git a/main/data_loaders/a2m/uestc.py b/main/data_loaders/a2m/uestc.py new file mode 100644 index 0000000000000000000000000000000000000000..e818b9831f587b360cf90f134074855ee5100484 --- /dev/null +++ b/main/data_loaders/a2m/uestc.py @@ -0,0 +1,226 @@ +import os +from tqdm import tqdm +import numpy as np +import pickle as pkl +import utils.rotation_conversions as geometry +import torch + +from .dataset import Dataset +# from torch.utils.data import Dataset + +action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38] + + +def get_z(cam_s, cam_pos, joints, img_size, flength): + """ + Solves for the depth offset of the model to approx. orth with persp camera. + """ + # Translate the model itself: Solve the best z that maps to orth_proj points + joints_orth_target = (cam_s * (joints[:, :2] + cam_pos) + 1) * 0.5 * img_size + height3d = np.linalg.norm(np.max(joints[:, :2], axis=0) - np.min(joints[:, :2], axis=0)) + height2d = np.linalg.norm(np.max(joints_orth_target, axis=0) - np.min(joints_orth_target, axis=0)) + tz = np.array(flength * (height3d / height2d)) + return float(tz) + + +def get_trans_from_vibe(vibe, index, use_z=True): + alltrans = [] + for t in range(vibe["joints3d"][index].shape[0]): + # Convert crop cam to orig cam + # No need! Because `convert_crop_cam_to_orig_img` from demoutils of vibe + # does this already for us :) + # Its format is: [sx, sy, tx, ty] + cam_orig = vibe["orig_cam"][index][t] + x = cam_orig[2] + y = cam_orig[3] + if use_z: + z = get_z(cam_s=cam_orig[0], # TODO: There are two scales instead of 1. + cam_pos=cam_orig[2:4], + joints=vibe['joints3d'][index][t], + img_size=540, + flength=500) + # z = 500 / (0.5 * 480 * cam_orig[0]) + else: + z = 0 + trans = [x, y, z] + alltrans.append(trans) + alltrans = np.array(alltrans) + return alltrans - alltrans[0] + + +class UESTC(Dataset): + dataname = "uestc" + + def __init__(self, datapath="dataset/uestc", method_name="vibe", view="all", **kargs): + + self.datapath = datapath + self.method_name = method_name + self.view = view + super().__init__(**kargs) + + # Load pre-computed #frames data + with open(os.path.join(datapath, 'info', 'num_frames_min.txt'), 'r') as f: + num_frames_video = np.asarray([int(s) for s in f.read().splitlines()]) + + # Out of 118 subjects -> 51 training, 67 in test + all_subjects = np.arange(1, 119) + self._tr_subjects = [ + 1, 2, 6, 12, 13, 16, 21, 24, 28, 29, 30, 31, 33, 35, 39, 41, 42, 45, 47, 50, + 52, 54, 55, 57, 59, 61, 63, 64, 67, 69, 70, 71, 73, 77, 81, 84, 86, 87, 88, + 90, 91, 93, 96, 99, 102, 103, 104, 107, 108, 112, 113] + self._test_subjects = [s for s in all_subjects if s not in self._tr_subjects] + + # Load names of 25600 videos + with open(os.path.join(datapath, 'info', 'names.txt'), 'r') as f: + videos = f.read().splitlines() + + self._videos = videos + + if self.method_name == "vibe": + vibe_data_path = os.path.join(datapath, "vibe_cache_refined.pkl") + vibe_data = pkl.load(open(vibe_data_path, "rb")) + + self._pose = vibe_data["pose"] + num_frames_method = [p.shape[0] for p in self._pose] + globpath = os.path.join(datapath, "globtrans_usez.pkl") + + if os.path.exists(globpath): + self._globtrans = pkl.load(open(globpath, "rb")) + else: + self._globtrans = [] + for index in tqdm(range(len(self._pose))): + self._globtrans.append(get_trans_from_vibe(vibe_data, index, use_z=True)) + pkl.dump(self._globtrans, open("globtrans_usez.pkl", "wb")) + self._joints = vibe_data["joints3d"] + self._jointsIx = action2motion_joints + else: + raise ValueError("This method name is not recognized.") + + num_frames_video = np.minimum(num_frames_video, num_frames_method) + num_frames_video = num_frames_video.astype(int) + self._num_frames_in_video = [x for x in num_frames_video] + + N = len(videos) + self._actions = np.zeros(N, dtype=int) + for ind in range(N): + self._actions[ind] = self.parse_action(videos[ind]) + + self._actions = [x for x in self._actions] + + total_num_actions = 40 + self.num_actions = total_num_actions + keep_actions = np.arange(0, total_num_actions) + + self._action_to_label = {x: i for i, x in enumerate(keep_actions)} + self._label_to_action = {i: x for i, x in enumerate(keep_actions)} + self.num_classes = len(keep_actions) + + self._train = [] + self._test = [] + + self.info_actions = [] + + def get_rotation(view): + theta = - view * np.pi/4 + axis = torch.tensor([0, 1, 0], dtype=torch.float) + axisangle = theta*axis + matrix = geometry.axis_angle_to_matrix(axisangle) + return matrix + + # 0 is identity if needed + rotations = {key: get_rotation(key) for key in [0, 1, 2, 3, 4, 5, 6, 7]} + + for index, video in enumerate(tqdm(videos, desc='Preparing UESTC data..')): + act, view, subject, side = self._get_action_view_subject_side(video) + self.info_actions.append({"action": act, + "view": view, + "subject": subject, + "side": side}) + if self.view == "frontview": + if side != 1: + continue + # rotate to front view + if side != 1: + # don't take the view 8 in side 2 + if view == 8: + continue + rotation = rotations[view] + global_matrix = geometry.axis_angle_to_matrix(torch.from_numpy(self._pose[index][:, :3])) + # rotate the global pose + self._pose[index][:, :3] = geometry.matrix_to_axis_angle(rotation @ global_matrix).numpy() + # rotate the joints + self._joints[index] = self._joints[index] @ rotation.T.numpy() + self._globtrans[index] = (self._globtrans[index] @ rotation.T.numpy()) + + # add the global translation to the joints + self._joints[index] = self._joints[index] + self._globtrans[index][:, None] + + if subject in self._tr_subjects: + self._train.append(index) + elif subject in self._test_subjects: + self._test.append(index) + else: + raise ValueError("This subject doesn't belong to any set.") + + # if index > 200: + # break + + # Select only sequences which have a minimum number of frames + if self.num_frames > 0: + threshold = self.num_frames*3/4 + else: + threshold = 0 + + method_extracted_ix = np.where(num_frames_video >= threshold)[0].tolist() + self._train = list(set(self._train) & set(method_extracted_ix)) + # keep the test set without modification + self._test = list(set(self._test)) + + action_classes_file = os.path.join(datapath, "info/action_classes.txt") + with open(action_classes_file, 'r') as f: + self._action_classes = np.array(f.read().splitlines()) + + # with open(processd_path, 'wb') as file: + # pkl.dump(xxx, file) + + def _load_joints3D(self, ind, frame_ix): + if len(self._joints[ind]) == 0: + raise ValueError( + f"Cannot load index {ind} in _load_joints3D function.") + if self._jointsIx is not None: + joints3D = self._joints[ind][frame_ix][:, self._jointsIx] + else: + joints3D = self._joints[ind][frame_ix] + + return joints3D + + def _load_rotvec(self, ind, frame_ix): + # 72 dim smpl + pose = self._pose[ind][frame_ix, :].reshape(-1, 24, 3) + return pose + + def _get_action_view_subject_side(self, videopath): + # TODO: Can be moved to tools.py + spl = videopath.split('_') + action = int(spl[0][1:]) + view = int(spl[1][1:]) + subject = int(spl[2][1:]) + side = int(spl[3][1:]) + return action, view, subject, side + + def _get_videopath(self, action, view, subject, side): + # Unused function + return 'a{:d}_d{:d}_p{:03d}_c{:d}_color.avi'.format( + action, view, subject, side) + + def parse_action(self, path, return_int=True): + # Override parent method + info, _, _, _ = self._get_action_view_subject_side(path) + if return_int: + return int(info) + else: + return info + + +if __name__ == "__main__": + dataset = UESTC() diff --git a/main/data_loaders/get_data.py b/main/data_loaders/get_data.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6400ff00e67c63436fce219396e6138d3cdbcc --- /dev/null +++ b/main/data_loaders/get_data.py @@ -0,0 +1,52 @@ +from torch.utils.data import DataLoader +from data_loaders.tensors import collate as all_collate +from data_loaders.tensors import t2m_collate + +def get_dataset_class(name): + if name == "amass": + from .amass import AMASS + return AMASS + elif name == "uestc": + from .a2m.uestc import UESTC + return UESTC + elif name == "humanact12": + from .a2m.humanact12poses import HumanAct12Poses + return HumanAct12Poses + elif name == "humanml": + from data_loaders.humanml.data.dataset import HumanML3D + return HumanML3D + elif name == "kit": + from data_loaders.humanml.data.dataset import KIT + return KIT + else: + raise ValueError(f'Unsupported dataset name [{name}]') + +def get_collate_fn(name, hml_mode='train'): + if hml_mode == 'gt': + from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate + return t2m_eval_collate + if name in ["humanml", "kit"]: + return t2m_collate + else: + return all_collate + + +def get_dataset(name, num_frames, split='train', hml_mode='train'): + DATA = get_dataset_class(name) + if name in ["humanml", "kit"]: + dataset = DATA(split=split, num_frames=num_frames, mode=hml_mode) + else: + dataset = DATA(split=split, num_frames=num_frames) + return dataset + + +def get_dataset_loader(name, batch_size, num_frames, split='train', hml_mode='train'): + dataset = get_dataset(name, num_frames, split, hml_mode) + collate = get_collate_fn(name, hml_mode) + + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=True, + num_workers=8, drop_last=True, collate_fn=collate + ) + + return loader \ No newline at end of file diff --git a/main/data_loaders/humanml/README.md b/main/data_loaders/humanml/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4bf224f6b341e21f549a27a000d8400c4909c6c1 --- /dev/null +++ b/main/data_loaders/humanml/README.md @@ -0,0 +1 @@ +This code is based on https://github.com/EricGuo5513/text-to-motion.git \ No newline at end of file diff --git a/main/data_loaders/humanml/common/quaternion.py b/main/data_loaders/humanml/common/quaternion.py new file mode 100644 index 0000000000000000000000000000000000000000..e2daa00aef1df60e43775864d1dd3d551f89ded8 --- /dev/null +++ b/main/data_loaders/humanml/common/quaternion.py @@ -0,0 +1,423 @@ +# Copyright (c) 2018-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import numpy as np + +_EPS4 = np.finfo(float).eps * 4.0 + +_FLOAT_EPS = np.finfo(np.float).eps + +# PyTorch-backed implementations +def qinv(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + + +def qinv_np(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return qinv(torch.from_numpy(q).float()).numpy() + + +def qnormalize(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return q / torch.norm(q, dim=-1, keepdim=True) + + +def qmul(q, r): + """ + Multiply quaternion(s) q with quaternion(s) r. + Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. + Returns q*r as a tensor of shape (*, 4). + """ + assert q.shape[-1] == 4 + assert r.shape[-1] == 4 + + original_shape = q.shape + + # Compute outer product + terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) + + w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] + x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] + y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] + z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] + return torch.stack((w, x, y, z), dim=1).view(original_shape) + + +def qrot(q, v): + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + # print(q.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) + + +def qeuler(q, order, epsilon=0, deg=True): + """ + Convert quaternion(s) q to Euler angles. + Expects a tensor of shape (*, 4), where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + + original_shape = list(q.shape) + original_shape[-1] = 3 + q = q.view(-1, 4) + + q0 = q[:, 0] + q1 = q[:, 1] + q2 = q[:, 2] + q3 = q[:, 3] + + if order == 'xyz': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + elif order == 'yzx': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) + elif order == 'zxy': + x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'xzy': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) + elif order == 'yxz': + x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'zyx': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + else: + raise + + if deg: + return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi + else: + return torch.stack((x, y, z), dim=1).view(original_shape) + + +# Numpy-backed implementations + +def qmul_np(q, r): + q = torch.from_numpy(q).contiguous().float() + r = torch.from_numpy(r).contiguous().float() + return qmul(q, r).numpy() + + +def qrot_np(q, v): + q = torch.from_numpy(q).contiguous().float() + v = torch.from_numpy(v).contiguous().float() + return qrot(q, v).numpy() + + +def qeuler_np(q, order, epsilon=0, use_gpu=False): + if use_gpu: + q = torch.from_numpy(q).cuda().float() + return qeuler(q, order, epsilon).cpu().numpy() + else: + q = torch.from_numpy(q).contiguous().float() + return qeuler(q, order, epsilon).numpy() + + +def qfix(q): + """ + Enforce quaternion continuity across the time dimension by selecting + the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) + between two consecutive frames. + + Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. + Returns a tensor of the same shape. + """ + assert len(q.shape) == 3 + assert q.shape[-1] == 4 + + result = q.copy() + dot_products = np.sum(q[1:] * q[:-1], axis=2) + mask = dot_products < 0 + mask = (np.cumsum(mask, axis=0) % 2).astype(bool) + result[1:][mask] *= -1 + return result + + +def euler2quat(e, order, deg=True): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.view(-1, 3) + + ## if euler angles in degrees + if deg: + e = e * np.pi / 180. + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) + ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) + rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.view(original_shape) + + +def expmap_to_quaternion(e): + """ + Convert axis-angle rotations (aka exponential maps) to quaternions. + Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". + Expects a tensor of shape (*, 3), where * denotes any number of dimensions. + Returns a tensor of shape (*, 4). + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + e = e.reshape(-1, 3) + + theta = np.linalg.norm(e, axis=1).reshape(-1, 1) + w = np.cos(0.5 * theta).reshape(-1, 1) + xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e + return np.concatenate((w, xyz), axis=1).reshape(original_shape) + + +def euler_to_quaternion(e, order): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.reshape(-1, 3) + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) + ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) + rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul_np(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.reshape(original_shape) + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def quaternion_to_matrix_np(quaternions): + q = torch.from_numpy(quaternions).contiguous().float() + return quaternion_to_matrix(q).numpy() + + +def quaternion_to_cont6d_np(quaternions): + rotation_mat = quaternion_to_matrix_np(quaternions) + cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) + return cont_6d + + +def quaternion_to_cont6d(quaternions): + rotation_mat = quaternion_to_matrix(quaternions) + cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) + return cont_6d + + +def cont6d_to_matrix(cont6d): + assert cont6d.shape[-1] == 6, "The last dimension must be 6" + x_raw = cont6d[..., 0:3] + y_raw = cont6d[..., 3:6] + + x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) + z = torch.cross(x, y_raw, dim=-1) + z = z / torch.norm(z, dim=-1, keepdim=True) + + y = torch.cross(z, x, dim=-1) + + x = x[..., None] + y = y[..., None] + z = z[..., None] + + mat = torch.cat([x, y, z], dim=-1) + return mat + + +def cont6d_to_matrix_np(cont6d): + q = torch.from_numpy(cont6d).contiguous().float() + return cont6d_to_matrix(q).numpy() + + +def qpow(q0, t, dtype=torch.float): + ''' q0 : tensor of quaternions + t: tensor of powers + ''' + q0 = qnormalize(q0) + theta0 = torch.acos(q0[..., 0]) + + ## if theta0 is close to zero, add epsilon to avoid NaNs + mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) + theta0 = (1 - mask) * theta0 + mask * 10e-10 + v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) + + if isinstance(t, torch.Tensor): + q = torch.zeros(t.shape + q0.shape) + theta = t.view(-1, 1) * theta0.view(1, -1) + else: ## if t is a number + q = torch.zeros(q0.shape) + theta = t * theta0 + + q[..., 0] = torch.cos(theta) + q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) + + return q.to(dtype) + + +def qslerp(q0, q1, t): + ''' + q0: starting quaternion + q1: ending quaternion + t: array of points along the way + + Returns: + Tensor of Slerps: t.shape + q0.shape + ''' + + q0 = qnormalize(q0) + q1 = qnormalize(q1) + q_ = qpow(qmul(q1, qinv(q0)), t) + + return qmul(q_, + q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous()) + + +def qbetween(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v = torch.cross(v0, v1) + w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, + keepdim=True) + return qnormalize(torch.cat([w, v], dim=-1)) + + +def qbetween_np(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v0 = torch.from_numpy(v0).float() + v1 = torch.from_numpy(v1).float() + return qbetween(v0, v1).numpy() + + +def lerp(p0, p1, t): + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]) + + new_shape = t.shape + p0.shape + new_view_t = t.shape + torch.Size([1] * len(p0.shape)) + new_view_p = torch.Size([1] * len(t.shape)) + p0.shape + p0 = p0.view(new_view_p).expand(new_shape) + p1 = p1.view(new_view_p).expand(new_shape) + t = t.view(new_view_t).expand(new_shape) + + return p0 + t * (p1 - p0) diff --git a/main/data_loaders/humanml/common/skeleton.py b/main/data_loaders/humanml/common/skeleton.py new file mode 100644 index 0000000000000000000000000000000000000000..ceaad100e96af0fe9c2fccae4c66b6bdbe39d5ca --- /dev/null +++ b/main/data_loaders/humanml/common/skeleton.py @@ -0,0 +1,199 @@ +from data_loaders.humanml.common.quaternion import * +import scipy.ndimage.filters as filters + +class Skeleton(object): + def __init__(self, offset, kinematic_tree, device): + self.device = device + self._raw_offset_np = offset.numpy() + self._raw_offset = offset.clone().detach().to(device).float() + self._kinematic_tree = kinematic_tree + self._offset = None + self._parents = [0] * len(self._raw_offset) + self._parents[0] = -1 + for chain in self._kinematic_tree: + for j in range(1, len(chain)): + self._parents[chain[j]] = chain[j-1] + + def njoints(self): + return len(self._raw_offset) + + def offset(self): + return self._offset + + def set_offset(self, offsets): + self._offset = offsets.clone().detach().to(self.device).float() + + def kinematic_tree(self): + return self._kinematic_tree + + def parents(self): + return self._parents + + # joints (batch_size, joints_num, 3) + def get_offsets_joints_batch(self, joints): + assert len(joints.shape) == 3 + _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() + for i in range(1, self._raw_offset.shape[0]): + _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] + + self._offset = _offsets.detach() + return _offsets + + # joints (joints_num, 3) + def get_offsets_joints(self, joints): + assert len(joints.shape) == 2 + _offsets = self._raw_offset.clone() + for i in range(1, self._raw_offset.shape[0]): + # print(joints.shape) + _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] + + self._offset = _offsets.detach() + return _offsets + + # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder + # joints (batch_size, joints_num, 3) + def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): + assert len(face_joint_idx) == 4 + '''Get Forward Direction''' + l_hip, r_hip, sdr_r, sdr_l = face_joint_idx + across1 = joints[:, r_hip] - joints[:, l_hip] + across2 = joints[:, sdr_r] - joints[:, sdr_l] + across = across1 + across2 + across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] + # print(across1.shape, across2.shape) + + # forward (batch_size, 3) + forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + if smooth_forward: + forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') + # forward (batch_size, 3) + forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] + + '''Get Root Rotation''' + target = np.array([[0,0,1]]).repeat(len(forward), axis=0) + root_quat = qbetween_np(forward, target) + + '''Inverse Kinematics''' + # quat_params (batch_size, joints_num, 4) + # print(joints.shape[:-1]) + quat_params = np.zeros(joints.shape[:-1] + (4,)) + # print(quat_params.shape) + root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + quat_params[:, 0] = root_quat + # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + for chain in self._kinematic_tree: + R = root_quat + for j in range(len(chain) - 1): + # (batch, 3) + u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) + # print(u.shape) + # (batch, 3) + v = joints[:, chain[j+1]] - joints[:, chain[j]] + v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] + # print(u.shape, v.shape) + rot_u_v = qbetween_np(u, v) + + R_loc = qmul_np(qinv_np(R), rot_u_v) + + quat_params[:,chain[j + 1], :] = R_loc + R = qmul_np(R, R_loc) + + return quat_params + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device) + for i in range(1, len(chain)): + R = qmul(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] + return joints + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(quat_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) + for i in range(1, len(chain)): + R = qmul_np(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] + return joints + + def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(cont6d_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix_np(cont6d_params[:, 0]) + else: + matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) + for i in range(1, len(chain)): + matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]][..., np.newaxis] + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + # skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) + joints[..., 0, :] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix(cont6d_params[:, 0]) + else: + matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device) + for i in range(1, len(chain)): + matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]].unsqueeze(-1) + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + + + + diff --git a/main/data_loaders/humanml/data/__init__.py b/main/data_loaders/humanml/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/main/data_loaders/humanml/data/dataset.py b/main/data_loaders/humanml/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..28db3d4206f552c29b92efd170663e55e4bbce3c --- /dev/null +++ b/main/data_loaders/humanml/data/dataset.py @@ -0,0 +1,783 @@ +import torch +from torch.utils import data +import numpy as np +import os +from os.path import join as pjoin +import random +import codecs as cs +from tqdm import tqdm +import spacy + +from torch.utils.data._utils.collate import default_collate +from data_loaders.humanml.utils.word_vectorizer import WordVectorizer +from data_loaders.humanml.utils.get_opt import get_opt + +# import spacy + +def collate_fn(batch): + batch.sort(key=lambda x: x[3], reverse=True) + return default_collate(batch) + + +'''For use of training text-2-motion generative model''' +class Text2MotionDataset(data.Dataset): + def __init__(self, opt, mean, std, split_file, w_vectorizer): + self.opt = opt + self.w_vectorizer = w_vectorizer + self.max_length = 20 + self.pointer = 0 + min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24 + + joints_num = opt.joints_num + + data_dict = {} + id_list = [] + with cs.open(split_file, 'r') as f: + for line in f.readlines(): + id_list.append(line.strip()) + + new_name_list = [] + length_list = [] + for name in tqdm(id_list): + try: + motion = np.load(pjoin(opt.motion_dir, name + '.npy')) + if (len(motion)) < min_motion_len or (len(motion) >= 200): + continue + text_data = [] + flag = False + with cs.open(pjoin(opt.text_dir, name + '.txt')) as f: + for line in f.readlines(): + text_dict = {} + line_split = line.strip().split('#') + caption = line_split[0] + tokens = line_split[1].split(' ') + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict['caption'] = caption + text_dict['tokens'] = tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + try: + n_motion = motion[int(f_tag*20) : int(to_tag*20)] + if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200): + continue + new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + while new_name in data_dict: + new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + data_dict[new_name] = {'motion': n_motion, + 'length': len(n_motion), + 'text':[text_dict]} + new_name_list.append(new_name) + length_list.append(len(n_motion)) + except: + print(line_split) + print(line_split[2], line_split[3], f_tag, to_tag, name) + # break + + if flag: + data_dict[name] = {'motion': motion, + 'length': len(motion), + 'text':text_data} + new_name_list.append(name) + length_list.append(len(motion)) + except: + # Some motion may not exist in KIT dataset + pass + + + name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1])) + + if opt.is_train: + # root_rot_velocity (B, seq_len, 1) + std[0:1] = std[0:1] / opt.feat_bias + # root_linear_velocity (B, seq_len, 2) + std[1:3] = std[1:3] / opt.feat_bias + # root_y (B, seq_len, 1) + std[3:4] = std[3:4] / opt.feat_bias + # ric_data (B, seq_len, (joint_num - 1)*3) + std[4: 4 + (joints_num - 1) * 3] = std[4: 4 + (joints_num - 1) * 3] / 1.0 + # rot_data (B, seq_len, (joint_num - 1)*6) + std[4 + (joints_num - 1) * 3: 4 + (joints_num - 1) * 9] = std[4 + (joints_num - 1) * 3: 4 + ( + joints_num - 1) * 9] / 1.0 + # local_velocity (B, seq_len, joint_num*3) + std[4 + (joints_num - 1) * 9: 4 + (joints_num - 1) * 9 + joints_num * 3] = std[ + 4 + (joints_num - 1) * 9: 4 + ( + joints_num - 1) * 9 + joints_num * 3] / 1.0 + # foot contact (B, seq_len, 4) + std[4 + (joints_num - 1) * 9 + joints_num * 3:] = std[ + 4 + (joints_num - 1) * 9 + joints_num * 3:] / opt.feat_bias + + assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1] + np.save(pjoin(opt.meta_dir, 'mean.npy'), mean) + np.save(pjoin(opt.meta_dir, 'std.npy'), std) + + self.mean = mean + self.std = std + self.length_arr = np.array(length_list) + self.data_dict = data_dict + self.name_list = name_list + self.reset_max_len(self.max_length) + + def reset_max_len(self, length): + assert length <= self.opt.max_motion_length + self.pointer = np.searchsorted(self.length_arr, length) + print("Pointer Pointing at %d"%self.pointer) + self.max_length = length + + def inv_transform(self, data): + return data * self.std + self.mean + + def __len__(self): + return len(self.data_dict) - self.pointer + + def __getitem__(self, item): + idx = self.pointer + item + data = self.data_dict[self.name_list[idx]] + motion, m_length, text_list = data['motion'], data['length'], data['text'] + # Randomly select a caption + text_data = random.choice(text_list) + caption, tokens = text_data['caption'], text_data['tokens'] + + if len(tokens) < self.opt.max_text_len: + # pad with "unk" + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + sent_len = len(tokens) + tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len) + else: + # crop + tokens = tokens[:self.opt.max_text_len] + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + sent_len = len(tokens) + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + + len_gap = (m_length - self.max_length) // self.opt.unit_length + + if self.opt.is_train: + if m_length != self.max_length: + # print("Motion original length:%d_%d"%(m_length, len(motion))) + if self.opt.unit_length < 10: + coin2 = np.random.choice(['single', 'single', 'double']) + else: + coin2 = 'single' + if len_gap == 0 or (len_gap == 1 and coin2 == 'double'): + m_length = self.max_length + idx = random.randint(0, m_length - self.max_length) + motion = motion[idx:idx+self.max_length] + else: + if coin2 == 'single': + n_m_length = self.max_length + self.opt.unit_length * len_gap + else: + n_m_length = self.max_length + self.opt.unit_length * (len_gap - 1) + idx = random.randint(0, m_length - n_m_length) + motion = motion[idx:idx + self.max_length] + m_length = n_m_length + # print(len_gap, idx, coin2) + else: + if self.opt.unit_length < 10: + coin2 = np.random.choice(['single', 'single', 'double']) + else: + coin2 = 'single' + + if coin2 == 'double': + m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length + elif coin2 == 'single': + m_length = (m_length // self.opt.unit_length) * self.opt.unit_length + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx:idx+m_length] + + "Z Normalization" + motion = (motion - self.mean) / self.std + + return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length + + +'''For use of training text motion matching model, and evaluations''' +class Text2MotionDatasetV2(data.Dataset): + def __init__(self, opt, mean, std, split_file, w_vectorizer): + self.opt = opt + self.w_vectorizer = w_vectorizer + self.max_length = 20 + self.pointer = 0 + self.max_motion_length = opt.max_motion_length + min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24 + + data_dict = {} + id_list = [] + with cs.open(split_file, 'r') as f: + for line in f.readlines(): + id_list.append(line.strip()) + id_list = id_list[:100] # debug + + new_name_list = [] + length_list = [] + for name in tqdm(id_list): + try: + motion = np.load(pjoin(opt.motion_dir, name + '.npy')) + if (len(motion)) < min_motion_len or (len(motion) >= 200): + continue + text_data = [] + flag = False + with cs.open(pjoin(opt.text_dir, name + '.txt')) as f: + for line in f.readlines(): + text_dict = {} + line_split = line.strip().split('#') + caption = line_split[0] + tokens = line_split[1].split(' ') + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict['caption'] = caption + text_dict['tokens'] = tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + try: + n_motion = motion[int(f_tag*20) : int(to_tag*20)] + if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200): + continue + new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + while new_name in data_dict: + new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + data_dict[new_name] = {'motion': n_motion, + 'length': len(n_motion), + 'text':[text_dict]} + new_name_list.append(new_name) + length_list.append(len(n_motion)) + except: + print(line_split) + print(line_split[2], line_split[3], f_tag, to_tag, name) + # break + + if flag: + data_dict[name] = {'motion': motion, + 'length': len(motion), + 'text': text_data} + new_name_list.append(name) + length_list.append(len(motion)) + except: + pass + + name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1])) + + self.mean = mean + self.std = std + self.length_arr = np.array(length_list) + self.data_dict = data_dict + self.name_list = name_list + self.reset_max_len(self.max_length) + + def reset_max_len(self, length): + assert length <= self.max_motion_length + self.pointer = np.searchsorted(self.length_arr, length) + print("Pointer Pointing at %d"%self.pointer) + self.max_length = length + + def inv_transform(self, data): + return data * self.std + self.mean + + def __len__(self): + return len(self.data_dict) - self.pointer + + def __getitem__(self, item): + idx = self.pointer + item + data = self.data_dict[self.name_list[idx]] + motion, m_length, text_list = data['motion'], data['length'], data['text'] + # Randomly select a caption + text_data = random.choice(text_list) + caption, tokens = text_data['caption'], text_data['tokens'] + + if len(tokens) < self.opt.max_text_len: + # pad with "unk" + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + sent_len = len(tokens) + tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len) + else: + # crop + tokens = tokens[:self.opt.max_text_len] + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + sent_len = len(tokens) + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + + # Crop the motions in to times of 4, and introduce small variations + if self.opt.unit_length < 10: + coin2 = np.random.choice(['single', 'single', 'double']) + else: + coin2 = 'single' + + if coin2 == 'double': + m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length + elif coin2 == 'single': + m_length = (m_length // self.opt.unit_length) * self.opt.unit_length + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx:idx+m_length] + + "Z Normalization" + motion = (motion - self.mean) / self.std + + if m_length < self.max_motion_length: + motion = np.concatenate([motion, + np.zeros((self.max_motion_length - m_length, motion.shape[1])) + ], axis=0) + # print(word_embeddings.shape, motion.shape) + # print(tokens) + return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens) + + +'''For use of training baseline''' +class Text2MotionDatasetBaseline(data.Dataset): + def __init__(self, opt, mean, std, split_file, w_vectorizer): + self.opt = opt + self.w_vectorizer = w_vectorizer + self.max_length = 20 + self.pointer = 0 + self.max_motion_length = opt.max_motion_length + min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24 + + data_dict = {} + id_list = [] + with cs.open(split_file, 'r') as f: + for line in f.readlines(): + id_list.append(line.strip()) + # id_list = id_list[:200] + + new_name_list = [] + length_list = [] + for name in tqdm(id_list): + try: + motion = np.load(pjoin(opt.motion_dir, name + '.npy')) + if (len(motion)) < min_motion_len or (len(motion) >= 200): + continue + text_data = [] + flag = False + with cs.open(pjoin(opt.text_dir, name + '.txt')) as f: + for line in f.readlines(): + text_dict = {} + line_split = line.strip().split('#') + caption = line_split[0] + tokens = line_split[1].split(' ') + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict['caption'] = caption + text_dict['tokens'] = tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + try: + n_motion = motion[int(f_tag*20) : int(to_tag*20)] + if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200): + continue + new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + while new_name in data_dict: + new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + data_dict[new_name] = {'motion': n_motion, + 'length': len(n_motion), + 'text':[text_dict]} + new_name_list.append(new_name) + length_list.append(len(n_motion)) + except: + print(line_split) + print(line_split[2], line_split[3], f_tag, to_tag, name) + # break + + if flag: + data_dict[name] = {'motion': motion, + 'length': len(motion), + 'text': text_data} + new_name_list.append(name) + length_list.append(len(motion)) + except: + pass + + name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1])) + + self.mean = mean + self.std = std + self.length_arr = np.array(length_list) + self.data_dict = data_dict + self.name_list = name_list + self.reset_max_len(self.max_length) + + def reset_max_len(self, length): + assert length <= self.max_motion_length + self.pointer = np.searchsorted(self.length_arr, length) + print("Pointer Pointing at %d"%self.pointer) + self.max_length = length + + def inv_transform(self, data): + return data * self.std + self.mean + + def __len__(self): + return len(self.data_dict) - self.pointer + + def __getitem__(self, item): + idx = self.pointer + item + data = self.data_dict[self.name_list[idx]] + motion, m_length, text_list = data['motion'], data['length'], data['text'] + # Randomly select a caption + text_data = random.choice(text_list) + caption, tokens = text_data['caption'], text_data['tokens'] + + if len(tokens) < self.opt.max_text_len: + # pad with "unk" + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + sent_len = len(tokens) + tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len) + else: + # crop + tokens = tokens[:self.opt.max_text_len] + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + sent_len = len(tokens) + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + + len_gap = (m_length - self.max_length) // self.opt.unit_length + + if m_length != self.max_length: + # print("Motion original length:%d_%d"%(m_length, len(motion))) + if self.opt.unit_length < 10: + coin2 = np.random.choice(['single', 'single', 'double']) + else: + coin2 = 'single' + if len_gap == 0 or (len_gap == 1 and coin2 == 'double'): + m_length = self.max_length + s_idx = random.randint(0, m_length - self.max_length) + else: + if coin2 == 'single': + n_m_length = self.max_length + self.opt.unit_length * len_gap + else: + n_m_length = self.max_length + self.opt.unit_length * (len_gap - 1) + s_idx = random.randint(0, m_length - n_m_length) + m_length = n_m_length + else: + s_idx = 0 + + src_motion = motion[s_idx: s_idx + m_length] + tgt_motion = motion[s_idx: s_idx + self.max_length] + + "Z Normalization" + src_motion = (src_motion - self.mean) / self.std + tgt_motion = (tgt_motion - self.mean) / self.std + + if m_length < self.max_motion_length: + src_motion = np.concatenate([src_motion, + np.zeros((self.max_motion_length - m_length, motion.shape[1])) + ], axis=0) + # print(m_length, src_motion.shape, tgt_motion.shape) + # print(word_embeddings.shape, motion.shape) + # print(tokens) + return word_embeddings, caption, sent_len, src_motion, tgt_motion, m_length + + +class MotionDatasetV2(data.Dataset): + def __init__(self, opt, mean, std, split_file): + self.opt = opt + joints_num = opt.joints_num + + self.data = [] + self.lengths = [] + id_list = [] + with cs.open(split_file, 'r') as f: + for line in f.readlines(): + id_list.append(line.strip()) + + for name in tqdm(id_list): + try: + motion = np.load(pjoin(opt.motion_dir, name + '.npy')) + if motion.shape[0] < opt.window_size: + continue + self.lengths.append(motion.shape[0] - opt.window_size) + self.data.append(motion) + except: + # Some motion may not exist in KIT dataset + pass + + self.cumsum = np.cumsum([0] + self.lengths) + + if opt.is_train: + # root_rot_velocity (B, seq_len, 1) + std[0:1] = std[0:1] / opt.feat_bias + # root_linear_velocity (B, seq_len, 2) + std[1:3] = std[1:3] / opt.feat_bias + # root_y (B, seq_len, 1) + std[3:4] = std[3:4] / opt.feat_bias + # ric_data (B, seq_len, (joint_num - 1)*3) + std[4: 4 + (joints_num - 1) * 3] = std[4: 4 + (joints_num - 1) * 3] / 1.0 + # rot_data (B, seq_len, (joint_num - 1)*6) + std[4 + (joints_num - 1) * 3: 4 + (joints_num - 1) * 9] = std[4 + (joints_num - 1) * 3: 4 + ( + joints_num - 1) * 9] / 1.0 + # local_velocity (B, seq_len, joint_num*3) + std[4 + (joints_num - 1) * 9: 4 + (joints_num - 1) * 9 + joints_num * 3] = std[ + 4 + (joints_num - 1) * 9: 4 + ( + joints_num - 1) * 9 + joints_num * 3] / 1.0 + # foot contact (B, seq_len, 4) + std[4 + (joints_num - 1) * 9 + joints_num * 3:] = std[ + 4 + (joints_num - 1) * 9 + joints_num * 3:] / opt.feat_bias + + assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1] + np.save(pjoin(opt.meta_dir, 'mean.npy'), mean) + np.save(pjoin(opt.meta_dir, 'std.npy'), std) + + self.mean = mean + self.std = std + print("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1])) + + def inv_transform(self, data): + return data * self.std + self.mean + + def __len__(self): + return self.cumsum[-1] + + def __getitem__(self, item): + if item != 0: + motion_id = np.searchsorted(self.cumsum, item) - 1 + idx = item - self.cumsum[motion_id] - 1 + else: + motion_id = 0 + idx = 0 + motion = self.data[motion_id][idx:idx+self.opt.window_size] + "Z Normalization" + motion = (motion - self.mean) / self.std + + return motion + + +class RawTextDataset(data.Dataset): + def __init__(self, opt, mean, std, text_file, w_vectorizer): + self.mean = mean + self.std = std + self.opt = opt + self.data_dict = [] + self.nlp = spacy.load('en_core_web_sm') + + with cs.open(text_file) as f: + for line in f.readlines(): + word_list, pos_list = self.process_text(line.strip()) + tokens = ['%s/%s'%(word_list[i], pos_list[i]) for i in range(len(word_list))] + self.data_dict.append({'caption':line.strip(), "tokens":tokens}) + + self.w_vectorizer = w_vectorizer + print("Total number of descriptions {}".format(len(self.data_dict))) + + + def process_text(self, sentence): + sentence = sentence.replace('-', '') + doc = self.nlp(sentence) + word_list = [] + pos_list = [] + for token in doc: + word = token.text + if not word.isalpha(): + continue + if (token.pos_ == 'NOUN' or token.pos_ == 'VERB') and (word != 'left'): + word_list.append(token.lemma_) + else: + word_list.append(word) + pos_list.append(token.pos_) + return word_list, pos_list + + def inv_transform(self, data): + return data * self.std + self.mean + + def __len__(self): + return len(self.data_dict) + + def __getitem__(self, item): + data = self.data_dict[item] + caption, tokens = data['caption'], data['tokens'] + + if len(tokens) < self.opt.max_text_len: + # pad with "unk" + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + sent_len = len(tokens) + tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len) + else: + # crop + tokens = tokens[:self.opt.max_text_len] + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + sent_len = len(tokens) + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + + return word_embeddings, pos_one_hots, caption, sent_len + +class TextOnlyDataset(data.Dataset): + def __init__(self, opt, mean, std, split_file): + self.mean = mean + self.std = std + self.opt = opt + self.data_dict = [] + self.max_length = 20 + self.pointer = 0 + self.fixed_length = 120 + + + data_dict = {} + id_list = [] + with cs.open(split_file, 'r') as f: + for line in f.readlines(): + id_list.append(line.strip()) + # id_list = id_list[:200] + + new_name_list = [] + length_list = [] + for name in tqdm(id_list): + try: + text_data = [] + flag = False + with cs.open(pjoin(opt.text_dir, name + '.txt')) as f: + for line in f.readlines(): + text_dict = {} + line_split = line.strip().split('#') + caption = line_split[0] + tokens = line_split[1].split(' ') + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict['caption'] = caption + text_dict['tokens'] = tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + try: + new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + while new_name in data_dict: + new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + data_dict[new_name] = {'text':[text_dict]} + new_name_list.append(new_name) + except: + print(line_split) + print(line_split[2], line_split[3], f_tag, to_tag, name) + # break + + if flag: + data_dict[name] = {'text': text_data} + new_name_list.append(name) + except: + pass + + self.length_arr = np.array(length_list) + self.data_dict = data_dict + self.name_list = new_name_list + + def inv_transform(self, data): + return data * self.std + self.mean + + def __len__(self): + return len(self.data_dict) + + def __getitem__(self, item): + idx = self.pointer + item + data = self.data_dict[self.name_list[idx]] + text_list = data['text'] + + # Randomly select a caption + text_data = random.choice(text_list) + caption, tokens = text_data['caption'], text_data['tokens'] + return None, None, caption, None, np.array([0]), self.fixed_length, None + # fixed_length can be set from outside before sampling + +# A wrapper class for t2m original dataset for MDM purposes +class HumanML3D(data.Dataset): + def __init__(self, mode, datapath='./dataset/humanml_opt.txt', split="train", **kwargs): + self.mode = mode + + self.dataset_name = 't2m' + self.dataname = 't2m' + + # Configurations of T2M dataset and KIT dataset is almost the same + abs_base_path = f'../motion-diffusion-model' + # abs_base_path = f'.' + + dataset_opt_path = pjoin(abs_base_path, datapath) + device = None # torch.device('cuda:4') # This param is not in use in this context + opt = get_opt(dataset_opt_path, device) + opt.meta_dir = pjoin(abs_base_path, opt.meta_dir) + opt.motion_dir = pjoin(abs_base_path, opt.motion_dir) + opt.text_dir = pjoin(abs_base_path, opt.text_dir) + opt.model_dir = pjoin(abs_base_path, opt.model_dir) + opt.checkpoints_dir = pjoin(abs_base_path, opt.checkpoints_dir) + opt.data_root = pjoin(abs_base_path, opt.data_root) + opt.save_root = pjoin(abs_base_path, opt.save_root) + opt.meta_dir = './dataset' + self.opt = opt + print('Loading dataset %s ...' % opt.dataset_name) + + if mode == 'gt': + # used by T2M models (including evaluators) + self.mean = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_mean.npy')) + self.std = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_std.npy')) + elif mode in ['train', 'eval', 'text_only']: + # used by our models + self.mean = np.load(pjoin(opt.data_root, 'Mean.npy')) + self.std = np.load(pjoin(opt.data_root, 'Std.npy')) + + if mode == 'eval': + # used by T2M models (including evaluators) + # this is to translate their norms to ours + self.mean_for_eval = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_mean.npy')) + self.std_for_eval = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_std.npy')) + + self.split_file = pjoin(opt.data_root, f'{split}.txt') + if mode == 'text_only': + self.t2m_dataset = TextOnlyDataset(self.opt, self.mean, self.std, self.split_file) + else: + self.w_vectorizer = WordVectorizer(pjoin(abs_base_path, 'glove'), 'our_vab') + self.t2m_dataset = Text2MotionDatasetV2(self.opt, self.mean, self.std, self.split_file, self.w_vectorizer) + self.num_actions = 1 # dummy placeholder + + assert len(self.t2m_dataset) > 1, 'You loaded an empty dataset, ' \ + 'it is probably because your data dir has only texts and no motions.\n' \ + 'To train and evaluate MDM you should get the FULL data as described ' \ + 'in the README file.' + + def __getitem__(self, item): + return self.t2m_dataset.__getitem__(item) + + def __len__(self): + return self.t2m_dataset.__len__() + +# A wrapper class for t2m original dataset for MDM purposes +class KIT(HumanML3D): + def __init__(self, mode, datapath='./dataset/kit_opt.txt', split="train", **kwargs): + super(KIT, self).__init__(mode, datapath, split, **kwargs) \ No newline at end of file diff --git a/main/data_loaders/humanml/motion_loaders/__init__.py b/main/data_loaders/humanml/motion_loaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/main/data_loaders/humanml/motion_loaders/comp_v6_model_dataset.py b/main/data_loaders/humanml/motion_loaders/comp_v6_model_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c64c261f577026afce98e0b5cd47764ddf265c18 --- /dev/null +++ b/main/data_loaders/humanml/motion_loaders/comp_v6_model_dataset.py @@ -0,0 +1,262 @@ +import torch +from data_loaders.humanml.networks.modules import * +from data_loaders.humanml.networks.trainers import CompTrainerV6 +from torch.utils.data import Dataset, DataLoader +from os.path import join as pjoin +from tqdm import tqdm +from utils import dist_util + +def build_models(opt): + if opt.text_enc_mod == 'bigru': + text_encoder = TextEncoderBiGRU(word_size=opt.dim_word, + pos_size=opt.dim_pos_ohot, + hidden_size=opt.dim_text_hidden, + device=opt.device) + text_size = opt.dim_text_hidden * 2 + else: + raise Exception("Text Encoder Mode not Recognized!!!") + + seq_prior = TextDecoder(text_size=text_size, + input_size=opt.dim_att_vec + opt.dim_movement_latent, + output_size=opt.dim_z, + hidden_size=opt.dim_pri_hidden, + n_layers=opt.n_layers_pri) + + + seq_decoder = TextVAEDecoder(text_size=text_size, + input_size=opt.dim_att_vec + opt.dim_z + opt.dim_movement_latent, + output_size=opt.dim_movement_latent, + hidden_size=opt.dim_dec_hidden, + n_layers=opt.n_layers_dec) + + att_layer = AttLayer(query_dim=opt.dim_pos_hidden, + key_dim=text_size, + value_dim=opt.dim_att_vec) + + movement_enc = MovementConvEncoder(opt.dim_pose - 4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) + movement_dec = MovementConvDecoder(opt.dim_movement_latent, opt.dim_movement_dec_hidden, opt.dim_pose) + + len_estimator = MotionLenEstimatorBiGRU(opt.dim_word, opt.dim_pos_ohot, 512, opt.num_classes) + + # latent_dis = LatentDis(input_size=opt.dim_z * 2) + checkpoints = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_est_bigru', 'model', 'latest.tar'), map_location=opt.device) + len_estimator.load_state_dict(checkpoints['estimator']) + len_estimator.to(opt.device) + len_estimator.eval() + + # return text_encoder, text_decoder, att_layer, vae_pri, vae_dec, vae_pos, motion_dis, movement_dis, latent_dis + return text_encoder, seq_prior, seq_decoder, att_layer, movement_enc, movement_dec, len_estimator + +class CompV6GeneratedDataset(Dataset): + + def __init__(self, opt, dataset, w_vectorizer, mm_num_samples, mm_num_repeats): + assert mm_num_samples < len(dataset) + print(opt.model_dir) + + dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True) + text_enc, seq_pri, seq_dec, att_layer, mov_enc, mov_dec, len_estimator = build_models(opt) + trainer = CompTrainerV6(opt, text_enc, seq_pri, seq_dec, att_layer, mov_dec, mov_enc=mov_enc) + epoch, it, sub_ep, schedule_len = trainer.load(pjoin(opt.model_dir, opt.which_epoch + '.tar')) + generated_motion = [] + mm_generated_motions = [] + mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False) + mm_idxs = np.sort(mm_idxs) + min_mov_length = 10 if opt.dataset_name == 't2m' else 6 + # print(mm_idxs) + + print('Loading model: Epoch %03d Schedule_len %03d' % (epoch, schedule_len)) + trainer.eval_mode() + trainer.to(opt.device) + with torch.no_grad(): + for i, data in tqdm(enumerate(dataloader)): + word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data + tokens = tokens[0].split('_') + word_emb = word_emb.detach().to(opt.device).float() + pos_ohot = pos_ohot.detach().to(opt.device).float() + + pred_dis = len_estimator(word_emb, pos_ohot, cap_lens) + pred_dis = nn.Softmax(-1)(pred_dis).squeeze() + + mm_num_now = len(mm_generated_motions) + is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False + + repeat_times = mm_num_repeats if is_mm else 1 + mm_motions = [] + for t in range(repeat_times): + mov_length = torch.multinomial(pred_dis, 1, replacement=True) + if mov_length < min_mov_length: + mov_length = torch.multinomial(pred_dis, 1, replacement=True) + if mov_length < min_mov_length: + mov_length = torch.multinomial(pred_dis, 1, replacement=True) + + m_lens = mov_length * opt.unit_length + pred_motions, _, _ = trainer.generate(word_emb, pos_ohot, cap_lens, m_lens, + m_lens[0]//opt.unit_length, opt.dim_pose) + if t == 0: + # print(m_lens) + # print(text_data) + sub_dict = {'motion': pred_motions[0].cpu().numpy(), + 'length': m_lens[0].item(), + 'cap_len': cap_lens[0].item(), + 'caption': caption[0], + 'tokens': tokens} + generated_motion.append(sub_dict) + + if is_mm: + mm_motions.append({ + 'motion': pred_motions[0].cpu().numpy(), + 'length': m_lens[0].item() + }) + if is_mm: + mm_generated_motions.append({'caption': caption[0], + 'tokens': tokens, + 'cap_len': cap_lens[0].item(), + 'mm_motions': mm_motions}) + + self.generated_motion = generated_motion + self.mm_generated_motion = mm_generated_motions + self.opt = opt + self.w_vectorizer = w_vectorizer + + + def __len__(self): + return len(self.generated_motion) + + + def __getitem__(self, item): + data = self.generated_motion[item] + motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens'] + sent_len = data['cap_len'] + + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + + if m_length < self.opt.max_motion_length: + motion = np.concatenate([motion, + np.zeros((self.opt.max_motion_length - m_length, motion.shape[1])) + ], axis=0) + return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens) + +class CompMDMGeneratedDataset(Dataset): + + def __init__(self, model, diffusion, dataloader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale=1.): + self.dataloader = dataloader + self.dataset = dataloader.dataset + assert mm_num_samples < len(dataloader.dataset) + use_ddim = False # FIXME - hardcoded + clip_denoised = False # FIXME - hardcoded + self.max_motion_length = max_motion_length + sample_fn = ( + diffusion.p_sample_loop if not use_ddim else diffusion.ddim_sample_loop + ) + + real_num_batches = len(dataloader) + if num_samples_limit is not None: + real_num_batches = num_samples_limit // dataloader.batch_size + 1 + print('real_num_batches', real_num_batches) + + generated_motion = [] + mm_generated_motions = [] + if mm_num_samples > 0: + mm_idxs = np.random.choice(real_num_batches, mm_num_samples // dataloader.batch_size +1, replace=False) + mm_idxs = np.sort(mm_idxs) + else: + mm_idxs = [] + print('mm_idxs', mm_idxs) + + model.eval() + + + with torch.no_grad(): + for i, (motion, model_kwargs) in tqdm(enumerate(dataloader)): + + if num_samples_limit is not None and len(generated_motion) >= num_samples_limit: + break + + tokens = [t.split('_') for t in model_kwargs['y']['tokens']] + + # add CFG scale to batch + if scale != 1.: + model_kwargs['y']['scale'] = torch.ones(motion.shape[0], + device=dist_util.dev()) * scale + + mm_num_now = len(mm_generated_motions) // dataloader.batch_size + is_mm = i in mm_idxs + repeat_times = mm_num_repeats if is_mm else 1 + mm_motions = [] + for t in range(repeat_times): + + sample = sample_fn( + model, + motion.shape, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + skip_timesteps=0, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=False, + dump_steps=None, + noise=None, + const_noise=False, + # when experimenting guidance_scale we want to nutrileze the effect of noise on generation + ) + + if t == 0: + sub_dicts = [{'motion': sample[bs_i].squeeze().permute(1,0).cpu().numpy(), + 'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(), + 'caption': model_kwargs['y']['text'][bs_i], + 'tokens': tokens[bs_i], + 'cap_len': len(tokens[bs_i]), + } for bs_i in range(dataloader.batch_size)] + generated_motion += sub_dicts + + if is_mm: + mm_motions += [{'motion': sample[bs_i].squeeze().permute(1, 0).cpu().numpy(), + 'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(), + } for bs_i in range(dataloader.batch_size)] + + if is_mm: + mm_generated_motions += [{ + 'caption': model_kwargs['y']['text'][bs_i], + 'tokens': tokens[bs_i], + 'cap_len': len(tokens[bs_i]), + 'mm_motions': mm_motions[bs_i::dataloader.batch_size], # collect all 10 repeats from the (32*10) generated motions + } for bs_i in range(dataloader.batch_size)] + + + self.generated_motion = generated_motion + self.mm_generated_motion = mm_generated_motions + self.w_vectorizer = dataloader.dataset.w_vectorizer + + + def __len__(self): + return len(self.generated_motion) + + + def __getitem__(self, item): + data = self.generated_motion[item] + motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens'] + sent_len = data['cap_len'] + + if self.dataset.mode == 'eval': + normed_motion = motion + denormed_motion = self.dataset.t2m_dataset.inv_transform(normed_motion) + renormed_motion = (denormed_motion - self.dataset.mean_for_eval) / self.dataset.std_for_eval # according to T2M norms + motion = renormed_motion + # This step is needed because T2M evaluators expect their norm convention + + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + + return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens) \ No newline at end of file diff --git a/main/data_loaders/humanml/motion_loaders/dataset_motion_loader.py b/main/data_loaders/humanml/motion_loaders/dataset_motion_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..37fff1d8d7a61a26cdb4df1572bfb6fea22c34bf --- /dev/null +++ b/main/data_loaders/humanml/motion_loaders/dataset_motion_loader.py @@ -0,0 +1,27 @@ +from t2m.data.dataset import Text2MotionDatasetV2, collate_fn +from t2m.utils.word_vectorizer import WordVectorizer +import numpy as np +from os.path import join as pjoin +from torch.utils.data import DataLoader +from t2m.utils.get_opt import get_opt + +def get_dataset_motion_loader(opt_path, batch_size, device): + opt = get_opt(opt_path, device) + + # Configurations of T2M dataset and KIT dataset is almost the same + if opt.dataset_name == 't2m' or opt.dataset_name == 'kit': + print('Loading dataset %s ...' % opt.dataset_name) + + mean = np.load(pjoin(opt.meta_dir, 'mean.npy')) + std = np.load(pjoin(opt.meta_dir, 'std.npy')) + + w_vectorizer = WordVectorizer('./glove', 'our_vab') + split_file = pjoin(opt.data_root, 'test.txt') + dataset = Text2MotionDatasetV2(opt, mean, std, split_file, w_vectorizer) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True, + collate_fn=collate_fn, shuffle=True) + else: + raise KeyError('Dataset not Recognized !!') + + print('Ground Truth Dataset Loading Completed!!!') + return dataloader, dataset \ No newline at end of file diff --git a/main/data_loaders/humanml/motion_loaders/model_motion_loaders.py b/main/data_loaders/humanml/motion_loaders/model_motion_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd35bf3e0d16bc9074ba01239b0771aa6fca089 --- /dev/null +++ b/main/data_loaders/humanml/motion_loaders/model_motion_loaders.py @@ -0,0 +1,91 @@ +from torch.utils.data import DataLoader, Dataset +from data_loaders.humanml.utils.get_opt import get_opt +from data_loaders.humanml.motion_loaders.comp_v6_model_dataset import CompMDMGeneratedDataset +from data_loaders.humanml.utils.word_vectorizer import WordVectorizer +import numpy as np +from torch.utils.data._utils.collate import default_collate + + +def collate_fn(batch): + batch.sort(key=lambda x: x[3], reverse=True) + return default_collate(batch) + + +class MMGeneratedDataset(Dataset): + def __init__(self, opt, motion_dataset, w_vectorizer): + self.opt = opt + self.dataset = motion_dataset.mm_generated_motion + self.w_vectorizer = w_vectorizer + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item): + data = self.dataset[item] + mm_motions = data['mm_motions'] + m_lens = [] + motions = [] + for mm_motion in mm_motions: + m_lens.append(mm_motion['length']) + motion = mm_motion['motion'] + # We don't need the following logic because our sample func generates the full tensor anyway: + # if len(motion) < self.opt.max_motion_length: + # motion = np.concatenate([motion, + # np.zeros((self.opt.max_motion_length - len(motion), motion.shape[1])) + # ], axis=0) + motion = motion[None, :] + motions.append(motion) + m_lens = np.array(m_lens, dtype=np.int) + motions = np.concatenate(motions, axis=0) + sort_indx = np.argsort(m_lens)[::-1].copy() + # print(m_lens) + # print(sort_indx) + # print(m_lens[sort_indx]) + m_lens = m_lens[sort_indx] + motions = motions[sort_indx] + return motions, m_lens + + + +def get_motion_loader(opt_path, batch_size, ground_truth_dataset, mm_num_samples, mm_num_repeats, device): + opt = get_opt(opt_path, device) + + # Currently the configurations of two datasets are almost the same + if opt.dataset_name == 't2m' or opt.dataset_name == 'kit': + w_vectorizer = WordVectorizer('./glove', 'our_vab') + else: + raise KeyError('Dataset not recognized!!') + print('Generating %s ...' % opt.name) + + if 'v6' in opt.name: + dataset = CompV6GeneratedDataset(opt, ground_truth_dataset, w_vectorizer, mm_num_samples, mm_num_repeats) + else: + raise KeyError('Dataset not recognized!!') + + mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer) + + motion_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4) + mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) + + print('Generated Dataset Loading Completed!!!') + + return motion_loader, mm_motion_loader + +# our loader +def get_mdm_loader(model, diffusion, batch_size, ground_truth_loader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale): + opt = { + 'name': 'test', # FIXME + } + print('Generating %s ...' % opt['name']) + # dataset = CompMDMGeneratedDataset(opt, ground_truth_dataset, ground_truth_dataset.w_vectorizer, mm_num_samples, mm_num_repeats) + dataset = CompMDMGeneratedDataset(model, diffusion, ground_truth_loader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale) + + mm_dataset = MMGeneratedDataset(opt, dataset, ground_truth_loader.dataset.w_vectorizer) + + # NOTE: bs must not be changed! this will cause a bug in R precision calc! + motion_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4) + mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) + + print('Generated Dataset Loading Completed!!!') + + return motion_loader, mm_motion_loader \ No newline at end of file diff --git a/main/data_loaders/humanml/networks/__init__.py b/main/data_loaders/humanml/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/main/data_loaders/humanml/networks/evaluator_wrapper.py b/main/data_loaders/humanml/networks/evaluator_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..87f9d569c67ad00cd265654c03afc38f152029cb --- /dev/null +++ b/main/data_loaders/humanml/networks/evaluator_wrapper.py @@ -0,0 +1,187 @@ +from data_loaders.humanml.networks.modules import * +from data_loaders.humanml.utils.word_vectorizer import POS_enumerator +from os.path import join as pjoin + +def build_models(opt): + movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) + text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word, + pos_size=opt.dim_pos_ohot, + hidden_size=opt.dim_text_hidden, + output_size=opt.dim_coemb_hidden, + device=opt.device) + + motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent, + hidden_size=opt.dim_motion_hidden, + output_size=opt.dim_coemb_hidden, + device=opt.device) + + checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'), + map_location=opt.device) + movement_enc.load_state_dict(checkpoint['movement_encoder']) + text_enc.load_state_dict(checkpoint['text_encoder']) + motion_enc.load_state_dict(checkpoint['motion_encoder']) + print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) + return text_enc, motion_enc, movement_enc + + +class EvaluatorModelWrapper(object): + + def __init__(self, opt): + + if opt.dataset_name == 't2m': + opt.dim_pose = 263 + elif opt.dataset_name == 'kit': + opt.dim_pose = 251 + else: + raise KeyError('Dataset not Recognized!!!') + + opt.dim_word = 300 + opt.max_motion_length = 196 + opt.dim_pos_ohot = len(POS_enumerator) + opt.dim_motion_hidden = 1024 + opt.max_text_len = 20 + opt.dim_text_hidden = 512 + opt.dim_coemb_hidden = 512 + + self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt) + self.opt = opt + self.device = opt.device + + self.text_encoder.to(opt.device) + self.motion_encoder.to(opt.device) + self.movement_encoder.to(opt.device) + + self.text_encoder.eval() + self.motion_encoder.eval() + self.movement_encoder.eval() + + # Please note that the results does not following the order of inputs + def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): + with torch.no_grad(): + word_embs = word_embs.detach().to(self.device).float() + pos_ohot = pos_ohot.detach().to(self.device).float() + motions = motions.detach().to(self.device).float() + + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + m_lens = m_lens[align_idx] + + '''Movement Encoding''' + movements = self.movement_encoder(motions[..., :-4]).detach() + m_lens = m_lens // self.opt.unit_length + motion_embedding = self.motion_encoder(movements, m_lens) + + '''Text Encoding''' + text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) + text_embedding = text_embedding[align_idx] + return text_embedding, motion_embedding + + # Please note that the results does not following the order of inputs + def get_motion_embeddings(self, motions, m_lens): + with torch.no_grad(): + motions = motions.detach().to(self.device).float() + + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + m_lens = m_lens[align_idx] + + '''Movement Encoding''' + movements = self.movement_encoder(motions[..., :-4]).detach() + m_lens = m_lens // self.opt.unit_length + motion_embedding = self.motion_encoder(movements, m_lens) + return motion_embedding + +# our version +def build_evaluators(opt): + movement_enc = MovementConvEncoder(opt['dim_pose']-4, opt['dim_movement_enc_hidden'], opt['dim_movement_latent']) + text_enc = TextEncoderBiGRUCo(word_size=opt['dim_word'], + pos_size=opt['dim_pos_ohot'], + hidden_size=opt['dim_text_hidden'], + output_size=opt['dim_coemb_hidden'], + device=opt['device']) + + motion_enc = MotionEncoderBiGRUCo(input_size=opt['dim_movement_latent'], + hidden_size=opt['dim_motion_hidden'], + output_size=opt['dim_coemb_hidden'], + device=opt['device']) + + ckpt_dir = opt['dataset_name'] + if opt['dataset_name'] == 'humanml': + ckpt_dir = 't2m' + + checkpoint = torch.load(pjoin(opt['checkpoints_dir'], ckpt_dir, 'text_mot_match', 'model', 'finest.tar'), + map_location=opt['device']) + movement_enc.load_state_dict(checkpoint['movement_encoder']) + text_enc.load_state_dict(checkpoint['text_encoder']) + motion_enc.load_state_dict(checkpoint['motion_encoder']) + print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) + return text_enc, motion_enc, movement_enc + +# our wrapper +class EvaluatorMDMWrapper(object): + + def __init__(self, dataset_name, device): + opt = { + 'dataset_name': dataset_name, + 'device': device, + 'dim_word': 300, + 'max_motion_length': 196, + 'dim_pos_ohot': len(POS_enumerator), + 'dim_motion_hidden': 1024, + 'max_text_len': 20, + 'dim_text_hidden': 512, + 'dim_coemb_hidden': 512, + 'dim_pose': 263 if dataset_name == 'humanml' else 251, + 'dim_movement_enc_hidden': 512, + 'dim_movement_latent': 512, + 'checkpoints_dir': '.', + 'unit_length': 4, + } + + self.text_encoder, self.motion_encoder, self.movement_encoder = build_evaluators(opt) + self.opt = opt + self.device = opt['device'] + + self.text_encoder.to(opt['device']) + self.motion_encoder.to(opt['device']) + self.movement_encoder.to(opt['device']) + + self.text_encoder.eval() + self.motion_encoder.eval() + self.movement_encoder.eval() + + # Please note that the results does not following the order of inputs + def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): + with torch.no_grad(): + word_embs = word_embs.detach().to(self.device).float() + pos_ohot = pos_ohot.detach().to(self.device).float() + motions = motions.detach().to(self.device).float() + + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + m_lens = m_lens[align_idx] + + '''Movement Encoding''' + movements = self.movement_encoder(motions[..., :-4]).detach() + m_lens = m_lens // self.opt['unit_length'] + motion_embedding = self.motion_encoder(movements, m_lens) + + '''Text Encoding''' + text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) + text_embedding = text_embedding[align_idx] + return text_embedding, motion_embedding + + # Please note that the results does not following the order of inputs + def get_motion_embeddings(self, motions, m_lens): + with torch.no_grad(): + motions = motions.detach().to(self.device).float() + + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + m_lens = m_lens[align_idx] + + '''Movement Encoding''' + movements = self.movement_encoder(motions[..., :-4]).detach() + m_lens = m_lens // self.opt['unit_length'] + motion_embedding = self.motion_encoder(movements, m_lens) + return motion_embedding \ No newline at end of file diff --git a/main/data_loaders/humanml/networks/modules.py b/main/data_loaders/humanml/networks/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..3177738d3f029a65fb4b26538d607d95fb1c84b7 --- /dev/null +++ b/main/data_loaders/humanml/networks/modules.py @@ -0,0 +1,438 @@ +import torch +import torch.nn as nn +import numpy as np +import time +import math +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +# from networks.layers import * +import torch.nn.functional as F + + +class ContrastiveLoss(torch.nn.Module): + """ + Contrastive loss function. + Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf + """ + def __init__(self, margin=3.0): + super(ContrastiveLoss, self).__init__() + self.margin = margin + + def forward(self, output1, output2, label): + euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True) + loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) + return loss_contrastive + + +def init_weight(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): + nn.init.xavier_normal_(m.weight) + # m.bias.data.fill_(0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +def reparameterize(mu, logvar): + s_var = logvar.mul(0.5).exp_() + eps = s_var.data.new(s_var.size()).normal_() + return eps.mul(s_var).add_(mu) + + +# batch_size, dimension and position +# output: (batch_size, dim) +def positional_encoding(batch_size, dim, pos): + assert batch_size == pos.shape[0] + positions_enc = np.array([ + [pos[j] / np.power(10000, (i-i%2)/dim) for i in range(dim)] + for j in range(batch_size) + ], dtype=np.float32) + positions_enc[:, 0::2] = np.sin(positions_enc[:, 0::2]) + positions_enc[:, 1::2] = np.cos(positions_enc[:, 1::2]) + return torch.from_numpy(positions_enc).float() + + +def get_padding_mask(batch_size, seq_len, cap_lens): + cap_lens = cap_lens.data.tolist() + mask_2d = torch.ones((batch_size, seq_len, seq_len), dtype=torch.float32) + for i, cap_len in enumerate(cap_lens): + mask_2d[i, :, :cap_len] = 0 + return mask_2d.bool(), 1 - mask_2d[:, :, 0].clone() + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, max_len=300): + super(PositionalEncoding, self).__init__() + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + # pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, pos): + return self.pe[pos] + + +class MovementConvEncoder(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MovementConvEncoder, self).__init__() + self.main = nn.Sequential( + nn.Conv1d(input_size, hidden_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(hidden_size, output_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ) + self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + self.out_net.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + # print(outputs.shape) + return self.out_net(outputs) + + +class MovementConvDecoder(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MovementConvDecoder, self).__init__() + self.main = nn.Sequential( + nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1), + # nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1), + # nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ) + self.out_net = nn.Linear(output_size, output_size) + + self.main.apply(init_weight) + self.out_net.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return self.out_net(outputs) + + +class TextVAEDecoder(nn.Module): + def __init__(self, text_size, input_size, output_size, hidden_size, n_layers): + super(TextVAEDecoder, self).__init__() + self.input_size = input_size + self.output_size = output_size + self.hidden_size = hidden_size + self.n_layers = n_layers + self.emb = nn.Sequential( + nn.Linear(input_size, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True)) + + self.z2init = nn.Linear(text_size, hidden_size * n_layers) + self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)]) + self.positional_encoder = PositionalEncoding(hidden_size) + + + self.output = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size) + ) + + # + # self.output = nn.Sequential( + # nn.Linear(hidden_size, hidden_size), + # nn.LayerNorm(hidden_size), + # nn.LeakyReLU(0.2, inplace=True), + # nn.Linear(hidden_size, output_size-4) + # ) + + # self.contact_net = nn.Sequential( + # nn.Linear(output_size-4, 64), + # nn.LayerNorm(64), + # nn.LeakyReLU(0.2, inplace=True), + # nn.Linear(64, 4) + # ) + + self.output.apply(init_weight) + self.emb.apply(init_weight) + self.z2init.apply(init_weight) + # self.contact_net.apply(init_weight) + + def get_init_hidden(self, latent): + hidden = self.z2init(latent) + hidden = torch.split(hidden, self.hidden_size, dim=-1) + return list(hidden) + + def forward(self, inputs, last_pred, hidden, p): + h_in = self.emb(inputs) + pos_enc = self.positional_encoder(p).to(inputs.device).detach() + h_in = h_in + pos_enc + for i in range(self.n_layers): + # print(h_in.shape) + hidden[i] = self.gru[i](h_in, hidden[i]) + h_in = hidden[i] + pose_pred = self.output(h_in) + # pose_pred = self.output(h_in) + last_pred.detach() + # contact = self.contact_net(pose_pred) + # return torch.cat([pose_pred, contact], dim=-1), hidden + return pose_pred, hidden + + +class TextDecoder(nn.Module): + def __init__(self, text_size, input_size, output_size, hidden_size, n_layers): + super(TextDecoder, self).__init__() + self.input_size = input_size + self.output_size = output_size + self.hidden_size = hidden_size + self.n_layers = n_layers + self.emb = nn.Sequential( + nn.Linear(input_size, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True)) + + self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)]) + self.z2init = nn.Linear(text_size, hidden_size * n_layers) + self.positional_encoder = PositionalEncoding(hidden_size) + + self.mu_net = nn.Linear(hidden_size, output_size) + self.logvar_net = nn.Linear(hidden_size, output_size) + + self.emb.apply(init_weight) + self.z2init.apply(init_weight) + self.mu_net.apply(init_weight) + self.logvar_net.apply(init_weight) + + def get_init_hidden(self, latent): + + hidden = self.z2init(latent) + hidden = torch.split(hidden, self.hidden_size, dim=-1) + + return list(hidden) + + def forward(self, inputs, hidden, p): + # print(inputs.shape) + x_in = self.emb(inputs) + pos_enc = self.positional_encoder(p).to(inputs.device).detach() + x_in = x_in + pos_enc + + for i in range(self.n_layers): + hidden[i] = self.gru[i](x_in, hidden[i]) + h_in = hidden[i] + mu = self.mu_net(h_in) + logvar = self.logvar_net(h_in) + z = reparameterize(mu, logvar) + return z, mu, logvar, hidden + +class AttLayer(nn.Module): + def __init__(self, query_dim, key_dim, value_dim): + super(AttLayer, self).__init__() + self.W_q = nn.Linear(query_dim, value_dim) + self.W_k = nn.Linear(key_dim, value_dim, bias=False) + self.W_v = nn.Linear(key_dim, value_dim) + + self.softmax = nn.Softmax(dim=1) + self.dim = value_dim + + self.W_q.apply(init_weight) + self.W_k.apply(init_weight) + self.W_v.apply(init_weight) + + def forward(self, query, key_mat): + ''' + query (batch, query_dim) + key (batch, seq_len, key_dim) + ''' + # print(query.shape) + query_vec = self.W_q(query).unsqueeze(-1) # (batch, value_dim, 1) + val_set = self.W_v(key_mat) # (batch, seq_len, value_dim) + key_set = self.W_k(key_mat) # (batch, seq_len, value_dim) + + weights = torch.matmul(key_set, query_vec) / np.sqrt(self.dim) + + co_weights = self.softmax(weights) # (batch, seq_len, 1) + values = val_set * co_weights # (batch, seq_len, value_dim) + pred = values.sum(dim=1) # (batch, value_dim) + return pred, co_weights + + def short_cut(self, querys, keys): + return self.W_q(querys), self.W_k(keys) + + +class TextEncoderBiGRU(nn.Module): + def __init__(self, word_size, pos_size, hidden_size, device): + super(TextEncoderBiGRU, self).__init__() + self.device = device + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + # self.linear2 = nn.Linear(hidden_size, output_size) + + self.input_emb.apply(init_weight) + self.pos_emb.apply(init_weight) + # self.linear2.apply(init_weight) + # self.batch_size = batch_size + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, word_embs, pos_onehot, cap_lens): + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + gru_seq = pad_packed_sequence(gru_seq, batch_first=True)[0] + forward_seq = gru_seq[..., :self.hidden_size] + backward_seq = gru_seq[..., self.hidden_size:].clone() + + # Concate the forward and backward word embeddings + for i, length in enumerate(cap_lens): + backward_seq[i:i+1, :length] = torch.flip(backward_seq[i:i+1, :length].clone(), dims=[1]) + gru_seq = torch.cat([forward_seq, backward_seq], dim=-1) + + return gru_seq, gru_last + + +class TextEncoderBiGRUCo(nn.Module): + def __init__(self, word_size, pos_size, hidden_size, output_size, device): + super(TextEncoderBiGRUCo, self).__init__() + self.device = device + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size) + ) + + self.input_emb.apply(init_weight) + self.pos_emb.apply(init_weight) + self.output_net.apply(init_weight) + # self.linear2.apply(init_weight) + # self.batch_size = batch_size + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, word_embs, pos_onehot, cap_lens): + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) + + +class MotionEncoderBiGRUCo(nn.Module): + def __init__(self, input_size, hidden_size, output_size, device): + super(MotionEncoderBiGRUCo, self).__init__() + self.device = device + + self.input_emb = nn.Linear(input_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_net = nn.Sequential( + nn.Linear(hidden_size*2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size) + ) + + self.input_emb.apply(init_weight) + self.output_net.apply(init_weight) + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, inputs, m_lens): + num_samples = inputs.shape[0] + + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = m_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) + + +class MotionLenEstimatorBiGRU(nn.Module): + def __init__(self, word_size, pos_size, hidden_size, output_size): + super(MotionLenEstimatorBiGRU, self).__init__() + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + nd = 512 + self.output = nn.Sequential( + nn.Linear(hidden_size*2, nd), + nn.LayerNorm(nd), + nn.LeakyReLU(0.2, inplace=True), + + nn.Linear(nd, nd // 2), + nn.LayerNorm(nd // 2), + nn.LeakyReLU(0.2, inplace=True), + + nn.Linear(nd // 2, nd // 4), + nn.LayerNorm(nd // 4), + nn.LeakyReLU(0.2, inplace=True), + + nn.Linear(nd // 4, output_size) + ) + # self.linear2 = nn.Linear(hidden_size, output_size) + + self.input_emb.apply(init_weight) + self.pos_emb.apply(init_weight) + self.output.apply(init_weight) + # self.linear2.apply(init_weight) + # self.batch_size = batch_size + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, word_embs, pos_onehot, cap_lens): + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output(gru_last) diff --git a/main/data_loaders/humanml/networks/trainers.py b/main/data_loaders/humanml/networks/trainers.py new file mode 100644 index 0000000000000000000000000000000000000000..123f497a7893cc85d915fde63add9338738f1d03 --- /dev/null +++ b/main/data_loaders/humanml/networks/trainers.py @@ -0,0 +1,1089 @@ +import torch +import torch.nn.functional as F +import random +from data_loaders.humanml.networks.modules import * +from torch.utils.data import DataLoader +import torch.optim as optim +from torch.nn.utils import clip_grad_norm_ +# import tensorflow as tf +from collections import OrderedDict +from data_loaders.humanml.utils.utils import * +from os.path import join as pjoin +from data_loaders.humanml.data.dataset import collate_fn +import codecs as cs + + +class Logger(object): + def __init__(self, log_dir): + self.writer = tf.summary.create_file_writer(log_dir) + + def scalar_summary(self, tag, value, step): + with self.writer.as_default(): + tf.summary.scalar(tag, value, step=step) + self.writer.flush() + +class DecompTrainerV3(object): + def __init__(self, args, movement_enc, movement_dec): + self.opt = args + self.movement_enc = movement_enc + self.movement_dec = movement_dec + self.device = args.device + + if args.is_train: + self.logger = Logger(args.log_dir) + self.sml1_criterion = torch.nn.SmoothL1Loss() + self.l1_criterion = torch.nn.L1Loss() + self.mse_criterion = torch.nn.MSELoss() + + + @staticmethod + def zero_grad(opt_list): + for opt in opt_list: + opt.zero_grad() + + @staticmethod + def clip_norm(network_list): + for network in network_list: + clip_grad_norm_(network.parameters(), 0.5) + + @staticmethod + def step(opt_list): + for opt in opt_list: + opt.step() + + def forward(self, batch_data): + motions = batch_data + self.motions = motions.detach().to(self.device).float() + self.latents = self.movement_enc(self.motions[..., :-4]) + self.recon_motions = self.movement_dec(self.latents) + + def backward(self): + self.loss_rec = self.l1_criterion(self.recon_motions, self.motions) + # self.sml1_criterion(self.recon_motions[:, 1:] - self.recon_motions[:, :-1], + # self.motions[:, 1:] - self.recon_motions[:, :-1]) + self.loss_sparsity = torch.mean(torch.abs(self.latents)) + self.loss_smooth = self.l1_criterion(self.latents[:, 1:], self.latents[:, :-1]) + self.loss = self.loss_rec + self.loss_sparsity * self.opt.lambda_sparsity +\ + self.loss_smooth*self.opt.lambda_smooth + + def update(self): + # time0 = time.time() + self.zero_grad([self.opt_movement_enc, self.opt_movement_dec]) + # time1 = time.time() + # print('\t Zero_grad Time: %.5f s' % (time1 - time0)) + self.backward() + # time2 = time.time() + # print('\t Backward Time: %.5f s' % (time2 - time1)) + self.loss.backward() + # time3 = time.time() + # print('\t Loss backward Time: %.5f s' % (time3 - time2)) + # self.clip_norm([self.movement_enc, self.movement_dec]) + # time4 = time.time() + # print('\t Clip_norm Time: %.5f s' % (time4 - time3)) + self.step([self.opt_movement_enc, self.opt_movement_dec]) + # time5 = time.time() + # print('\t Step Time: %.5f s' % (time5 - time4)) + + loss_logs = OrderedDict({}) + loss_logs['loss'] = self.loss_rec.item() + loss_logs['loss_rec'] = self.loss_rec.item() + loss_logs['loss_sparsity'] = self.loss_sparsity.item() + loss_logs['loss_smooth'] = self.loss_smooth.item() + return loss_logs + + def save(self, file_name, ep, total_it): + state = { + 'movement_enc': self.movement_enc.state_dict(), + 'movement_dec': self.movement_dec.state_dict(), + + 'opt_movement_enc': self.opt_movement_enc.state_dict(), + 'opt_movement_dec': self.opt_movement_dec.state_dict(), + + 'ep': ep, + 'total_it': total_it, + } + torch.save(state, file_name) + return + + def resume(self, model_dir): + checkpoint = torch.load(model_dir, map_location=self.device) + + self.movement_dec.load_state_dict(checkpoint['movement_dec']) + self.movement_enc.load_state_dict(checkpoint['movement_enc']) + + self.opt_movement_enc.load_state_dict(checkpoint['opt_movement_enc']) + self.opt_movement_dec.load_state_dict(checkpoint['opt_movement_dec']) + + return checkpoint['ep'], checkpoint['total_it'] + + def train(self, train_dataloader, val_dataloader, plot_eval): + self.movement_enc.to(self.device) + self.movement_dec.to(self.device) + + self.opt_movement_enc = optim.Adam(self.movement_enc.parameters(), lr=self.opt.lr) + self.opt_movement_dec = optim.Adam(self.movement_dec.parameters(), lr=self.opt.lr) + + epoch = 0 + it = 0 + + if self.opt.is_continue: + model_dir = pjoin(self.opt.model_dir, 'latest.tar') + epoch, it = self.resume(model_dir) + + start_time = time.time() + total_iters = self.opt.max_epoch * len(train_dataloader) + print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader))) + val_loss = 0 + logs = OrderedDict() + while epoch < self.opt.max_epoch: + # time0 = time.time() + for i, batch_data in enumerate(train_dataloader): + self.movement_dec.train() + self.movement_enc.train() + + # time1 = time.time() + # print('DataLoader Time: %.5f s'%(time1-time0) ) + self.forward(batch_data) + # time2 = time.time() + # print('Forward Time: %.5f s'%(time2-time1)) + log_dict = self.update() + # time3 = time.time() + # print('Update Time: %.5f s' % (time3 - time2)) + # time0 = time3 + for k, v in log_dict.items(): + if k not in logs: + logs[k] = v + else: + logs[k] += v + + it += 1 + if it % self.opt.log_every == 0: + mean_loss = OrderedDict({'val_loss': val_loss}) + self.logger.scalar_summary('val_loss', val_loss, it) + + for tag, value in logs.items(): + self.logger.scalar_summary(tag, value / self.opt.log_every, it) + mean_loss[tag] = value / self.opt.log_every + logs = OrderedDict() + print_current_loss_decomp(start_time, it, total_iters, mean_loss, epoch, i) + + if it % self.opt.save_latest == 0: + self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) + + self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) + + epoch += 1 + if epoch % self.opt.save_every_e == 0: + self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, total_it=it) + + print('Validation time:') + + val_loss = 0 + val_rec_loss = 0 + val_sparcity_loss = 0 + val_smooth_loss = 0 + with torch.no_grad(): + for i, batch_data in enumerate(val_dataloader): + self.forward(batch_data) + self.backward() + val_rec_loss += self.loss_rec.item() + val_smooth_loss += self.loss.item() + val_sparcity_loss += self.loss_sparsity.item() + val_smooth_loss += self.loss_smooth.item() + val_loss += self.loss.item() + + val_loss = val_loss / (len(val_dataloader) + 1) + val_rec_loss = val_rec_loss / (len(val_dataloader) + 1) + val_sparcity_loss = val_sparcity_loss / (len(val_dataloader) + 1) + val_smooth_loss = val_smooth_loss / (len(val_dataloader) + 1) + print('Validation Loss: %.5f Reconstruction Loss: %.5f ' + 'Sparsity Loss: %.5f Smooth Loss: %.5f' % (val_loss, val_rec_loss, val_sparcity_loss, \ + val_smooth_loss)) + + if epoch % self.opt.eval_every_e == 0: + data = torch.cat([self.recon_motions[:4], self.motions[:4]], dim=0).detach().cpu().numpy() + save_dir = pjoin(self.opt.eval_dir, 'E%04d' % (epoch)) + os.makedirs(save_dir, exist_ok=True) + plot_eval(data, save_dir) + + +# VAE Sequence Decoder/Prior/Posterior latent by latent +class CompTrainerV6(object): + + def __init__(self, args, text_enc, seq_pri, seq_dec, att_layer, mov_dec, mov_enc=None, seq_post=None): + self.opt = args + self.text_enc = text_enc + self.seq_pri = seq_pri + self.att_layer = att_layer + self.device = args.device + self.seq_dec = seq_dec + self.mov_dec = mov_dec + self.mov_enc = mov_enc + + if args.is_train: + self.seq_post = seq_post + # self.motion_dis + self.logger = Logger(args.log_dir) + self.l1_criterion = torch.nn.SmoothL1Loss() + self.gan_criterion = torch.nn.BCEWithLogitsLoss() + self.mse_criterion = torch.nn.MSELoss() + + @staticmethod + def reparametrize(mu, logvar): + s_var = logvar.mul(0.5).exp_() + eps = s_var.data.new(s_var.size()).normal_() + return eps.mul(s_var).add_(mu) + + @staticmethod + def ones_like(tensor, val=1.): + return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False) + + @staticmethod + def zeros_like(tensor, val=0.): + return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False) + + @staticmethod + def zero_grad(opt_list): + for opt in opt_list: + opt.zero_grad() + + @staticmethod + def clip_norm(network_list): + for network in network_list: + clip_grad_norm_(network.parameters(), 0.5) + + @staticmethod + def step(opt_list): + for opt in opt_list: + opt.step() + + @staticmethod + def kl_criterion(mu1, logvar1, mu2, logvar2): + # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2)) + # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2 + sigma1 = logvar1.mul(0.5).exp() + sigma2 = logvar2.mul(0.5).exp() + kld = torch.log(sigma2 / sigma1) + (torch.exp(logvar1) + (mu1 - mu2) ** 2) / ( + 2 * torch.exp(logvar2)) - 1 / 2 + return kld.sum() / mu1.shape[0] + + @staticmethod + def kl_criterion_unit(mu, logvar): + # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2)) + # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2 + kld = ((torch.exp(logvar) + mu ** 2) - logvar - 1) / 2 + return kld.sum() / mu.shape[0] + + def forward(self, batch_data, tf_ratio, mov_len, eval_mode=False): + word_emb, pos_ohot, caption, cap_lens, motions, m_lens = batch_data + word_emb = word_emb.detach().to(self.device).float() + pos_ohot = pos_ohot.detach().to(self.device).float() + motions = motions.detach().to(self.device).float() + self.cap_lens = cap_lens + self.caption = caption + + # print(motions.shape) + # (batch_size, motion_len, pose_dim) + self.motions = motions + + '''Movement Encoding''' + self.movements = self.mov_enc(self.motions[..., :-4]).detach() + # Initially input a mean vector + mov_in = self.mov_enc( + torch.zeros((self.motions.shape[0], self.opt.unit_length, self.motions.shape[-1] - 4), device=self.device) + ).squeeze(1).detach() + assert self.movements.shape[1] == mov_len + + teacher_force = True if random.random() < tf_ratio else False + + '''Text Encoding''' + # time0 = time.time() + # text_input = torch.cat([word_emb, pos_ohot], dim=-1) + word_hids, hidden = self.text_enc(word_emb, pos_ohot, cap_lens) + # print(word_hids.shape, hidden.shape) + + if self.opt.text_enc_mod == 'bigru': + hidden_pos = self.seq_post.get_init_hidden(hidden) + hidden_pri = self.seq_pri.get_init_hidden(hidden) + hidden_dec = self.seq_dec.get_init_hidden(hidden) + elif self.opt.text_enc_mod == 'transformer': + hidden_pos = self.seq_post.get_init_hidden(hidden.detach()) + hidden_pri = self.seq_pri.get_init_hidden(hidden.detach()) + hidden_dec = self.seq_dec.get_init_hidden(hidden) + + mus_pri = [] + logvars_pri = [] + mus_post = [] + logvars_post = [] + fake_mov_batch = [] + + query_input = [] + + # time1 = time.time() + # print("\t Text Encoder Cost:%5f" % (time1 - time0)) + # print(self.movements.shape) + + for i in range(mov_len): + # print("\t Sequence Measure") + # print(mov_in.shape) + mov_tgt = self.movements[:, i] + '''Local Attention Vector''' + att_vec, _ = self.att_layer(hidden_dec[-1], word_hids) + query_input.append(hidden_dec[-1]) + + tta = m_lens // self.opt.unit_length - i + + if self.opt.text_enc_mod == 'bigru': + pos_in = torch.cat([mov_in, mov_tgt, att_vec], dim=-1) + pri_in = torch.cat([mov_in, att_vec], dim=-1) + + elif self.opt.text_enc_mod == 'transformer': + pos_in = torch.cat([mov_in, mov_tgt, att_vec.detach()], dim=-1) + pri_in = torch.cat([mov_in, att_vec.detach()], dim=-1) + + '''Posterior''' + z_pos, mu_pos, logvar_pos, hidden_pos = self.seq_post(pos_in, hidden_pos, tta) + + '''Prior''' + z_pri, mu_pri, logvar_pri, hidden_pri = self.seq_pri(pri_in, hidden_pri, tta) + + '''Decoder''' + if eval_mode: + dec_in = torch.cat([mov_in, att_vec, z_pri], dim=-1) + else: + dec_in = torch.cat([mov_in, att_vec, z_pos], dim=-1) + fake_mov, hidden_dec = self.seq_dec(dec_in, mov_in, hidden_dec, tta) + + # print(fake_mov.shape) + + mus_post.append(mu_pos) + logvars_post.append(logvar_pos) + mus_pri.append(mu_pri) + logvars_pri.append(logvar_pri) + fake_mov_batch.append(fake_mov.unsqueeze(1)) + + if teacher_force: + mov_in = self.movements[:, i].detach() + else: + mov_in = fake_mov.detach() + + + self.fake_movements = torch.cat(fake_mov_batch, dim=1) + + # print(self.fake_movements.shape) + + self.fake_motions = self.mov_dec(self.fake_movements) + + self.mus_post = torch.cat(mus_post, dim=0) + self.mus_pri = torch.cat(mus_pri, dim=0) + self.logvars_post = torch.cat(logvars_post, dim=0) + self.logvars_pri = torch.cat(logvars_pri, dim=0) + + def generate(self, word_emb, pos_ohot, cap_lens, m_lens, mov_len, dim_pose): + word_emb = word_emb.detach().to(self.device).float() + pos_ohot = pos_ohot.detach().to(self.device).float() + self.cap_lens = cap_lens + + # print(motions.shape) + # (batch_size, motion_len, pose_dim) + + '''Movement Encoding''' + # Initially input a mean vector + mov_in = self.mov_enc( + torch.zeros((word_emb.shape[0], self.opt.unit_length, dim_pose - 4), device=self.device) + ).squeeze(1).detach() + + '''Text Encoding''' + # time0 = time.time() + # text_input = torch.cat([word_emb, pos_ohot], dim=-1) + word_hids, hidden = self.text_enc(word_emb, pos_ohot, cap_lens) + # print(word_hids.shape, hidden.shape) + + hidden_pri = self.seq_pri.get_init_hidden(hidden) + hidden_dec = self.seq_dec.get_init_hidden(hidden) + + mus_pri = [] + logvars_pri = [] + fake_mov_batch = [] + att_wgt = [] + + # time1 = time.time() + # print("\t Text Encoder Cost:%5f" % (time1 - time0)) + # print(self.movements.shape) + + for i in range(mov_len): + # print("\t Sequence Measure") + # print(mov_in.shape) + '''Local Attention Vector''' + att_vec, co_weights = self.att_layer(hidden_dec[-1], word_hids) + + tta = m_lens // self.opt.unit_length - i + # tta = m_lens - i + + '''Prior''' + pri_in = torch.cat([mov_in, att_vec], dim=-1) + z_pri, mu_pri, logvar_pri, hidden_pri = self.seq_pri(pri_in, hidden_pri, tta) + + '''Decoder''' + dec_in = torch.cat([mov_in, att_vec, z_pri], dim=-1) + + fake_mov, hidden_dec = self.seq_dec(dec_in, mov_in, hidden_dec, tta) + + # print(fake_mov.shape) + mus_pri.append(mu_pri) + logvars_pri.append(logvar_pri) + fake_mov_batch.append(fake_mov.unsqueeze(1)) + att_wgt.append(co_weights) + + mov_in = fake_mov.detach() + + fake_movements = torch.cat(fake_mov_batch, dim=1) + att_wgts = torch.cat(att_wgt, dim=-1) + + # print(self.fake_movements.shape) + + fake_motions = self.mov_dec(fake_movements) + + mus_pri = torch.cat(mus_pri, dim=0) + logvars_pri = torch.cat(logvars_pri, dim=0) + + return fake_motions, mus_pri, att_wgts + + def backward_G(self): + self.loss_mot_rec = self.l1_criterion(self.fake_motions, self.motions) + self.loss_mov_rec = self.l1_criterion(self.fake_movements, self.movements) + + self.loss_kld = self.kl_criterion(self.mus_post, self.logvars_post, self.mus_pri, self.logvars_pri) + + self.loss_gen = self.loss_mot_rec * self.opt.lambda_rec_mov + self.loss_mov_rec * self.opt.lambda_rec_mot + \ + self.loss_kld * self.opt.lambda_kld + loss_logs = OrderedDict({}) + loss_logs['loss_gen'] = self.loss_gen.item() + loss_logs['loss_mot_rec'] = self.loss_mot_rec.item() + loss_logs['loss_mov_rec'] = self.loss_mov_rec.item() + loss_logs['loss_kld'] = self.loss_kld.item() + + return loss_logs + # self.loss_gen = self.loss_rec_mov + + # self.loss_gen = self.loss_rec_mov * self.opt.lambda_rec_mov + self.loss_rec_mot + \ + # self.loss_kld * self.opt.lambda_kld + \ + # self.loss_mtgan_G * self.opt.lambda_gan_mt + self.loss_mvgan_G * self.opt.lambda_gan_mv + + + def update(self): + + self.zero_grad([self.opt_text_enc, self.opt_seq_dec, self.opt_seq_post, + self.opt_seq_pri, self.opt_att_layer, self.opt_mov_dec]) + # time2_0 = time.time() + # print("\t\t Zero Grad:%5f" % (time2_0 - time1)) + loss_logs = self.backward_G() + + # time2_1 = time.time() + # print("\t\t Backward_G :%5f" % (time2_1 - time2_0)) + self.loss_gen.backward() + + # time2_2 = time.time() + # print("\t\t Backward :%5f" % (time2_2 - time2_1)) + self.clip_norm([self.text_enc, self.seq_dec, self.seq_post, self.seq_pri, + self.att_layer, self.mov_dec]) + + # time2_3 = time.time() + # print("\t\t Clip Norm :%5f" % (time2_3 - time2_2)) + self.step([self.opt_text_enc, self.opt_seq_dec, self.opt_seq_post, + self.opt_seq_pri, self.opt_att_layer, self.opt_mov_dec]) + + # time2_4 = time.time() + # print("\t\t Step :%5f" % (time2_4 - time2_3)) + + # time2 = time.time() + # print("\t Update Generator Cost:%5f" % (time2 - time1)) + + # self.zero_grad([self.opt_att_layer]) + # self.backward_Att() + # self.loss_lgan_G_.backward() + # self.clip_norm([self.att_layer]) + # self.step([self.opt_att_layer]) + # # time3 = time.time() + # # print("\t Update Att Cost:%5f" % (time3 - time2)) + + # self.loss_gen += self.loss_lgan_G_ + + return loss_logs + + def to(self, device): + if self.opt.is_train: + self.gan_criterion.to(device) + self.mse_criterion.to(device) + self.l1_criterion.to(device) + self.seq_post.to(device) + self.mov_enc.to(device) + self.text_enc.to(device) + self.mov_dec.to(device) + self.seq_pri.to(device) + self.att_layer.to(device) + self.seq_dec.to(device) + + def train_mode(self): + if self.opt.is_train: + self.seq_post.train() + self.mov_enc.eval() + # self.motion_dis.train() + # self.movement_dis.train() + self.mov_dec.train() + self.text_enc.train() + self.seq_pri.train() + self.att_layer.train() + self.seq_dec.train() + + + def eval_mode(self): + if self.opt.is_train: + self.seq_post.eval() + self.mov_enc.eval() + # self.motion_dis.train() + # self.movement_dis.train() + self.mov_dec.eval() + self.text_enc.eval() + self.seq_pri.eval() + self.att_layer.eval() + self.seq_dec.eval() + + + def save(self, file_name, ep, total_it, sub_ep, sl_len): + state = { + # 'latent_dis': self.latent_dis.state_dict(), + # 'motion_dis': self.motion_dis.state_dict(), + 'text_enc': self.text_enc.state_dict(), + 'seq_post': self.seq_post.state_dict(), + 'att_layer': self.att_layer.state_dict(), + 'seq_dec': self.seq_dec.state_dict(), + 'seq_pri': self.seq_pri.state_dict(), + 'mov_enc': self.mov_enc.state_dict(), + 'mov_dec': self.mov_dec.state_dict(), + + # 'opt_motion_dis': self.opt_motion_dis.state_dict(), + 'opt_mov_dec': self.opt_mov_dec.state_dict(), + 'opt_text_enc': self.opt_text_enc.state_dict(), + 'opt_seq_pri': self.opt_seq_pri.state_dict(), + 'opt_att_layer': self.opt_att_layer.state_dict(), + 'opt_seq_post': self.opt_seq_post.state_dict(), + 'opt_seq_dec': self.opt_seq_dec.state_dict(), + # 'opt_movement_dis': self.opt_movement_dis.state_dict(), + + 'ep': ep, + 'total_it': total_it, + 'sub_ep': sub_ep, + 'sl_len': sl_len + } + torch.save(state, file_name) + return + + def load(self, model_dir): + checkpoint = torch.load(model_dir, map_location=self.device) + if self.opt.is_train: + self.seq_post.load_state_dict(checkpoint['seq_post']) + # self.opt_latent_dis.load_state_dict(checkpoint['opt_latent_dis']) + + self.opt_text_enc.load_state_dict(checkpoint['opt_text_enc']) + self.opt_seq_post.load_state_dict(checkpoint['opt_seq_post']) + self.opt_att_layer.load_state_dict(checkpoint['opt_att_layer']) + self.opt_seq_pri.load_state_dict(checkpoint['opt_seq_pri']) + self.opt_seq_dec.load_state_dict(checkpoint['opt_seq_dec']) + self.opt_mov_dec.load_state_dict(checkpoint['opt_mov_dec']) + + self.text_enc.load_state_dict(checkpoint['text_enc']) + self.mov_dec.load_state_dict(checkpoint['mov_dec']) + self.seq_pri.load_state_dict(checkpoint['seq_pri']) + self.att_layer.load_state_dict(checkpoint['att_layer']) + self.seq_dec.load_state_dict(checkpoint['seq_dec']) + self.mov_enc.load_state_dict(checkpoint['mov_enc']) + + return checkpoint['ep'], checkpoint['total_it'], checkpoint['sub_ep'], checkpoint['sl_len'] + + def train(self, train_dataset, val_dataset, plot_eval): + self.to(self.device) + + self.opt_text_enc = optim.Adam(self.text_enc.parameters(), lr=self.opt.lr) + self.opt_seq_post = optim.Adam(self.seq_post.parameters(), lr=self.opt.lr) + self.opt_seq_pri = optim.Adam(self.seq_pri.parameters(), lr=self.opt.lr) + self.opt_att_layer = optim.Adam(self.att_layer.parameters(), lr=self.opt.lr) + self.opt_seq_dec = optim.Adam(self.seq_dec.parameters(), lr=self.opt.lr) + + self.opt_mov_dec = optim.Adam(self.mov_dec.parameters(), lr=self.opt.lr*0.1) + + epoch = 0 + it = 0 + if self.opt.dataset_name == 't2m': + schedule_len = 10 + elif self.opt.dataset_name == 'kit': + schedule_len = 6 + sub_ep = 0 + + if self.opt.is_continue: + model_dir = pjoin(self.opt.model_dir, 'latest.tar') + epoch, it, sub_ep, schedule_len = self.load(model_dir) + + invalid = True + start_time = time.time() + val_loss = 0 + is_continue_and_first = self.opt.is_continue + while invalid: + train_dataset.reset_max_len(schedule_len * self.opt.unit_length) + val_dataset.reset_max_len(schedule_len * self.opt.unit_length) + + train_loader = DataLoader(train_dataset, batch_size=self.opt.batch_size, drop_last=True, num_workers=4, + shuffle=True, collate_fn=collate_fn, pin_memory=True) + val_loader = DataLoader(val_dataset, batch_size=self.opt.batch_size, drop_last=True, num_workers=4, + shuffle=True, collate_fn=collate_fn, pin_memory=True) + print("Max_Length:%03d Training Split:%05d Validation Split:%04d" % (schedule_len, len(train_loader), len(val_loader))) + + min_val_loss = np.inf + stop_cnt = 0 + logs = OrderedDict() + for sub_epoch in range(sub_ep, self.opt.max_sub_epoch): + self.train_mode() + + if is_continue_and_first: + sub_ep = 0 + is_continue_and_first = False + + tf_ratio = self.opt.tf_ratio + + time1 = time.time() + for i, batch_data in enumerate(train_loader): + time2 = time.time() + self.forward(batch_data, tf_ratio, schedule_len) + time3 = time.time() + log_dict = self.update() + for k, v in log_dict.items(): + if k not in logs: + logs[k] = v + else: + logs[k] += v + time4 = time.time() + + + it += 1 + if it % self.opt.log_every == 0: + mean_loss = OrderedDict({'val_loss': val_loss}) + self.logger.scalar_summary('val_loss', val_loss, it) + self.logger.scalar_summary('scheduled_length', schedule_len, it) + + for tag, value in logs.items(): + self.logger.scalar_summary(tag, value/self.opt.log_every, it) + mean_loss[tag] = value / self.opt.log_every + logs = OrderedDict() + print_current_loss(start_time, it, mean_loss, epoch, sub_epoch=sub_epoch, inner_iter=i, + tf_ratio=tf_ratio, sl_steps=schedule_len) + + if it % self.opt.save_latest == 0: + self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it, sub_epoch, schedule_len) + + time5 = time.time() + # print("Data Loader Time: %5f s" % ((time2 - time1))) + # print("Forward Time: %5f s" % ((time3 - time2))) + # print("Update Time: %5f s" % ((time4 - time3))) + # print('Per Iteration: %5f s' % ((time5 - time1))) + time1 = time5 + + self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it, sub_epoch, schedule_len) + + epoch += 1 + if epoch % self.opt.save_every_e == 0: + self.save(pjoin(self.opt.model_dir, 'E%03d_SE%02d_SL%02d.tar'%(epoch, sub_epoch, schedule_len)), + epoch, total_it=it, sub_ep=sub_epoch, sl_len=schedule_len) + + print('Validation time:') + + loss_mot_rec = 0 + loss_mov_rec = 0 + loss_kld = 0 + val_loss = 0 + with torch.no_grad(): + for i, batch_data in enumerate(val_loader): + self.forward(batch_data, 0, schedule_len) + self.backward_G() + loss_mot_rec += self.loss_mot_rec.item() + loss_mov_rec += self.loss_mov_rec.item() + loss_kld += self.loss_kld.item() + val_loss += self.loss_gen.item() + + loss_mot_rec /= len(val_loader) + 1 + loss_mov_rec /= len(val_loader) + 1 + loss_kld /= len(val_loader) + 1 + val_loss /= len(val_loader) + 1 + print('Validation Loss: %.5f Movement Recon Loss: %.5f Motion Recon Loss: %.5f KLD Loss: %.5f:' % + (val_loss, loss_mov_rec, loss_mot_rec, loss_kld)) + + if epoch % self.opt.eval_every_e == 0: + reco_data = self.fake_motions[:4] + with torch.no_grad(): + self.forward(batch_data, 0, schedule_len, eval_mode=True) + fake_data = self.fake_motions[:4] + gt_data = self.motions[:4] + data = torch.cat([fake_data, reco_data, gt_data], dim=0).cpu().numpy() + captions = self.caption[:4] * 3 + save_dir = pjoin(self.opt.eval_dir, 'E%03d_SE%02d_SL%02d'%(epoch, sub_epoch, schedule_len)) + os.makedirs(save_dir, exist_ok=True) + plot_eval(data, save_dir, captions) + + # if cl_ratio == 1: + if val_loss < min_val_loss: + min_val_loss = val_loss + stop_cnt = 0 + elif stop_cnt < self.opt.early_stop_count: + stop_cnt += 1 + elif stop_cnt >= self.opt.early_stop_count: + break + if val_loss - min_val_loss >= 0.1: + break + + schedule_len += 1 + + if schedule_len > 49: + invalid = False + + +class LengthEstTrainer(object): + + def __init__(self, args, estimator): + self.opt = args + self.estimator = estimator + self.device = args.device + + if args.is_train: + # self.motion_dis + self.logger = Logger(args.log_dir) + self.mul_cls_criterion = torch.nn.CrossEntropyLoss() + + def resume(self, model_dir): + checkpoints = torch.load(model_dir, map_location=self.device) + self.estimator.load_state_dict(checkpoints['estimator']) + self.opt_estimator.load_state_dict(checkpoints['opt_estimator']) + return checkpoints['epoch'], checkpoints['iter'] + + def save(self, model_dir, epoch, niter): + state = { + 'estimator': self.estimator.state_dict(), + 'opt_estimator': self.opt_estimator.state_dict(), + 'epoch': epoch, + 'niter': niter, + } + torch.save(state, model_dir) + + @staticmethod + def zero_grad(opt_list): + for opt in opt_list: + opt.zero_grad() + + @staticmethod + def clip_norm(network_list): + for network in network_list: + clip_grad_norm_(network.parameters(), 0.5) + + @staticmethod + def step(opt_list): + for opt in opt_list: + opt.step() + + def train(self, train_dataloader, val_dataloader): + self.estimator.to(self.device) + + self.opt_estimator = optim.Adam(self.estimator.parameters(), lr=self.opt.lr) + + epoch = 0 + it = 0 + + if self.opt.is_continue: + model_dir = pjoin(self.opt.model_dir, 'latest.tar') + epoch, it = self.resume(model_dir) + + start_time = time.time() + total_iters = self.opt.max_epoch * len(train_dataloader) + print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader))) + val_loss = 0 + min_val_loss = np.inf + logs = OrderedDict({'loss': 0}) + while epoch < self.opt.max_epoch: + # time0 = time.time() + for i, batch_data in enumerate(train_dataloader): + self.estimator.train() + + word_emb, pos_ohot, _, cap_lens, _, m_lens = batch_data + word_emb = word_emb.detach().to(self.device).float() + pos_ohot = pos_ohot.detach().to(self.device).float() + + pred_dis = self.estimator(word_emb, pos_ohot, cap_lens) + + self.zero_grad([self.opt_estimator]) + + gt_labels = m_lens // self.opt.unit_length + gt_labels = gt_labels.long().to(self.device) + # print(gt_labels) + # print(pred_dis) + loss = self.mul_cls_criterion(pred_dis, gt_labels) + + loss.backward() + + self.clip_norm([self.estimator]) + self.step([self.opt_estimator]) + + logs['loss'] += loss.item() + + it += 1 + if it % self.opt.log_every == 0: + mean_loss = OrderedDict({'val_loss': val_loss}) + self.logger.scalar_summary('val_loss', val_loss, it) + + for tag, value in logs.items(): + self.logger.scalar_summary(tag, value / self.opt.log_every, it) + mean_loss[tag] = value / self.opt.log_every + logs = OrderedDict({'loss': 0}) + print_current_loss_decomp(start_time, it, total_iters, mean_loss, epoch, i) + + if it % self.opt.save_latest == 0: + self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) + + self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) + + epoch += 1 + if epoch % self.opt.save_every_e == 0: + self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, it) + + print('Validation time:') + + val_loss = 0 + with torch.no_grad(): + for i, batch_data in enumerate(val_dataloader): + word_emb, pos_ohot, _, cap_lens, _, m_lens = batch_data + word_emb = word_emb.detach().to(self.device).float() + pos_ohot = pos_ohot.detach().to(self.device).float() + + pred_dis = self.estimator(word_emb, pos_ohot, cap_lens) + + gt_labels = m_lens // self.opt.unit_length + gt_labels = gt_labels.long().to(self.device) + loss = self.mul_cls_criterion(pred_dis, gt_labels) + + val_loss += loss.item() + + val_loss = val_loss / (len(val_dataloader) + 1) + print('Validation Loss: %.5f' % (val_loss)) + + if val_loss < min_val_loss: + self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it) + min_val_loss = val_loss + + +class TextMotionMatchTrainer(object): + + def __init__(self, args, text_encoder, motion_encoder, movement_encoder): + self.opt = args + self.text_encoder = text_encoder + self.motion_encoder = motion_encoder + self.movement_encoder = movement_encoder + self.device = args.device + + if args.is_train: + # self.motion_dis + self.logger = Logger(args.log_dir) + self.contrastive_loss = ContrastiveLoss(self.opt.negative_margin) + + def resume(self, model_dir): + checkpoints = torch.load(model_dir, map_location=self.device) + self.text_encoder.load_state_dict(checkpoints['text_encoder']) + self.motion_encoder.load_state_dict(checkpoints['motion_encoder']) + self.movement_encoder.load_state_dict(checkpoints['movement_encoder']) + + self.opt_text_encoder.load_state_dict(checkpoints['opt_text_encoder']) + self.opt_motion_encoder.load_state_dict(checkpoints['opt_motion_encoder']) + return checkpoints['epoch'], checkpoints['iter'] + + def save(self, model_dir, epoch, niter): + state = { + 'text_encoder': self.text_encoder.state_dict(), + 'motion_encoder': self.motion_encoder.state_dict(), + 'movement_encoder': self.movement_encoder.state_dict(), + + 'opt_text_encoder': self.opt_text_encoder.state_dict(), + 'opt_motion_encoder': self.opt_motion_encoder.state_dict(), + 'epoch': epoch, + 'iter': niter, + } + torch.save(state, model_dir) + + @staticmethod + def zero_grad(opt_list): + for opt in opt_list: + opt.zero_grad() + + @staticmethod + def clip_norm(network_list): + for network in network_list: + clip_grad_norm_(network.parameters(), 0.5) + + @staticmethod + def step(opt_list): + for opt in opt_list: + opt.step() + + def to(self, device): + self.text_encoder.to(device) + self.motion_encoder.to(device) + self.movement_encoder.to(device) + + def train_mode(self): + self.text_encoder.train() + self.motion_encoder.train() + self.movement_encoder.eval() + + def forward(self, batch_data): + word_emb, pos_ohot, caption, cap_lens, motions, m_lens, _ = batch_data + word_emb = word_emb.detach().to(self.device).float() + pos_ohot = pos_ohot.detach().to(self.device).float() + motions = motions.detach().to(self.device).float() + + # Sort the length of motions in descending order, (length of text has been sorted) + self.align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + # print(self.align_idx) + # print(m_lens[self.align_idx]) + motions = motions[self.align_idx] + m_lens = m_lens[self.align_idx] + + '''Movement Encoding''' + movements = self.movement_encoder(motions[..., :-4]).detach() + m_lens = m_lens // self.opt.unit_length + self.motion_embedding = self.motion_encoder(movements, m_lens) + + '''Text Encoding''' + # time0 = time.time() + # text_input = torch.cat([word_emb, pos_ohot], dim=-1) + self.text_embedding = self.text_encoder(word_emb, pos_ohot, cap_lens) + self.text_embedding = self.text_embedding.clone()[self.align_idx] + + + def backward(self): + + batch_size = self.text_embedding.shape[0] + '''Positive pairs''' + pos_labels = torch.zeros(batch_size).to(self.text_embedding.device) + self.loss_pos = self.contrastive_loss(self.text_embedding, self.motion_embedding, pos_labels) + + '''Negative Pairs, shifting index''' + neg_labels = torch.ones(batch_size).to(self.text_embedding.device) + shift = np.random.randint(0, batch_size-1) + new_idx = np.arange(shift, batch_size + shift) % batch_size + self.mis_motion_embedding = self.motion_embedding.clone()[new_idx] + self.loss_neg = self.contrastive_loss(self.text_embedding, self.mis_motion_embedding, neg_labels) + self.loss = self.loss_pos + self.loss_neg + + loss_logs = OrderedDict({}) + loss_logs['loss'] = self.loss.item() + loss_logs['loss_pos'] = self.loss_pos.item() + loss_logs['loss_neg'] = self.loss_neg.item() + return loss_logs + + + def update(self): + + self.zero_grad([self.opt_motion_encoder, self.opt_text_encoder]) + loss_logs = self.backward() + self.loss.backward() + self.clip_norm([self.text_encoder, self.motion_encoder]) + self.step([self.opt_text_encoder, self.opt_motion_encoder]) + + return loss_logs + + + def train(self, train_dataloader, val_dataloader): + self.to(self.device) + + self.opt_motion_encoder = optim.Adam(self.motion_encoder.parameters(), lr=self.opt.lr) + self.opt_text_encoder = optim.Adam(self.text_encoder.parameters(), lr=self.opt.lr) + + epoch = 0 + it = 0 + + if self.opt.is_continue: + model_dir = pjoin(self.opt.model_dir, 'latest.tar') + epoch, it = self.resume(model_dir) + + start_time = time.time() + total_iters = self.opt.max_epoch * len(train_dataloader) + print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader))) + val_loss = 0 + logs = OrderedDict() + + min_val_loss = np.inf + while epoch < self.opt.max_epoch: + # time0 = time.time() + for i, batch_data in enumerate(train_dataloader): + self.train_mode() + + self.forward(batch_data) + # time3 = time.time() + log_dict = self.update() + for k, v in log_dict.items(): + if k not in logs: + logs[k] = v + else: + logs[k] += v + + + it += 1 + if it % self.opt.log_every == 0: + mean_loss = OrderedDict({'val_loss': val_loss}) + self.logger.scalar_summary('val_loss', val_loss, it) + + for tag, value in logs.items(): + self.logger.scalar_summary(tag, value / self.opt.log_every, it) + mean_loss[tag] = value / self.opt.log_every + logs = OrderedDict() + print_current_loss_decomp(start_time, it, total_iters, mean_loss, epoch, i) + + if it % self.opt.save_latest == 0: + self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) + + self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) + + epoch += 1 + if epoch % self.opt.save_every_e == 0: + self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, it) + + print('Validation time:') + + loss_pos_pair = 0 + loss_neg_pair = 0 + val_loss = 0 + with torch.no_grad(): + for i, batch_data in enumerate(val_dataloader): + self.forward(batch_data) + self.backward() + loss_pos_pair += self.loss_pos.item() + loss_neg_pair += self.loss_neg.item() + val_loss += self.loss.item() + + loss_pos_pair /= len(val_dataloader) + 1 + loss_neg_pair /= len(val_dataloader) + 1 + val_loss /= len(val_dataloader) + 1 + print('Validation Loss: %.5f Positive Loss: %.5f Negative Loss: %.5f' % + (val_loss, loss_pos_pair, loss_neg_pair)) + + if val_loss < min_val_loss: + self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it) + min_val_loss = val_loss + + if epoch % self.opt.eval_every_e == 0: + pos_dist = F.pairwise_distance(self.text_embedding, self.motion_embedding) + neg_dist = F.pairwise_distance(self.text_embedding, self.mis_motion_embedding) + + pos_str = ' '.join(['%.3f' % (pos_dist[i]) for i in range(pos_dist.shape[0])]) + neg_str = ' '.join(['%.3f' % (neg_dist[i]) for i in range(neg_dist.shape[0])]) + + save_path = pjoin(self.opt.eval_dir, 'E%03d.txt' % (epoch)) + with cs.open(save_path, 'w') as f: + f.write('Positive Pairs Distance\n') + f.write(pos_str + '\n') + f.write('Negative Pairs Distance\n') + f.write(neg_str + '\n') diff --git a/main/data_loaders/humanml/scripts/motion_process.py b/main/data_loaders/humanml/scripts/motion_process.py new file mode 100644 index 0000000000000000000000000000000000000000..704ce20b9f00859058da641f7a17188d30e6a05a --- /dev/null +++ b/main/data_loaders/humanml/scripts/motion_process.py @@ -0,0 +1,529 @@ +from os.path import join as pjoin + +from data_loaders.humanml.common.skeleton import Skeleton +import numpy as np +import os +from data_loaders.humanml.common.quaternion import * +from data_loaders.humanml.utils.paramUtil import * + +import torch +from tqdm import tqdm + +# positions (batch, joint_num, 3) +def uniform_skeleton(positions, target_offset): + src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu') + src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0])) + src_offset = src_offset.numpy() + tgt_offset = target_offset.numpy() + # print(src_offset) + # print(tgt_offset) + '''Calculate Scale Ratio as the ratio of legs''' + src_leg_len = np.abs(src_offset[l_idx1]).max() + np.abs(src_offset[l_idx2]).max() + tgt_leg_len = np.abs(tgt_offset[l_idx1]).max() + np.abs(tgt_offset[l_idx2]).max() + + scale_rt = tgt_leg_len / src_leg_len + # print(scale_rt) + src_root_pos = positions[:, 0] + tgt_root_pos = src_root_pos * scale_rt + + '''Inverse Kinematics''' + quat_params = src_skel.inverse_kinematics_np(positions, face_joint_indx) + # print(quat_params.shape) + + '''Forward Kinematics''' + src_skel.set_offset(target_offset) + new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos) + return new_joints + + +def extract_features(positions, feet_thre, n_raw_offsets, kinematic_chain, face_joint_indx, fid_r, fid_l): + global_positions = positions.copy() + """ Get Foot Contacts """ + + def foot_detect(positions, thres): + velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0]) + + feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2 + feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2 + feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2 + # feet_l_h = positions[:-1,fid_l,1] + # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float) + feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float) + + feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2 + feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2 + feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2 + # feet_r_h = positions[:-1,fid_r,1] + # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float) + feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float) + return feet_l, feet_r + + # + feet_l, feet_r = foot_detect(positions, feet_thre) + # feet_l, feet_r = foot_detect(positions, 0.002) + + '''Quaternion and Cartesian representation''' + r_rot = None + + def get_rifke(positions): + '''Local pose''' + positions[..., 0] -= positions[:, 0:1, 0] + positions[..., 2] -= positions[:, 0:1, 2] + '''All pose face Z+''' + positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions) + return positions + + def get_quaternion(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False) + + '''Fix Quaternion Discontinuity''' + quat_params = qfix(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + quat_params[1:, 0] = r_velocity + # (seq_len, joints_num, 4) + return quat_params, r_velocity, velocity, r_rot + + def get_cont6d_params(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True) + + '''Quaternion to continuous 6D''' + cont_6d_params = quaternion_to_cont6d_np(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + # (seq_len, joints_num, 4) + return cont_6d_params, r_velocity, velocity, r_rot + + cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions) + positions = get_rifke(positions) + + # trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0) + # r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]]) + + # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*') + # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r') + # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g') + # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y') + # plt.xlabel('x') + # plt.ylabel('z') + # plt.axis('equal') + # plt.show() + + '''Root height''' + root_y = positions[:, 0, 1:2] + + '''Root rotation and linear velocity''' + # (seq_len-1, 1) rotation velocity along y-axis + # (seq_len-1, 2) linear velovity on xz plane + r_velocity = np.arcsin(r_velocity[:, 2:3]) + l_velocity = velocity[:, [0, 2]] + # print(r_velocity.shape, l_velocity.shape, root_y.shape) + root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1) + + '''Get Joint Rotation Representation''' + # (seq_len, (joints_num-1) *6) quaternion for skeleton joints + rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1) + + '''Get Joint Rotation Invariant Position Represention''' + # (seq_len, (joints_num-1)*3) local joint position + ric_data = positions[:, 1:].reshape(len(positions), -1) + + '''Get Joint Velocity Representation''' + # (seq_len-1, joints_num*3) + local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1), + global_positions[1:] - global_positions[:-1]) + local_vel = local_vel.reshape(len(local_vel), -1) + + data = root_data + data = np.concatenate([data, ric_data[:-1]], axis=-1) + data = np.concatenate([data, rot_data[:-1]], axis=-1) + # print(dataset.shape, local_vel.shape) + data = np.concatenate([data, local_vel], axis=-1) + data = np.concatenate([data, feet_l, feet_r], axis=-1) + + return data + + +def process_file(positions, feet_thre): + # (seq_len, joints_num, 3) + # '''Down Sample''' + # positions = positions[::ds_num] + + '''Uniform Skeleton''' + positions = uniform_skeleton(positions, tgt_offsets) + + '''Put on Floor''' + floor_height = positions.min(axis=0).min(axis=0)[1] + positions[:, :, 1] -= floor_height + # print(floor_height) + + # plot_3d_motion("./positions_1.mp4", kinematic_chain, positions, 'title', fps=20) + + '''XZ at origin''' + root_pos_init = positions[0] + root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1]) + positions = positions - root_pose_init_xz + + # '''Move the first pose to origin ''' + # root_pos_init = positions[0] + # positions = positions - root_pos_init[0] + + '''All initially face Z+''' + r_hip, l_hip, sdr_r, sdr_l = face_joint_indx + across1 = root_pos_init[r_hip] - root_pos_init[l_hip] + across2 = root_pos_init[sdr_r] - root_pos_init[sdr_l] + across = across1 + across2 + across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] + + # forward (3,), rotate around y-axis + forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + # forward (3,) + forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] + + # print(forward_init) + + target = np.array([[0, 0, 1]]) + root_quat_init = qbetween_np(forward_init, target) + root_quat_init = np.ones(positions.shape[:-1] + (4,)) * root_quat_init + + positions_b = positions.copy() + + positions = qrot_np(root_quat_init, positions) + + # plot_3d_motion("./positions_2.mp4", kinematic_chain, positions, 'title', fps=20) + + '''New ground truth positions''' + global_positions = positions.copy() + + # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*') + # plt.plot(positions[:, 0, 0], positions[:, 0, 2], marker='o', color='r') + # plt.xlabel('x') + # plt.ylabel('z') + # plt.axis('equal') + # plt.show() + + """ Get Foot Contacts """ + + def foot_detect(positions, thres): + velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0]) + + feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2 + feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2 + feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2 + # feet_l_h = positions[:-1,fid_l,1] + # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float) + feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float) + + feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2 + feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2 + feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2 + # feet_r_h = positions[:-1,fid_r,1] + # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float) + feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float) + return feet_l, feet_r + # + feet_l, feet_r = foot_detect(positions, feet_thre) + # feet_l, feet_r = foot_detect(positions, 0.002) + + '''Quaternion and Cartesian representation''' + r_rot = None + + def get_rifke(positions): + '''Local pose''' + positions[..., 0] -= positions[:, 0:1, 0] + positions[..., 2] -= positions[:, 0:1, 2] + '''All pose face Z+''' + positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions) + return positions + + def get_quaternion(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False) + + '''Fix Quaternion Discontinuity''' + quat_params = qfix(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + quat_params[1:, 0] = r_velocity + # (seq_len, joints_num, 4) + return quat_params, r_velocity, velocity, r_rot + + def get_cont6d_params(positions): + skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu") + # (seq_len, joints_num, 4) + quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True) + + '''Quaternion to continuous 6D''' + cont_6d_params = quaternion_to_cont6d_np(quat_params) + # (seq_len, 4) + r_rot = quat_params[:, 0].copy() + # print(r_rot[0]) + '''Root Linear Velocity''' + # (seq_len - 1, 3) + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + # print(r_rot.shape, velocity.shape) + velocity = qrot_np(r_rot[1:], velocity) + '''Root Angular Velocity''' + # (seq_len - 1, 4) + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + # (seq_len, joints_num, 4) + return cont_6d_params, r_velocity, velocity, r_rot + + cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions) + positions = get_rifke(positions) + + # trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0) + # r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]]) + + # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*') + # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r') + # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g') + # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y') + # plt.xlabel('x') + # plt.ylabel('z') + # plt.axis('equal') + # plt.show() + + '''Root height''' + root_y = positions[:, 0, 1:2] + + '''Root rotation and linear velocity''' + # (seq_len-1, 1) rotation velocity along y-axis + # (seq_len-1, 2) linear velovity on xz plane + r_velocity = np.arcsin(r_velocity[:, 2:3]) + l_velocity = velocity[:, [0, 2]] + # print(r_velocity.shape, l_velocity.shape, root_y.shape) + root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1) + + '''Get Joint Rotation Representation''' + # (seq_len, (joints_num-1) *6) quaternion for skeleton joints + rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1) + + '''Get Joint Rotation Invariant Position Represention''' + # (seq_len, (joints_num-1)*3) local joint position + ric_data = positions[:, 1:].reshape(len(positions), -1) + + '''Get Joint Velocity Representation''' + # (seq_len-1, joints_num*3) + local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1), + global_positions[1:] - global_positions[:-1]) + local_vel = local_vel.reshape(len(local_vel), -1) + + data = root_data + data = np.concatenate([data, ric_data[:-1]], axis=-1) + data = np.concatenate([data, rot_data[:-1]], axis=-1) + # print(dataset.shape, local_vel.shape) + data = np.concatenate([data, local_vel], axis=-1) + data = np.concatenate([data, feet_l, feet_r], axis=-1) + + return data, global_positions, positions, l_velocity + + +# Recover global angle and positions for rotation dataset +# root_rot_velocity (B, seq_len, 1) +# root_linear_velocity (B, seq_len, 2) +# root_y (B, seq_len, 1) +# ric_data (B, seq_len, (joint_num - 1)*3) +# rot_data (B, seq_len, (joint_num - 1)*6) +# local_velocity (B, seq_len, joint_num*3) +# foot contact (B, seq_len, 4) +def recover_root_rot_pos(data): + rot_vel = data[..., 0] + r_rot_ang = torch.zeros_like(rot_vel).to(data.device) + '''Get Y-axis rotation from rotation velocity''' + r_rot_ang[..., 1:] = rot_vel[..., :-1] + r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) + + r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device) + r_rot_quat[..., 0] = torch.cos(r_rot_ang) + r_rot_quat[..., 2] = torch.sin(r_rot_ang) + + r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device) + r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] + '''Add Y-axis rotation to root position''' + r_pos = qrot(qinv(r_rot_quat), r_pos) + + r_pos = torch.cumsum(r_pos, dim=-2) + + r_pos[..., 1] = data[..., 3] + return r_rot_quat, r_pos + + +def recover_from_rot(data, joints_num, skeleton): + r_rot_quat, r_pos = recover_root_rot_pos(data) + + r_rot_cont6d = quaternion_to_cont6d(r_rot_quat) + + start_indx = 1 + 2 + 1 + (joints_num - 1) * 3 + end_indx = start_indx + (joints_num - 1) * 6 + cont6d_params = data[..., start_indx:end_indx] + # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape) + cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1) + cont6d_params = cont6d_params.view(-1, joints_num, 6) + + positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos) + + return positions + +def recover_rot(data): + # dataset [bs, seqlen, 263/251] HumanML/KIT + joints_num = 22 if data.shape[-1] == 263 else 21 + r_rot_quat, r_pos = recover_root_rot_pos(data) + r_pos_pad = torch.cat([r_pos, torch.zeros_like(r_pos)], dim=-1).unsqueeze(-2) + r_rot_cont6d = quaternion_to_cont6d(r_rot_quat) + start_indx = 1 + 2 + 1 + (joints_num - 1) * 3 + end_indx = start_indx + (joints_num - 1) * 6 + cont6d_params = data[..., start_indx:end_indx] + cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1) + cont6d_params = cont6d_params.view(-1, joints_num, 6) + cont6d_params = torch.cat([cont6d_params, r_pos_pad], dim=-2) + return cont6d_params + + +def recover_from_ric(data, joints_num): + r_rot_quat, r_pos = recover_root_rot_pos(data) + positions = data[..., 4:(joints_num - 1) * 3 + 4] + positions = positions.view(positions.shape[:-1] + (-1, 3)) + + '''Add Y-axis rotation to local joints''' + positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) + + '''Add root XZ to joints''' + positions[..., 0] += r_pos[..., 0:1] + positions[..., 2] += r_pos[..., 2:3] + + '''Concate root and joints''' + positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) + + return positions +''' +For Text2Motion Dataset +''' +''' +if __name__ == "__main__": + example_id = "000021" + # Lower legs + l_idx1, l_idx2 = 5, 8 + # Right/Left foot + fid_r, fid_l = [8, 11], [7, 10] + # Face direction, r_hip, l_hip, sdr_r, sdr_l + face_joint_indx = [2, 1, 17, 16] + # l_hip, r_hip + r_hip, l_hip = 2, 1 + joints_num = 22 + # ds_num = 8 + data_dir = '../dataset/pose_data_raw/joints/' + save_dir1 = '../dataset/pose_data_raw/new_joints/' + save_dir2 = '../dataset/pose_data_raw/new_joint_vecs/' + + n_raw_offsets = torch.from_numpy(t2m_raw_offsets) + kinematic_chain = t2m_kinematic_chain + + # Get offsets of target skeleton + example_data = np.load(os.path.join(data_dir, example_id + '.npy')) + example_data = example_data.reshape(len(example_data), -1, 3) + example_data = torch.from_numpy(example_data) + tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu') + # (joints_num, 3) + tgt_offsets = tgt_skel.get_offsets_joints(example_data[0]) + # print(tgt_offsets) + + source_list = os.listdir(data_dir) + frame_num = 0 + for source_file in tqdm(source_list): + source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num] + try: + dataset, ground_positions, positions, l_velocity = process_file(source_data, 0.002) + rec_ric_data = recover_from_ric(torch.from_numpy(dataset).unsqueeze(0).float(), joints_num) + np.save(pjoin(save_dir1, source_file), rec_ric_data.squeeze().numpy()) + np.save(pjoin(save_dir2, source_file), dataset) + frame_num += dataset.shape[0] + except Exception as e: + print(source_file) + print(e) + + print('Total clips: %d, Frames: %d, Duration: %fm' % + (len(source_list), frame_num, frame_num / 20 / 60)) +''' + +if __name__ == "__main__": + example_id = "03950_gt" + # Lower legs + l_idx1, l_idx2 = 17, 18 + # Right/Left foot + fid_r, fid_l = [14, 15], [19, 20] + # Face direction, r_hip, l_hip, sdr_r, sdr_l + face_joint_indx = [11, 16, 5, 8] + # l_hip, r_hip + r_hip, l_hip = 11, 16 + joints_num = 21 + # ds_num = 8 + data_dir = '../dataset/kit_mocap_dataset/joints/' + save_dir1 = '../dataset/kit_mocap_dataset/new_joints/' + save_dir2 = '../dataset/kit_mocap_dataset/new_joint_vecs/' + + n_raw_offsets = torch.from_numpy(kit_raw_offsets) + kinematic_chain = kit_kinematic_chain + + '''Get offsets of target skeleton''' + example_data = np.load(os.path.join(data_dir, example_id + '.npy')) + example_data = example_data.reshape(len(example_data), -1, 3) + example_data = torch.from_numpy(example_data) + tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu') + # (joints_num, 3) + tgt_offsets = tgt_skel.get_offsets_joints(example_data[0]) + # print(tgt_offsets) + + source_list = os.listdir(data_dir) + frame_num = 0 + '''Read source dataset''' + for source_file in tqdm(source_list): + source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num] + try: + name = ''.join(source_file[:-7].split('_')) + '.npy' + data, ground_positions, positions, l_velocity = process_file(source_data, 0.05) + rec_ric_data = recover_from_ric(torch.from_numpy(data).unsqueeze(0).float(), joints_num) + if np.isnan(rec_ric_data.numpy()).any(): + print(source_file) + continue + np.save(pjoin(save_dir1, name), rec_ric_data.squeeze().numpy()) + np.save(pjoin(save_dir2, name), data) + frame_num += data.shape[0] + except Exception as e: + print(source_file) + print(e) + + print('Total clips: %d, Frames: %d, Duration: %fm' % + (len(source_list), frame_num, frame_num / 12.5 / 60)) \ No newline at end of file diff --git a/main/data_loaders/humanml/utils/get_opt.py b/main/data_loaders/humanml/utils/get_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..c331b4dde8cc71c2ce33916945d75a43fc32308f --- /dev/null +++ b/main/data_loaders/humanml/utils/get_opt.py @@ -0,0 +1,81 @@ +import os +from argparse import Namespace +import re +from os.path import join as pjoin +from data_loaders.humanml.utils.word_vectorizer import POS_enumerator + + +def is_float(numStr): + flag = False + numStr = str(numStr).strip().lstrip('-').lstrip('+') # ๅŽป้™คๆญฃๆ•ฐ(+)ใ€่ดŸๆ•ฐ(-)็ฌฆๅท + try: + reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$') + res = reg.match(str(numStr)) + if res: + flag = True + except Exception as ex: + print("is_float() - error: " + str(ex)) + return flag + + +def is_number(numStr): + flag = False + numStr = str(numStr).strip().lstrip('-').lstrip('+') # ๅŽป้™คๆญฃๆ•ฐ(+)ใ€่ดŸๆ•ฐ(-)็ฌฆๅท + if str(numStr).isdigit(): + flag = True + return flag + + +def get_opt(opt_path, device): + opt = Namespace() + opt_dict = vars(opt) + + skip = ('-------------- End ----------------', + '------------ Options -------------', + '\n') + print('Reading', opt_path) + with open(opt_path) as f: + for line in f: + if line.strip() not in skip: + # print(line.strip()) + key, value = line.strip().split(': ') + if value in ('True', 'False'): + opt_dict[key] = bool(value) + elif is_float(value): + opt_dict[key] = float(value) + elif is_number(value): + opt_dict[key] = int(value) + else: + opt_dict[key] = str(value) + + # print(opt) + opt_dict['which_epoch'] = 'latest' + opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) + opt.model_dir = pjoin(opt.save_root, 'model') + opt.meta_dir = pjoin(opt.save_root, 'meta') + + if opt.dataset_name == 't2m': + opt.data_root = './dataset/HumanML3D' + opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') + opt.text_dir = pjoin(opt.data_root, 'texts') + opt.joints_num = 22 + opt.dim_pose = 263 + opt.max_motion_length = 196 + elif opt.dataset_name == 'kit': + opt.data_root = './dataset/KIT-ML' + opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') + opt.text_dir = pjoin(opt.data_root, 'texts') + opt.joints_num = 21 + opt.dim_pose = 251 + opt.max_motion_length = 196 + else: + raise KeyError('Dataset not recognized') + + opt.dim_word = 300 + opt.num_classes = 200 // opt.unit_length + opt.dim_pos_ohot = len(POS_enumerator) + opt.is_train = False + opt.is_continue = False + opt.device = device + + return opt \ No newline at end of file diff --git a/main/data_loaders/humanml/utils/metrics.py b/main/data_loaders/humanml/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6c68192ccecc3a1d739d2e4e53ffb12800efc1 --- /dev/null +++ b/main/data_loaders/humanml/utils/metrics.py @@ -0,0 +1,146 @@ +import numpy as np +from scipy import linalg + + +# (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train +def euclidean_distance_matrix(matrix1, matrix2): + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dist: N1 x N2 + dist[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) + d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) + d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) + dists = np.sqrt(d1 + d2 + d3) # broadcasting + return dists + +def calculate_top_k(mat, top_k): + size = mat.shape[0] + gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) + bool_mat = (mat == gt_mat) + correct_vec = False + top_k_list = [] + for i in range(top_k): +# print(correct_vec, bool_mat[:, i]) + correct_vec = (correct_vec | bool_mat[:, i]) + # print(correct_vec) + top_k_list.append(correct_vec[:, None]) + top_k_mat = np.concatenate(top_k_list, axis=1) + return top_k_mat + + +def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False): + dist_mat = euclidean_distance_matrix(embedding1, embedding2) + argmax = np.argsort(dist_mat, axis=1) + top_k_mat = calculate_top_k(argmax, top_k) + if sum_all: + return top_k_mat.sum(axis=0) + else: + return top_k_mat + + +def calculate_matching_score(embedding1, embedding2, sum_all=False): + assert len(embedding1.shape) == 2 + assert embedding1.shape[0] == embedding2.shape[0] + assert embedding1.shape[1] == embedding2.shape[1] + + dist = linalg.norm(embedding1 - embedding2, axis=1) + if sum_all: + return dist.sum(axis=0) + else: + return dist + + + +def calculate_activation_statistics(activations): + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + mu = np.mean(activations, axis=0) + cov = np.cov(activations, rowvar=False) + return mu, cov + + +def calculate_diversity(activation, diversity_times): + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, diversity_times, replace=False) + second_indices = np.random.choice(num_samples, diversity_times, replace=False) + dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) + return dist.mean() + + +def calculate_multimodality(activation, multimodality_times): + assert len(activation.shape) == 3 + assert activation.shape[1] > multimodality_times + num_per_sent = activation.shape[1] + + first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) + second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) + dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) + return dist.mean() + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative dataset set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative dataset set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) \ No newline at end of file diff --git a/main/data_loaders/humanml/utils/paramUtil.py b/main/data_loaders/humanml/utils/paramUtil.py new file mode 100644 index 0000000000000000000000000000000000000000..a9f1708b85ca80a9051cb3675cec9b999a0d0e2b --- /dev/null +++ b/main/data_loaders/humanml/utils/paramUtil.py @@ -0,0 +1,63 @@ +import numpy as np + +# Define a kinematic tree for the skeletal struture +kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] + +kit_raw_offsets = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1] + ] +) + +t2m_raw_offsets = np.array([[0,0,0], + [1,0,0], + [-1,0,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,0,1], + [0,0,1], + [0,1,0], + [1,0,0], + [-1,0,0], + [0,0,1], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0]]) + +t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] +t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] +t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] + + +kit_tgt_skel_id = '03950' + +t2m_tgt_skel_id = '000021' + diff --git a/main/data_loaders/humanml/utils/plot_script.py b/main/data_loaders/humanml/utils/plot_script.py new file mode 100644 index 0000000000000000000000000000000000000000..7594750d27ebac87a971f99323eaa1d31e5541f7 --- /dev/null +++ b/main/data_loaders/humanml/utils/plot_script.py @@ -0,0 +1,132 @@ +import math +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +from matplotlib.animation import FuncAnimation, FFMpegFileWriter +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import mpl_toolkits.mplot3d.axes3d as p3 +# import cv2 +from textwrap import wrap + + +def list_cut_average(ll, intervals): + if intervals == 1: + return ll + + bins = math.ceil(len(ll) * 1.0 / intervals) + ll_new = [] + for i in range(bins): + l_low = intervals * i + l_high = l_low + intervals + l_high = l_high if l_high < len(ll) else len(ll) + ll_new.append(np.mean(ll[l_low:l_high])) + return ll_new + + +def plot_3d_motion(save_path, kinematic_tree, joints, title, dataset, figsize=(3, 3), fps=120, radius=3, + vis_mode='default', gt_frames=[]): + matplotlib.use('Agg') + + title = '\n'.join(wrap(title, 20)) + + def init(): + ax.set_xlim3d([-radius / 2, radius / 2]) + ax.set_ylim3d([0, radius]) + ax.set_zlim3d([-radius / 3., radius * 2 / 3.]) + # print(title) + fig.suptitle(title, fontsize=10) + ax.grid(b=False) + + def plot_xzPlane(minx, maxx, miny, minz, maxz): + ## Plot a plane XZ + verts = [ + [minx, miny, minz], + [minx, miny, maxz], + [maxx, miny, maxz], + [maxx, miny, minz] + ] + xz_plane = Poly3DCollection([verts]) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) + ax.add_collection3d(xz_plane) + + # return ax + + # (seq_len, joints_num, 3) + data = joints.copy().reshape(len(joints), -1, 3) + + # preparation related to specific datasets + if dataset == 'kit': + data *= 0.003 # scale for visualization + elif dataset == 'humanml': + data *= 1.3 # scale for visualization + elif dataset in ['humanact12', 'uestc']: + data *= -1.5 # reverse axes, scale for visualization + + fig = plt.figure(figsize=figsize) + plt.tight_layout() + ax = p3.Axes3D(fig) + init() + MINS = data.min(axis=0).min(axis=0) + MAXS = data.max(axis=0).max(axis=0) + colors_blue = ["#4D84AA", "#5B9965", "#61CEB9", "#34C1E2", "#80B79A"] # GT color + colors_orange = ["#DD5A37", "#D69E00", "#B75A39", "#FF6D00", "#DDB50E"] # Generation color + colors = colors_orange + if vis_mode == 'upper_body': # lower body taken fixed to input motion + colors[0] = colors_blue[0] + colors[1] = colors_blue[1] + elif vis_mode == 'gt': + colors = colors_blue + + frame_number = data.shape[0] + # print(dataset.shape) + + height_offset = MINS[1] + data[:, :, 1] -= height_offset + trajec = data[:, 0, [0, 2]] + + data[..., 0] -= data[:, 0:1, 0] + data[..., 2] -= data[:, 0:1, 2] + + # print(trajec.shape) + + def update(index): + # print(index) + ax.lines = [] + ax.collections = [] + ax.view_init(elev=120, azim=-90) + ax.dist = 7.5 + # ax = + plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], + MAXS[2] - trajec[index, 1]) + # ax.scatter(dataset[index, :22, 0], dataset[index, :22, 1], dataset[index, :22, 2], color='black', s=3) + + # if index > 1: + # ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]), + # trajec[:index, 1] - trajec[index, 1], linewidth=1.0, + # color='blue') + # # ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2]) + + used_colors = colors_blue if index in gt_frames else colors + for i, (chain, color) in enumerate(zip(kinematic_tree, used_colors)): + if i < 5: + linewidth = 4.0 + else: + linewidth = 2.0 + ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, + color=color) + # print(trajec[:index, 0].shape) + + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False) + + # writer = FFMpegFileWriter(fps=fps) + ani.save(save_path, fps=fps) + # ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False, init_func=init) + # ani.save(save_path, writer='pillow', fps=1000 / fps) + + plt.close() diff --git a/main/data_loaders/humanml/utils/utils.py b/main/data_loaders/humanml/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb26cb77b6387ee2e376864a8682b35bd0b5dc95 --- /dev/null +++ b/main/data_loaders/humanml/utils/utils.py @@ -0,0 +1,168 @@ +import os +import numpy as np +# import cv2 +from PIL import Image +from data_loaders.humanml.utils import paramUtil +import math +import time +import matplotlib.pyplot as plt +from scipy.ndimage import gaussian_filter + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + +COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + +MISSING_VALUE = -1 + +def save_image(image_numpy, image_path): + img_pil = Image.fromarray(image_numpy) + img_pil.save(image_path) + + +def save_logfile(log_loss, save_path): + with open(save_path, 'wt') as f: + for k, v in log_loss.items(): + w_line = k + for digit in v: + w_line += ' %.3f' % digit + f.write(w_line + '\n') + + +def print_current_loss(start_time, niter_state, losses, epoch=None, sub_epoch=None, + inner_iter=None, tf_ratio=None, sl_steps=None): + + def as_minutes(s): + m = math.floor(s / 60) + s -= m * 60 + return '%dm %ds' % (m, s) + + def time_since(since, percent): + now = time.time() + s = now - since + es = s / percent + rs = es - s + return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) + + if epoch is not None: + print('epoch: %3d niter: %6d sub_epoch: %2d inner_iter: %4d' % (epoch, niter_state, sub_epoch, inner_iter), end=" ") + + # message = '%s niter: %d completed: %3d%%)' % (time_since(start_time, niter_state / total_niters), + # niter_state, niter_state / total_niters * 100) + now = time.time() + message = '%s'%(as_minutes(now - start_time)) + + for k, v in losses.items(): + message += ' %s: %.4f ' % (k, v) + message += ' sl_length:%2d tf_ratio:%.2f'%(sl_steps, tf_ratio) + print(message) + +def print_current_loss_decomp(start_time, niter_state, total_niters, losses, epoch=None, inner_iter=None): + + def as_minutes(s): + m = math.floor(s / 60) + s -= m * 60 + return '%dm %ds' % (m, s) + + def time_since(since, percent): + now = time.time() + s = now - since + es = s / percent + rs = es - s + return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) + + print('epoch: %03d inner_iter: %5d' % (epoch, inner_iter), end=" ") + # now = time.time() + message = '%s niter: %07d completed: %3d%%)'%(time_since(start_time, niter_state / total_niters), niter_state, niter_state / total_niters * 100) + for k, v in losses.items(): + message += ' %s: %.4f ' % (k, v) + print(message) + + +def compose_gif_img_list(img_list, fp_out, duration): + img, *imgs = [Image.fromarray(np.array(image)) for image in img_list] + img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False, + save_all=True, loop=0, duration=duration) + + +def save_images(visuals, image_path): + if not os.path.exists(image_path): + os.makedirs(image_path) + + for i, (label, img_numpy) in enumerate(visuals.items()): + img_name = '%d_%s.jpg' % (i, label) + save_path = os.path.join(image_path, img_name) + save_image(img_numpy, save_path) + + +def save_images_test(visuals, image_path, from_name, to_name): + if not os.path.exists(image_path): + os.makedirs(image_path) + + for i, (label, img_numpy) in enumerate(visuals.items()): + img_name = "%s_%s_%s" % (from_name, to_name, label) + save_path = os.path.join(image_path, img_name) + save_image(img_numpy, save_path) + + +def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)): + # print(col, row) + compose_img = compose_image(img_list, col, row, img_size) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + img_path = os.path.join(save_dir, img_name) + # print(img_path) + compose_img.save(img_path) + + +def compose_image(img_list, col, row, img_size): + to_image = Image.new('RGB', (col * img_size[0], row * img_size[1])) + for y in range(0, row): + for x in range(0, col): + from_img = Image.fromarray(img_list[y * col + x]) + # print((x * img_size[0], y*img_size[1], + # (x + 1) * img_size[0], (y + 1) * img_size[1])) + paste_area = (x * img_size[0], y*img_size[1], + (x + 1) * img_size[0], (y + 1) * img_size[1]) + to_image.paste(from_img, paste_area) + # to_image[y*img_size[1]:(y + 1) * img_size[1], x * img_size[0] :(x + 1) * img_size[0]] = from_img + return to_image + + +def plot_loss_curve(losses, save_path, intervals=500): + plt.figure(figsize=(10, 5)) + plt.title("Loss During Training") + for key in losses.keys(): + plt.plot(list_cut_average(losses[key], intervals), label=key) + plt.xlabel("Iterations/" + str(intervals)) + plt.ylabel("Loss") + plt.legend() + plt.savefig(save_path) + plt.show() + + +def list_cut_average(ll, intervals): + if intervals == 1: + return ll + + bins = math.ceil(len(ll) * 1.0 / intervals) + ll_new = [] + for i in range(bins): + l_low = intervals * i + l_high = l_low + intervals + l_high = l_high if l_high < len(ll) else len(ll) + ll_new.append(np.mean(ll[l_low:l_high])) + return ll_new + + +def motion_temporal_filter(motion, sigma=1): + motion = motion.reshape(motion.shape[0], -1) + # print(motion.shape)โ€จ + for i in range(motion.shape[1]): + motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") + return motion.reshape(motion.shape[0], -1, 3) + diff --git a/main/data_loaders/humanml/utils/word_vectorizer.py b/main/data_loaders/humanml/utils/word_vectorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..68c5956ff39f840d03c9a352e65291d26e2dfbd4 --- /dev/null +++ b/main/data_loaders/humanml/utils/word_vectorizer.py @@ -0,0 +1,80 @@ +import numpy as np +import pickle +from os.path import join as pjoin + +POS_enumerator = { + 'VERB': 0, + 'NOUN': 1, + 'DET': 2, + 'ADP': 3, + 'NUM': 4, + 'AUX': 5, + 'PRON': 6, + 'ADJ': 7, + 'ADV': 8, + 'Loc_VIP': 9, + 'Body_VIP': 10, + 'Obj_VIP': 11, + 'Act_VIP': 12, + 'Desc_VIP': 13, + 'OTHER': 14, +} + +Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', + 'up', 'down', 'straight', 'curve') + +Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') + +Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') + +Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', + 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', + 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') + +Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', + 'angrily', 'sadly') + +VIP_dict = { + 'Loc_VIP': Loc_list, + 'Body_VIP': Body_list, + 'Obj_VIP': Obj_List, + 'Act_VIP': Act_list, + 'Desc_VIP': Desc_list, +} + + +class WordVectorizer(object): + def __init__(self, meta_root, prefix): + vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) + words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) + word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) + self.word2vec = {w: vectors[word2idx[w]] for w in words} + + def _get_pos_ohot(self, pos): + pos_vec = np.zeros(len(POS_enumerator)) + if pos in POS_enumerator: + pos_vec[POS_enumerator[pos]] = 1 + else: + pos_vec[POS_enumerator['OTHER']] = 1 + return pos_vec + + def __len__(self): + return len(self.word2vec) + + def __getitem__(self, item): + word, pos = item.split('/') + if word in self.word2vec: + word_vec = self.word2vec[word] + vip_pos = None + for key, values in VIP_dict.items(): + if word in values: + vip_pos = key + break + if vip_pos is not None: + pos_vec = self._get_pos_ohot(vip_pos) + else: + pos_vec = self._get_pos_ohot(pos) + else: + word_vec = self.word2vec['unk'] + pos_vec = self._get_pos_ohot('OTHER') + return word_vec, pos_vec \ No newline at end of file diff --git a/main/data_loaders/humanml_utils.py b/main/data_loaders/humanml_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..83ab7cb7916fe6c2bfbd8325ec83c87f4e6fceb9 --- /dev/null +++ b/main/data_loaders/humanml_utils.py @@ -0,0 +1,54 @@ +import numpy as np + +HML_JOINT_NAMES = [ + 'pelvis', + 'left_hip', + 'right_hip', + 'spine1', + 'left_knee', + 'right_knee', + 'spine2', + 'left_ankle', + 'right_ankle', + 'spine3', + 'left_foot', + 'right_foot', + 'neck', + 'left_collar', + 'right_collar', + 'head', + 'left_shoulder', + 'right_shoulder', + 'left_elbow', + 'right_elbow', + 'left_wrist', + 'right_wrist', +] + +NUM_HML_JOINTS = len(HML_JOINT_NAMES) # 22 SMPLH body joints + +HML_LOWER_BODY_JOINTS = [HML_JOINT_NAMES.index(name) for name in ['pelvis', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle', 'left_foot', 'right_foot',]] +SMPL_UPPER_BODY_JOINTS = [i for i in range(len(HML_JOINT_NAMES)) if i not in HML_LOWER_BODY_JOINTS] + + +# Recover global angle and positions for rotation data +# root_rot_velocity (B, seq_len, 1) +# root_linear_velocity (B, seq_len, 2) +# root_y (B, seq_len, 1) +# ric_data (B, seq_len, (joint_num - 1)*3) +# rot_data (B, seq_len, (joint_num - 1)*6) +# local_velocity (B, seq_len, joint_num*3) +# foot contact (B, seq_len, 4) +HML_ROOT_BINARY = np.array([True] + [False] * (NUM_HML_JOINTS-1)) +HML_ROOT_MASK = np.concatenate(([True]*(1+2+1), + HML_ROOT_BINARY[1:].repeat(3), + HML_ROOT_BINARY[1:].repeat(6), + HML_ROOT_BINARY.repeat(3), + [False] * 4)) +HML_LOWER_BODY_JOINTS_BINARY = np.array([i in HML_LOWER_BODY_JOINTS for i in range(NUM_HML_JOINTS)]) +HML_LOWER_BODY_MASK = np.concatenate(([True]*(1+2+1), + HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(3), + HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(6), + HML_LOWER_BODY_JOINTS_BINARY.repeat(3), + [True]*4)) +HML_UPPER_BODY_MASK = ~HML_LOWER_BODY_MASK \ No newline at end of file diff --git a/main/data_loaders/tensors.py b/main/data_loaders/tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..b176fa2a20112254a3f223bd47e5fb7e26b288f0 --- /dev/null +++ b/main/data_loaders/tensors.py @@ -0,0 +1,70 @@ +import pdb + +import torch + +def lengths_to_mask(lengths, max_len): + # max_len = max(lengths) + mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) + return mask + + +def collate_tensors(batch): + dims = batch[0].dim() + max_size = [max([b.size(i) for b in batch]) for i in range(dims)] + size = (len(batch),) + tuple(max_size) + canvas = batch[0].new_zeros(size=size) + for i, b in enumerate(batch): + sub_tensor = canvas[i] + for d in range(dims): + sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) + sub_tensor.add_(b) + return canvas + + +def collate(batch): + notnone_batches = [b for b in batch if b is not None] + databatch = [b['inp'] for b in notnone_batches] + if 'lengths' in notnone_batches[0]: + lenbatch = [b['lengths'] for b in notnone_batches] + else: + lenbatch = [len(b['inp'][0][0]) for b in notnone_batches] + + + databatchTensor = collate_tensors(databatch) + lenbatchTensor = torch.as_tensor(lenbatch) + maskbatchTensor = lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]).unsqueeze(1).unsqueeze(1) # unqueeze for broadcasting + + motion = databatchTensor + cond = {'y': {'mask': maskbatchTensor, 'lengths': lenbatchTensor}} + + if 'text' in notnone_batches[0]: + textbatch = [b['text'] for b in notnone_batches] + cond['y'].update({'text': textbatch}) + + if 'tokens' in notnone_batches[0]: + textbatch = [b['tokens'] for b in notnone_batches] + cond['y'].update({'tokens': textbatch}) + + if 'action' in notnone_batches[0]: + actionbatch = [b['action'] for b in notnone_batches] + cond['y'].update({'action': torch.as_tensor(actionbatch).unsqueeze(1)}) + + # collate action textual names + if 'action_text' in notnone_batches[0]: + action_text = [b['action_text']for b in notnone_batches] + cond['y'].update({'action_text': action_text}) + + return motion, cond + +# an adapter to our collate func +def t2m_collate(batch): + # batch.sort(key=lambda x: x[3], reverse=True) + adapted_batch = [{ + 'inp': torch.tensor(b[4].T).float().unsqueeze(1), # [seqlen, J] -> [J, 1, seqlen] + 'text': b[2], #b[0]['caption'] + 'tokens': b[6], + 'lengths': b[5], + } for b in batch] + return collate(adapted_batch) + + diff --git a/main/dataset/README.md b/main/dataset/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e51284f58429966777ca38b455822873dfa8b4fc --- /dev/null +++ b/main/dataset/README.md @@ -0,0 +1,6 @@ +## Data + +* Data dirs should be placed here. + +* The `opt` files are configurations for how to read the data according to [text-to-motion](https://github.com/EricGuo5513/text-to-motion). +* The `*_mean.npy` and `*_std.npy` files, are stats used for evaluation only, according to [text-to-motion](https://github.com/EricGuo5513/text-to-motion). \ No newline at end of file diff --git a/main/dataset/humanml_opt.txt b/main/dataset/humanml_opt.txt new file mode 100644 index 0000000000000000000000000000000000000000..718bce2887942d2237c8a561ec996792be329314 --- /dev/null +++ b/main/dataset/humanml_opt.txt @@ -0,0 +1,54 @@ +------------ Options ------------- +batch_size: 32 +checkpoints_dir: ./checkpoints +dataset_name: t2m +decomp_name: Decomp_SP001_SM001_H512 +dim_att_vec: 512 +dim_dec_hidden: 1024 +dim_movement2_dec_hidden: 512 +dim_movement_dec_hidden: 512 +dim_movement_enc_hidden: 512 +dim_movement_latent: 512 +dim_msd_hidden: 512 +dim_pos_hidden: 1024 +dim_pri_hidden: 1024 +dim_seq_de_hidden: 512 +dim_seq_en_hidden: 512 +dim_text_hidden: 512 +dim_z: 128 +early_stop_count: 3 +estimator_mod: bigru +eval_every_e: 5 +feat_bias: 5 +fixed_steps: 5 +gpu_id: 3 +input_z: False +is_continue: True +is_train: True +lambda_fake: 10 +lambda_gan_l: 0.1 +lambda_gan_mt: 0.1 +lambda_gan_mv: 0.1 +lambda_kld: 0.01 +lambda_rec: 1 +lambda_rec_init: 1 +lambda_rec_mot: 1 +lambda_rec_mov: 1 +log_every: 50 +lr: 0.0002 +max_sub_epoch: 50 +max_text_len: 20 +n_layers_dec: 1 +n_layers_msd: 2 +n_layers_pos: 1 +n_layers_pri: 1 +n_layers_seq_de: 2 +n_layers_seq_en: 1 +name: Comp_v6_KLD01 +num_experts: 4 +save_every_e: 10 +save_latest: 500 +text_enc_mod: bigru +tf_ratio: 0.4 +unit_length: 4 +-------------- End ---------------- diff --git a/main/dataset/kit_mean.npy b/main/dataset/kit_mean.npy new file mode 100644 index 0000000000000000000000000000000000000000..c1f076c473eaabf4e6c0144d3e6db8b6a3c7e976 --- /dev/null +++ b/main/dataset/kit_mean.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e23fac51db2215ab5666324226be48f27efd6a6e7b22ebd17c28e0f056a7c22 +size 2136 diff --git a/main/dataset/kit_opt.txt b/main/dataset/kit_opt.txt new file mode 100644 index 0000000000000000000000000000000000000000..805c8ecb40568f1d06c6b0b266d7bffc81d69b3c --- /dev/null +++ b/main/dataset/kit_opt.txt @@ -0,0 +1,54 @@ +------------ Options ------------- +batch_size: 32 +checkpoints_dir: ./checkpoints +dataset_name: kit +decomp_name: Decomp_SP001_SM001_H512 +dim_att_vec: 512 +dim_dec_hidden: 1024 +dim_movement2_dec_hidden: 512 +dim_movement_dec_hidden: 512 +dim_movement_enc_hidden: 512 +dim_movement_latent: 512 +dim_msd_hidden: 512 +dim_pos_hidden: 1024 +dim_pri_hidden: 1024 +dim_seq_de_hidden: 512 +dim_seq_en_hidden: 512 +dim_text_hidden: 512 +dim_z: 128 +early_stop_count: 3 +estimator_mod: bigru +eval_every_e: 5 +feat_bias: 5 +fixed_steps: 5 +gpu_id: 2 +input_z: False +is_continue: True +is_train: True +lambda_fake: 10 +lambda_gan_l: 0.1 +lambda_gan_mt: 0.1 +lambda_gan_mv: 0.1 +lambda_kld: 0.005 +lambda_rec: 1 +lambda_rec_init: 1 +lambda_rec_mot: 1 +lambda_rec_mov: 1 +log_every: 50 +lr: 0.0002 +max_sub_epoch: 50 +max_text_len: 20 +n_layers_dec: 1 +n_layers_msd: 2 +n_layers_pos: 1 +n_layers_pri: 1 +n_layers_seq_de: 2 +n_layers_seq_en: 1 +name: Comp_v6_KLD005 +num_experts: 4 +save_every_e: 10 +save_latest: 500 +text_enc_mod: bigru +tf_ratio: 0.4 +unit_length: 4 +-------------- End ---------------- diff --git a/main/dataset/kit_std.npy b/main/dataset/kit_std.npy new file mode 100644 index 0000000000000000000000000000000000000000..02a4c81095a331998ae0c95e3b01dc48c6d37b77 --- /dev/null +++ b/main/dataset/kit_std.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:296a60656cea07e65ee64512d73d47c0412df0698b35194116330661be32fa90 +size 2136 diff --git a/main/dataset/t2m_mean.npy b/main/dataset/t2m_mean.npy new file mode 100644 index 0000000000000000000000000000000000000000..6c57414d9cf6242bb4b4bab4c33df5e2cc9d2f91 --- /dev/null +++ b/main/dataset/t2m_mean.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bdb5ba69a3a9e34d71990db15bc535ebc024c8d95ddb5574196f96058faa7d3 +size 2232 diff --git a/main/dataset/t2m_std.npy b/main/dataset/t2m_std.npy new file mode 100644 index 0000000000000000000000000000000000000000..93c6b7ae4c2fa23dd21c10a27da1b6966168b35b --- /dev/null +++ b/main/dataset/t2m_std.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a5f7d60301c9465972fc225f8ad0ee8f957e7720431189123eb6d15873a9557 +size 2232 diff --git a/main/diffusion/fp16_util.py b/main/diffusion/fp16_util.py new file mode 100644 index 0000000000000000000000000000000000000000..1ccb93e4843b6257c3151b763356ef501f1acec8 --- /dev/null +++ b/main/diffusion/fp16_util.py @@ -0,0 +1,236 @@ +""" +Helpers to train with 16-bit precision. +""" + +import numpy as np +import torch as th +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from diffusion import logger + +INITIAL_LOG_LOSS_SCALE = 20.0 + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + +def make_master_params(param_groups_and_shapes): + """ + Copy model parameters into a (differently-shaped) list of full-precision + parameters. + """ + master_params = [] + for param_group, shape in param_groups_and_shapes: + master_param = nn.Parameter( + _flatten_dense_tensors( + [param.detach().float() for (_, param) in param_group] + ).view(shape) + ) + master_param.requires_grad = True + master_params.append(master_param) + return master_params + + +def model_grads_to_master_grads(param_groups_and_shapes, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + for master_param, (param_group, shape) in zip( + master_params, param_groups_and_shapes + ): + master_param.grad = _flatten_dense_tensors( + [param_grad_or_zeros(param) for (_, param) in param_group] + ).view(shape) + + +def master_params_to_model_params(param_groups_and_shapes, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + # Without copying to a list, if a generator is passed, this will + # silently not copy any parameters. + for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): + for (_, param), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + param.detach().copy_(unflat_master_param) + + +def unflatten_master_params(param_group, master_param): + return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) + + +def get_param_groups_and_shapes(named_model_params): + named_model_params = list(named_model_params) + scalar_vector_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim <= 1], + (-1), + ) + matrix_named_params = ( + [(n, p) for (n, p) in named_model_params if p.ndim > 1], + (1, -1), + ) + return [scalar_vector_named_params, matrix_named_params] + + +def master_params_to_state_dict( + model, param_groups_and_shapes, master_params, use_fp16 +): + if use_fp16: + state_dict = model.state_dict() + for master_param, (param_group, _) in zip( + master_params, param_groups_and_shapes + ): + for (name, _), unflat_master_param in zip( + param_group, unflatten_master_params(param_group, master_param.view(-1)) + ): + assert name in state_dict + state_dict[name] = unflat_master_param + else: + state_dict = model.state_dict() + for i, (name, _value) in enumerate(model.named_parameters()): + assert name in state_dict + state_dict[name] = master_params[i] + return state_dict + + +def state_dict_to_master_params(model, state_dict, use_fp16): + if use_fp16: + named_model_params = [ + (name, state_dict[name]) for name, _ in model.named_parameters() + ] + param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) + master_params = make_master_params(param_groups_and_shapes) + else: + master_params = [state_dict[name] for name, _ in model.named_parameters()] + return master_params + + +def zero_master_grads(master_params): + for param in master_params: + param.grad = None + + +def zero_grad(model_params): + for param in model_params: + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() + + +def param_grad_or_zeros(param): + if param.grad is not None: + return param.grad.data.detach() + else: + return th.zeros_like(param) + + +class MixedPrecisionTrainer: + def __init__( + self, + *, + model, + use_fp16=False, + fp16_scale_growth=1e-3, + initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, + ): + self.model = model + self.use_fp16 = use_fp16 + self.fp16_scale_growth = fp16_scale_growth + + self.model_params = list(self.model.parameters()) + self.master_params = self.model_params + self.param_groups_and_shapes = None + self.lg_loss_scale = initial_lg_loss_scale + + if self.use_fp16: + self.param_groups_and_shapes = get_param_groups_and_shapes( + self.model.named_parameters() + ) + self.master_params = make_master_params(self.param_groups_and_shapes) + self.model.convert_to_fp16() + + def zero_grad(self): + zero_grad(self.model_params) + + def backward(self, loss: th.Tensor): + if self.use_fp16: + loss_scale = 2 ** self.lg_loss_scale + (loss * loss_scale).backward() + else: + loss.backward() + + def optimize(self, opt: th.optim.Optimizer): + if self.use_fp16: + return self._optimize_fp16(opt) + else: + return self._optimize_normal(opt) + + def _optimize_fp16(self, opt: th.optim.Optimizer): + logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) + model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) + grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) + if check_overflow(grad_norm): + self.lg_loss_scale -= 1 + logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") + zero_master_grads(self.master_params) + return False + + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + + self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) + opt.step() + zero_master_grads(self.master_params) + master_params_to_model_params(self.param_groups_and_shapes, self.master_params) + self.lg_loss_scale += self.fp16_scale_growth + return True + + def _optimize_normal(self, opt: th.optim.Optimizer): + grad_norm, param_norm = self._compute_norms() + logger.logkv_mean("grad_norm", grad_norm) + logger.logkv_mean("param_norm", param_norm) + opt.step() + return True + + def _compute_norms(self, grad_scale=1.0): + grad_norm = 0.0 + param_norm = 0.0 + for p in self.master_params: + with th.no_grad(): + param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 + if p.grad is not None: + grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 + return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) + + def master_params_to_state_dict(self, master_params): + return master_params_to_state_dict( + self.model, self.param_groups_and_shapes, master_params, self.use_fp16 + ) + + def state_dict_to_master_params(self, state_dict): + return state_dict_to_master_params(self.model, state_dict, self.use_fp16) + + +def check_overflow(value): + return (value == float("inf")) or (value == -float("inf")) or (value != value) diff --git a/main/diffusion/gaussian_diffusion.py b/main/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..50d397c622ebea6372665eb22341b96ea8cf001d --- /dev/null +++ b/main/diffusion/gaussian_diffusion.py @@ -0,0 +1,1613 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py + +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" + +import enum +import math +import pdb + +import numpy as np +import torch +import torch as th +from copy import deepcopy +from diffusion.nn import mean_flat, sum_flat +from diffusion.losses import normal_kl, discretized_gaussian_log_likelihood +from data_loaders.humanml.scripts import motion_process + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, scale_betas=1.): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = scale_betas * 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + lambda_rcxyz=0., + lambda_vel=0., + lambda_pose=1., + lambda_orient=1., + lambda_loc=1., + data_rep='rot6d', + lambda_root_vel=0., + lambda_vel_rcxyz=0., + lambda_fc=0., + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + self.data_rep = data_rep + + if data_rep != 'rot_vel' and lambda_pose != 1.: + raise ValueError('lambda_pose is relevant only when training on velocities!') + self.lambda_pose = lambda_pose + self.lambda_orient = lambda_orient + self.lambda_loc = lambda_loc + + self.lambda_rcxyz = lambda_rcxyz + self.lambda_vel = lambda_vel + self.lambda_root_vel = lambda_root_vel + self.lambda_vel_rcxyz = lambda_vel_rcxyz + self.lambda_fc = lambda_fc + + if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \ + self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.: + assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!' + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. + self.smooth_l1_loss = th.nn.SmoothL1Loss(reduction='none') + + def masked_l2(self, a, b, mask): + # assuming a.shape == b.shape == bs, J, Jdim, seqlen + # assuming mask.shape == bs, 1, 1, seqlen + # loss = self.l2_loss(a, b) # 20221217 + loss = self.smooth_l1_loss(a, b) + loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements + n_entries = a.shape[1] * a.shape[2] + non_zero_elements = sum_flat(mask) * n_entries + # print('mask', mask.shape) + # print('non_zero_elements', non_zero_elements) + # print('loss', loss) + mse_loss_val = loss / non_zero_elements + # print('mse_loss_val', mse_loss_val) + return mse_loss_val + + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the dataset for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial dataset batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + + if 'inpainting_mask' in model_kwargs['y'].keys() and 'inpainted_motion' in model_kwargs['y'].keys(): + inpainting_mask, inpainted_motion = model_kwargs['y']['inpainting_mask'], model_kwargs['y']['inpainted_motion'] + assert self.model_mean_type == ModelMeanType.START_X, 'This feature supports only X_start pred for mow!' + assert model_output.shape == inpainting_mask.shape == inpainted_motion.shape + model_output = (model_output * ~inpainting_mask) + (inpainted_motion * inpainting_mask) + # print('model_output', model_output.shape, model_output) + # print('inpainting_mask', inpainting_mask.shape, inpainting_mask[0,0,0,:]) + # print('inpainted_motion', inpainted_motion.shape, inpainted_motion) + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + # print('model_variance', model_variance) + # print('model_log_variance',model_log_variance) + # print('self.posterior_variance', self.posterior_variance) + # print('self.posterior_log_variance_clipped', self.posterior_log_variance_clipped) + # print('self.model_var_type', self.model_var_type) + + + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + # print('clip_denoised', clip_denoised) + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: # THIS IS US! + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, p_mean_var, **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, self._scale_timesteps(t), **model_kwargs + ) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, t, p_mean_var, **model_kwargs + ) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + const_noise=False, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) # 'mean' ๏ผˆ1, 135, 1, 240๏ผ‰, 'variance', 'log_variance', 'pred_xstart' + noise = th.randn_like(x) + # print('const_noise', const_noise) + if const_noise: + noise = noise[[0]].repeat(x.shape[0], 1, 1, 1) + + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + # print('mean', out["mean"].shape, out["mean"]) + # print('log_variance', out["log_variance"].shape, out["log_variance"]) + # print('nonzero_mask', nonzero_mask.shape, nonzero_mask) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_with_grad( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + with th.enable_grad(): + x = x.detach().requires_grad_() + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean_with_grad( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + dump_steps=None, + const_noise=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :param const_noise: If True, will noise all samples with the same noise throughout sampling + :return: a non-differentiable batch of samples. + """ + final = None + if dump_steps is not None: + dump = [] + + for i, sample in enumerate(self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + const_noise=const_noise, + )): + if dump_steps is not None and i in dump_steps: + dump.append(deepcopy(sample["sample"])) + final = sample + if dump_steps is not None: + return dump + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + const_noise=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = th.randint(low=0, high=model.num_classes, + size=model_kwargs['y'].shape, + device=model_kwargs['y'].device) + with th.no_grad(): + sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample + out = sample_fn( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + const_noise=const_noise, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) + else: + out = out_orig + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} + + def ddim_sample_with_grad( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + with th.enable_grad(): + x = x.detach().requires_grad_() + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score_with_grad(cond_fn, out_orig, x, t, + model_kwargs=model_kwargs) + else: + out = out_orig + + out["pred_xstart"] = out["pred_xstart"].detach() + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + dump_steps=None, + const_noise=False, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + if dump_steps is not None: + raise NotImplementedError() + if const_noise == True: + raise NotImplementedError() + + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = th.randint(low=0, high=model.num_classes, + size=model_kwargs['y'].shape, + device=model_kwargs['y'].device) + with th.no_grad(): + sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample + out = sample_fn( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def plms_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + cond_fn_with_grad=False, + order=2, + old_out=None, + ): + """ + Sample x_{t-1} from the model using Pseudo Linear Multistep. + + Same usage as p_sample(). + """ + if not int(order) or not 1 <= order <= 4: + raise ValueError('order is invalid (should be int from 1-4).') + + def get_model_output(x, t): + with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): + x = x.detach().requires_grad_() if cond_fn_with_grad else x + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + if cond_fn_with_grad: + out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) + x = x.detach() + else: + out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) + else: + out = out_orig + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + return eps, out, out_orig + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + eps, out, out_orig = get_model_output(x, t) + + if order > 1 and old_out is None: + # Pseudo Improved Euler + old_eps = [eps] + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps + eps_2, _, _ = get_model_output(mean_pred, t - 1) + eps_prime = (eps + eps_2) / 2 + pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) + mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime + else: + # Pseudo Linear Multistep (Adams-Bashforth) + old_eps = old_out["old_eps"] + old_eps.append(eps) + cur_order = min(order, len(old_eps)) + if cur_order == 1: + eps_prime = old_eps[-1] + elif cur_order == 2: + eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 + elif cur_order == 3: + eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 + elif cur_order == 4: + eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 + else: + raise RuntimeError('cur_order is invalid.') + pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) + mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime + + if len(old_eps) >= order: + old_eps.pop(0) + + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) + + return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} + + def plms_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + order=2, + ): + """ + Generate samples from the model using Pseudo Linear Multistep. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.plms_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + order=order, + ): + final = sample + return final["sample"] + + def plms_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + order=2, + ): + """ + Use PLMS to sample from the model and yield intermediate samples from each + timestep of PLMS. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + old_out = None + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = th.randint(low=0, high=model.num_classes, + size=model_kwargs['y'].shape, + device=model_kwargs['y'].device) + with th.no_grad(): + out = self.plms_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + cond_fn_with_grad=cond_fn_with_grad, + order=order, + old_out=old_out, + ) + yield out + old_out = out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + + # enc = model.model._modules['module'] + enc = model.model + mask = model_kwargs['y']['mask'] + # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, + # glob=enc.glob, + # # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP + # jointstype='smpl', # 3.4 iter/sec + # vertstrans=False) + + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) # torch.Size([64, 251, 1, 196]), add noisy + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: # LossType.MSE + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ # ModelVarType.FIXED_SMALL: 2 + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] # ModelMeanType.START_X: 2 + assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] + + # pdb.set_trace() # target (2, 135, 1, 240) + + terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # [64, 251, 1, 196], -, [64, 1, 1, 196] + + target_xyz, model_output_xyz = None, None + + if self.lambda_rcxyz > 0.: # 0.0 + target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] + model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] + terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) + + if self.lambda_vel_rcxyz > 0.: # 0.0 + if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: + target_xyz = get_xyz(target) if target_xyz is None else target_xyz + model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz + target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) + model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) + terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) + + if self.lambda_fc > 0.: # 0.0 + torch.autograd.set_detect_anomaly(True) + if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: + target_xyz = get_xyz(target) if target_xyz is None else target_xyz + model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz + # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 + l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 + relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] + gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] + gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] + fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) + pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] + pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] + pred_vel[~fc_mask] = 0 + terms["fc"] = self.masked_l2(pred_vel, + torch.zeros(pred_vel.shape, device=pred_vel.device), + mask[:, :, :, 1:]) + if self.lambda_vel > 0.: # 0.0 + target_vel = (target[..., 1:] - target[..., :-1]) + model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) + terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! + model_output_vel[:, :-1, :, :], + mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) + + terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ + (self.lambda_vel * terms.get('vel_mse', 0.)) +\ + (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ + (self.lambda_fc * terms.get('fc', 0.)) + + else: + raise NotImplementedError(self.loss_type) + + return terms + + def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask): + def to_np_cpu(x): + return x.detach().cpu().numpy() + """ + pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames] + """ + # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 + + l_ankle_idx, r_ankle_idx = 7, 8 + l_foot_idx, r_foot_idx = 10, 11 + """ Contact calculated by 'Kfir Method' Commented code)""" + # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2] + # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames] + # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :]) + # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames] + # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames] + # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) + # + # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1) + # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2] + # left_z_mask[:, :, 1] = False # Blank right side + # contact_signal[left_z_mask] = 0.4 + # + # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1) + # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2] + # right_z_mask[:, :, 0] = False # Blank left side + # contact_signal[right_z_mask] = 0.4 + # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1 + # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1 + + # plt.plot(to_np_cpu(left_z[0]), label='left_z') + # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity') + # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc') + # plt.grid() + # plt.legend() + # plt.show() + # plt.plot(to_np_cpu(right_z[0]), label='right_z') + # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity') + # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc') + # plt.grid() + # plt.legend() + # plt.show() + + gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] + gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] + fc_mask = (gt_joint_vel <= 0.01) + pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] + pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] + pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES] + pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2) + + """DEBUG CODE""" + # print(f'mask: {mask.shape}') + # print(f'pred_joint_vel: {pred_joint_vel.shape}') + # plt.title(f'Joint: {joint_idx}') + # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity') + # plt.plot(to_np_cpu(fc_mask[0]), label='fc') + # plt.grid() + # plt.legend() + # plt.show() + return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device), + mask[:, :, :, 1:]) + # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! + def foot_contact_loss_humanml3d(self, target, model_output): + # root_rot_velocity (B, seq_len, 1) + # root_linear_velocity (B, seq_len, 2) + # root_y (B, seq_len, 1) + # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ + # rot_data (B, seq_len, (joint_num - 1)*6) , 6D + # local_velocity (B, seq_len, joint_num*3) , XYZ + # foot contact (B, seq_len, 4) , + + target_fc = target[:, -4:, :, :] + root_rot_velocity = target[:, :1, :, :] + root_linear_velocity = target[:, 1:3, :, :] + root_y = target[:, 3:4, :, :] + ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 + rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 + local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 + contact = target[:, 259:, :, :] # 193+(3*22)=259 + contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11] + vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :] + vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :] + vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :] + vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :] + + calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1] + calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1] + calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1] + calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1] + + # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1) + for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip( + [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11], + [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], + [7, 10, 8, 11], + [0, 1, 2, 3]): + tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int) + chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0) + chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)), + axis=0) + + print(tmp_mask_gt.shape) + print(chosen_vel_foot.shape) + print(chosen_vel_calc_norm.shape) + import matplotlib.pyplot as plt + plt.plot(tmp_mask_gt, label='FC mask') + plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)') + plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)') + + plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}') + plt.legend() + plt.show() + # print(vel_foots.shape) + return 0 + # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! + def velocity_consistency_loss_humanml3d(self, target, model_output): + # root_rot_velocity (B, seq_len, 1) + # root_linear_velocity (B, seq_len, 2) + # root_y (B, seq_len, 1) + # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ + # rot_data (B, seq_len, (joint_num - 1)*6) , 6D + # local_velocity (B, seq_len, joint_num*3) , XYZ + # foot contact (B, seq_len, 4) , + + target_fc = target[:, -4:, :, :] + root_rot_velocity = target[:, :1, :, :] + root_linear_velocity = target[:, 1:3, :, :] + root_y = target[:, 3:4, :, :] + ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 + rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 + local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 + contact = target[:, 259:, :, :] # 193+(3*22)=259 + + calc_vel_from_xyz = ric_data[:, :, :, 1:] - ric_data[:, :, :, :-1] + velocity_from_vector = local_velocity[:, 3:, :, 1:] # Slicing out root + r_rot_quat, r_pos = motion_process.recover_root_rot_pos(target.permute(0, 2, 3, 1).type(th.FloatTensor)) + print(f'r_rot_quat: {r_rot_quat.shape}') + print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape}') + calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 2, 3, 1) + calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21, 3)).type(th.FloatTensor) + r_rot_quat_adapted = r_rot_quat[..., :-1, None, :].repeat((1,1,1,21,1)).to(calc_vel_from_xyz.device) + print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') + print(f'r_rot_quat_adapted: {r_rot_quat_adapted.shape}, {r_rot_quat_adapted.device}') + + calc_vel_from_xyz = motion_process.qrot(r_rot_quat_adapted, calc_vel_from_xyz) + calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21 * 3)) + calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 3, 1, 2) + print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') + + import matplotlib.pyplot as plt + for i in range(21): + plt.plot(np.linalg.norm(calc_vel_from_xyz[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Calc Vel') + plt.plot(np.linalg.norm(velocity_from_vector[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Vector Vel') + plt.title(f'Joint idx: {i}') + plt.legend() + plt.show() + print(calc_vel_from_xyz.shape) + print(velocity_from_vector.shape) + diff = calc_vel_from_xyz-velocity_from_vector + print(np.linalg.norm(diff.cpu().detach().numpy().reshape((63, -1)), axis=0)) + + return 0 + + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/main/diffusion/logger.py b/main/diffusion/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d856dcfea6b56a2ee8d37b286887430dbfac30 --- /dev/null +++ b/main/diffusion/logger.py @@ -0,0 +1,495 @@ +""" +Logger copied from OpenAI baselines to avoid extra RL-based dependencies: +https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py +""" + +import os +import sys +import shutil +import os.path as osp +import json +import time +import datetime +import tempfile +import warnings +from collections import defaultdict +from contextlib import contextmanager + +DEBUG = 10 +INFO = 20 +WARN = 30 +ERROR = 40 + +DISABLED = 50 + + +class KVWriter(object): + def writekvs(self, kvs): + raise NotImplementedError + + +class SeqWriter(object): + def writeseq(self, seq): + raise NotImplementedError + + +class HumanOutputFormat(KVWriter, SeqWriter): + def __init__(self, filename_or_file): + if isinstance(filename_or_file, str): + self.file = open(filename_or_file, "wt") + self.own_file = True + else: + assert hasattr(filename_or_file, "read"), ( + "expected file or str, got %s" % filename_or_file + ) + self.file = filename_or_file + self.own_file = False + + def writekvs(self, kvs): + # Create strings for printing + key2str = {} + for (key, val) in sorted(kvs.items()): + if hasattr(val, "__float__"): + valstr = "%-8.3g" % val + else: + valstr = str(val) + key2str[self._truncate(key)] = self._truncate(valstr) + + # Find max widths + if len(key2str) == 0: + print("WARNING: tried to write empty key-value dict") + return + else: + keywidth = max(map(len, key2str.keys())) + valwidth = max(map(len, key2str.values())) + + # Write out the data + dashes = "-" * (keywidth + valwidth + 7) + lines = [dashes] + for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): + lines.append( + "| %s%s | %s%s |" + % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) + ) + lines.append(dashes) + self.file.write("\n".join(lines) + "\n") + + # Flush the output to the file + self.file.flush() + + def _truncate(self, s): + maxlen = 30 + return s[: maxlen - 3] + "..." if len(s) > maxlen else s + + def writeseq(self, seq): + seq = list(seq) + for (i, elem) in enumerate(seq): + self.file.write(elem) + if i < len(seq) - 1: # add space unless this is the last one + self.file.write(" ") + self.file.write("\n") + self.file.flush() + + def close(self): + if self.own_file: + self.file.close() + + +class JSONOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "wt") + + def writekvs(self, kvs): + for k, v in sorted(kvs.items()): + if hasattr(v, "dtype"): + kvs[k] = float(v) + self.file.write(json.dumps(kvs) + "\n") + self.file.flush() + + def close(self): + self.file.close() + + +class CSVOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "w+t") + self.keys = [] + self.sep = "," + + def writekvs(self, kvs): + # Add our current row to the history + extra_keys = list(kvs.keys() - self.keys) + extra_keys.sort() + if extra_keys: + self.keys.extend(extra_keys) + self.file.seek(0) + lines = self.file.readlines() + self.file.seek(0) + for (i, k) in enumerate(self.keys): + if i > 0: + self.file.write(",") + self.file.write(k) + self.file.write("\n") + for line in lines[1:]: + self.file.write(line[:-1]) + self.file.write(self.sep * len(extra_keys)) + self.file.write("\n") + for (i, k) in enumerate(self.keys): + if i > 0: + self.file.write(",") + v = kvs.get(k) + if v is not None: + self.file.write(str(v)) + self.file.write("\n") + self.file.flush() + + def close(self): + self.file.close() + + +class TensorBoardOutputFormat(KVWriter): + """ + Dumps key/value pairs into TensorBoard's numeric format. + """ + + def __init__(self, dir): + os.makedirs(dir, exist_ok=True) + self.dir = dir + self.step = 1 + prefix = "events" + path = osp.join(osp.abspath(dir), prefix) + import tensorflow as tf + from tensorflow.python import pywrap_tensorflow + from tensorflow.core.util import event_pb2 + from tensorflow.python.util import compat + + self.tf = tf + self.event_pb2 = event_pb2 + self.pywrap_tensorflow = pywrap_tensorflow + self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) + + def writekvs(self, kvs): + def summary_val(k, v): + kwargs = {"tag": k, "simple_value": float(v)} + return self.tf.Summary.Value(**kwargs) + + summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) + event = self.event_pb2.Event(wall_time=time.time(), summary=summary) + event.step = ( + self.step + ) # is there any reason why you'd want to specify the step? + self.writer.WriteEvent(event) + self.writer.Flush() + self.step += 1 + + def close(self): + if self.writer: + self.writer.Close() + self.writer = None + + +def make_output_format(format, ev_dir, log_suffix=""): + os.makedirs(ev_dir, exist_ok=True) + if format == "stdout": + return HumanOutputFormat(sys.stdout) + elif format == "log": + return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) + elif format == "json": + return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) + elif format == "csv": + return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) + elif format == "tensorboard": + return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) + else: + raise ValueError("Unknown format specified: %s" % (format,)) + + +# ================================================================ +# API +# ================================================================ + + +def logkv(key, val): + """ + Log a value of some diagnostic + Call this once for each diagnostic quantity, each iteration + If called many times, last value will be used. + """ + get_current().logkv(key, val) + + +def logkv_mean(key, val): + """ + The same as logkv(), but if called many times, values averaged. + """ + get_current().logkv_mean(key, val) + + +def logkvs(d): + """ + Log a dictionary of key-value pairs + """ + for (k, v) in d.items(): + logkv(k, v) + + +def dumpkvs(): + """ + Write all of the diagnostics from the current iteration + """ + return get_current().dumpkvs() + + +def getkvs(): + return get_current().name2val + + +def log(*args, level=INFO): + """ + Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). + """ + get_current().log(*args, level=level) + + +def debug(*args): + log(*args, level=DEBUG) + + +def info(*args): + log(*args, level=INFO) + + +def warn(*args): + log(*args, level=WARN) + + +def error(*args): + log(*args, level=ERROR) + + +def set_level(level): + """ + Set logging threshold on current logger. + """ + get_current().set_level(level) + + +def set_comm(comm): + get_current().set_comm(comm) + + +def get_dir(): + """ + Get directory that log files are being written to. + will be None if there is no output directory (i.e., if you didn't call start) + """ + return get_current().get_dir() + + +record_tabular = logkv +dump_tabular = dumpkvs + + +@contextmanager +def profile_kv(scopename): + logkey = "wait_" + scopename + tstart = time.time() + try: + yield + finally: + get_current().name2val[logkey] += time.time() - tstart + + +def profile(n): + """ + Usage: + @profile("my_func") + def my_func(): code + """ + + def decorator_with_name(func): + def func_wrapper(*args, **kwargs): + with profile_kv(n): + return func(*args, **kwargs) + + return func_wrapper + + return decorator_with_name + + +# ================================================================ +# Backend +# ================================================================ + + +def get_current(): + if Logger.CURRENT is None: + _configure_default_logger() + + return Logger.CURRENT + + +class Logger(object): + DEFAULT = None # A logger with no output files. (See right below class definition) + # So that you can still log to the terminal without setting up any output files + CURRENT = None # Current logger being used by the free functions above + + def __init__(self, dir, output_formats, comm=None): + self.name2val = defaultdict(float) # values this iteration + self.name2cnt = defaultdict(int) + self.level = INFO + self.dir = dir + self.output_formats = output_formats + self.comm = comm + + # Logging API, forwarded + # ---------------------------------------- + def logkv(self, key, val): + self.name2val[key] = val + + def logkv_mean(self, key, val): + oldval, cnt = self.name2val[key], self.name2cnt[key] + self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) + self.name2cnt[key] = cnt + 1 + + def dumpkvs(self): + if self.comm is None: + d = self.name2val + else: + d = mpi_weighted_mean( + self.comm, + { + name: (val, self.name2cnt.get(name, 1)) + for (name, val) in self.name2val.items() + }, + ) + if self.comm.rank != 0: + d["dummy"] = 1 # so we don't get a warning about empty dict + out = d.copy() # Return the dict for unit testing purposes + for fmt in self.output_formats: + if isinstance(fmt, KVWriter): + fmt.writekvs(d) + self.name2val.clear() + self.name2cnt.clear() + return out + + def log(self, *args, level=INFO): + if self.level <= level: + self._do_log(args) + + # Configuration + # ---------------------------------------- + def set_level(self, level): + self.level = level + + def set_comm(self, comm): + self.comm = comm + + def get_dir(self): + return self.dir + + def close(self): + for fmt in self.output_formats: + fmt.close() + + # Misc + # ---------------------------------------- + def _do_log(self, args): + for fmt in self.output_formats: + if isinstance(fmt, SeqWriter): + fmt.writeseq(map(str, args)) + + +def get_rank_without_mpi_import(): + # check environment variables here instead of importing mpi4py + # to avoid calling MPI_Init() when this module is imported + for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: + if varname in os.environ: + return int(os.environ[varname]) + return 0 + + +def mpi_weighted_mean(comm, local_name2valcount): + """ + Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 + Perform a weighted average over dicts that are each on a different node + Input: local_name2valcount: dict mapping key -> (value, count) + Returns: key -> mean + """ + all_name2valcount = comm.gather(local_name2valcount) + if comm.rank == 0: + name2sum = defaultdict(float) + name2count = defaultdict(float) + for n2vc in all_name2valcount: + for (name, (val, count)) in n2vc.items(): + try: + val = float(val) + except ValueError: + if comm.rank == 0: + warnings.warn( + "WARNING: tried to compute mean on non-float {}={}".format( + name, val + ) + ) + else: + name2sum[name] += val * count + name2count[name] += count + return {name: name2sum[name] / name2count[name] for name in name2sum} + else: + return {} + + +def configure(dir=None, format_strs=None, comm=None, log_suffix=""): + """ + If comm is provided, average all numerical stats across that comm + """ + if dir is None: + dir = os.getenv("OPENAI_LOGDIR") + if dir is None: + dir = osp.join( + tempfile.gettempdir(), + datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), + ) + assert isinstance(dir, str) + dir = os.path.expanduser(dir) + os.makedirs(os.path.expanduser(dir), exist_ok=True) + + rank = get_rank_without_mpi_import() + if rank > 0: + log_suffix = log_suffix + "-rank%03i" % rank + + if format_strs is None: + if rank == 0: + format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") + else: + format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") + format_strs = filter(None, format_strs) + output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] + + Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) + if output_formats: + log("Logging to %s" % dir) + + +def _configure_default_logger(): + configure() + Logger.DEFAULT = Logger.CURRENT + + +def reset(): + if Logger.CURRENT is not Logger.DEFAULT: + Logger.CURRENT.close() + Logger.CURRENT = Logger.DEFAULT + log("Reset logger") + + +@contextmanager +def scoped_configure(dir=None, format_strs=None, comm=None): + prevlogger = Logger.CURRENT + configure(dir=dir, format_strs=format_strs, comm=comm) + try: + yield + finally: + Logger.CURRENT.close() + Logger.CURRENT = prevlogger + diff --git a/main/diffusion/losses.py b/main/diffusion/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..e3fded1953584eaaf183f3d2399be545a5003e0a --- /dev/null +++ b/main/diffusion/losses.py @@ -0,0 +1,77 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +Helpers for various likelihood-based losses. These are ported from the original +Ho et al. diffusion models codebase: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py +""" + +import numpy as np +import torch as th + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/main/diffusion/nn.py b/main/diffusion/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..41c18e7dd3d8cae1e719638e87c27f718f6a94e6 --- /dev/null +++ b/main/diffusion/nn.py @@ -0,0 +1,197 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +Various utilities for neural networks. +""" + +import math + +import torch as th +import torch.nn as nn + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * th.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + +def sum_flat(tensor): + """ + Take the sum over all non-batch dimensions. + """ + return tensor.sum(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp( + -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(th.autograd.Function): + @staticmethod + @th.cuda.amp.custom_fwd + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_length = length + ctx.save_for_backward(*args) + with th.no_grad(): + output_tensors = ctx.run_function(*args[:length]) + return output_tensors + + @staticmethod + @th.cuda.amp.custom_bwd + def backward(ctx, *output_grads): + args = list(ctx.saved_tensors) + + # Filter for inputs that require grad. If none, exit early. + input_indices = [i for (i, x) in enumerate(args) if x.requires_grad] + if not input_indices: + return (None, None) + tuple(None for _ in args) + + with th.enable_grad(): + for i in input_indices: + if i < ctx.input_length: + # Not sure why the OAI code does this little + # dance. It might not be necessary. + args[i] = args[i].detach().requires_grad_() + args[i] = args[i].view_as(args[i]) + output_tensors = ctx.run_function(*args[:ctx.input_length]) + + if isinstance(output_tensors, th.Tensor): + output_tensors = [output_tensors] + + # Filter for outputs that require grad. If none, exit early. + out_and_grads = [(o, g) for (o, g) in zip(output_tensors, output_grads) if o.requires_grad] + if not out_and_grads: + return (None, None) + tuple(None for _ in args) + + # Compute gradients on the filtered tensors. + computed_grads = th.autograd.grad( + [o for (o, g) in out_and_grads], + [args[i] for i in input_indices], + [g for (o, g) in out_and_grads] + ) + + # Reassemble the complete gradient tuple. + input_grads = [None for _ in args] + for (i, g) in zip(input_indices, computed_grads): + input_grads[i] = g + return (None, None) + tuple(input_grads) diff --git a/main/diffusion/resample.py b/main/diffusion/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..c82eccdcd47c468d41e7cbe02de6a731f2c9bf81 --- /dev/null +++ b/main/diffusion/resample.py @@ -0,0 +1,154 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + + Sub-classes should override this method to update the reweighting + using losses from the model. + + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/main/diffusion/respace.py b/main/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..13a3c0667029b75aa82202ef709fc7cb2fb337f4 --- /dev/null +++ b/main/diffusion/respace.py @@ -0,0 +1,129 @@ +# This code is based on https://github.com/openai/guided-diffusion +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.rescale_timesteps, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/main/eval/a2m/__init__.py b/main/eval/a2m/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/main/eval/a2m/action2motion/accuracy.py b/main/eval/a2m/action2motion/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..31ac61bc1a2791587e78febb64f2df73227c2111 --- /dev/null +++ b/main/eval/a2m/action2motion/accuracy.py @@ -0,0 +1,14 @@ +import torch + + +def calculate_accuracy(model, motion_loader, num_labels, classifier, device): + confusion = torch.zeros(num_labels, num_labels, dtype=torch.long) + with torch.no_grad(): + for batch in motion_loader: + batch_prob = classifier(batch["output_xyz"], lengths=batch["lengths"]) + batch_pred = batch_prob.max(dim=1).indices + for label, pred in zip(batch["y"], batch_pred): + confusion[label][pred] += 1 + + accuracy = torch.trace(confusion)/torch.sum(confusion) + return accuracy.item(), confusion diff --git a/main/eval/a2m/action2motion/diversity.py b/main/eval/a2m/action2motion/diversity.py new file mode 100644 index 0000000000000000000000000000000000000000..c20110803756bb60fb299a2f773608ccea4faeca --- /dev/null +++ b/main/eval/a2m/action2motion/diversity.py @@ -0,0 +1,66 @@ +import torch +import numpy as np + + +#adapted from action2motion +def calculate_diversity(activations): + diversity_times = 200 + num_motions = len(activations) + + diversity = 0 + + first_indices = np.random.randint(0, num_motions, diversity_times) + second_indices = np.random.randint(0, num_motions, diversity_times) + for first_idx, second_idx in zip(first_indices, second_indices): + diversity += torch.dist(activations[first_idx, :], + activations[second_idx, :]) + diversity /= diversity_times + return diversity + +# from action2motion +def calculate_diversity_multimodality(activations, labels, num_labels, unconstrained = False): + diversity_times = 200 + multimodality_times = 20 + if not unconstrained: + labels = labels.long() + num_motions = activations.shape[0] # len(labels) + + diversity = 0 + + first_indices = np.random.randint(0, num_motions, diversity_times) + second_indices = np.random.randint(0, num_motions, diversity_times) + for first_idx, second_idx in zip(first_indices, second_indices): + diversity += torch.dist(activations[first_idx, :], + activations[second_idx, :]) + diversity /= diversity_times + + if not unconstrained: + multimodality = 0 + label_quotas = np.zeros(num_labels) + label_quotas[labels.unique()] = multimodality_times # if a label does not appear in batch, its quota remains zero + while np.any(label_quotas > 0): + # print(label_quotas) + first_idx = np.random.randint(0, num_motions) + first_label = labels[first_idx] + if not label_quotas[first_label]: + continue + + second_idx = np.random.randint(0, num_motions) + second_label = labels[second_idx] + while first_label != second_label: + second_idx = np.random.randint(0, num_motions) + second_label = labels[second_idx] + + label_quotas[first_label] -= 1 + + first_activation = activations[first_idx, :] + second_activation = activations[second_idx, :] + multimodality += torch.dist(first_activation, + second_activation) + + multimodality /= (multimodality_times * num_labels) + else: + multimodality = torch.tensor(np.nan) + + return diversity.item(), multimodality.item() + diff --git a/main/eval/a2m/action2motion/evaluate.py b/main/eval/a2m/action2motion/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..75f1325ea4fd3d2acecc868905bf93a7c40e27cd --- /dev/null +++ b/main/eval/a2m/action2motion/evaluate.py @@ -0,0 +1,84 @@ +import torch +import numpy as np +from .models import load_classifier, load_classifier_for_fid +from .accuracy import calculate_accuracy +from .fid import calculate_fid +from .diversity import calculate_diversity_multimodality + + +class A2MEvaluation: + def __init__(self, device): + dataset_opt = {"input_size_raw": 72, "joints_num": 24, "num_classes": 12} + + self.input_size_raw = dataset_opt["input_size_raw"] + self.num_classes = dataset_opt["num_classes"] + self.device = device + + self.gru_classifier_for_fid = load_classifier_for_fid(self.input_size_raw, self.num_classes, device).eval() + self.gru_classifier = load_classifier(self.input_size_raw, self.num_classes, device).eval() + + def compute_features(self, model, motionloader): + # calculate_activations_labels function from action2motion + activations = [] + labels = [] + with torch.no_grad(): + for idx, batch in enumerate(motionloader): + activations.append(self.gru_classifier_for_fid(batch["output_xyz"], lengths=batch["lengths"])) + if model.cond_mode != 'no_cond': + labels.append(batch["y"]) + activations = torch.cat(activations, dim=0) + if model.cond_mode != 'no_cond': + labels = torch.cat(labels, dim=0) + return activations, labels + + @staticmethod + def calculate_activation_statistics(activations): + activations = activations.cpu().numpy() + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return mu, sigma + + def evaluate(self, model, loaders): + + def print_logs(metric, key): + print(f"Computing action2motion {metric} on the {key} loader ...") + + metrics = {} + + computedfeats = {} + for key, loader in loaders.items(): + metric = "accuracy" + print_logs(metric, key) + mkey = f"{metric}_{key}" + if model.cond_mode != 'no_cond': + metrics[mkey], _ = calculate_accuracy(model, loader, + self.num_classes, + self.gru_classifier, self.device) + else: + metrics[mkey] = np.nan + + # features for diversity + print_logs("features", key) + feats, labels = self.compute_features(model, loader) + print_logs("stats", key) + stats = self.calculate_activation_statistics(feats) + + computedfeats[key] = {"feats": feats, + "labels": labels, + "stats": stats} + + print_logs("diversity", key) + ret = calculate_diversity_multimodality(feats, labels, self.num_classes, unconstrained=(model.cond_mode=='no_cond')) + metrics[f"diversity_{key}"], metrics[f"multimodality_{key}"] = ret + + # taking the stats of the ground truth and remove it from the computed feats + gtstats = computedfeats["gt"]["stats"] + # computing fid + for key, loader in computedfeats.items(): + metric = "fid" + mkey = f"{metric}_{key}" + + stats = computedfeats[key]["stats"] + metrics[mkey] = float(calculate_fid(gtstats, stats)) + + return metrics diff --git a/main/eval/a2m/action2motion/fid.py b/main/eval/a2m/action2motion/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..4302e6b2808e87e3d9d4d8080db89d1b03e45f85 --- /dev/null +++ b/main/eval/a2m/action2motion/fid.py @@ -0,0 +1,61 @@ +import numpy as np +from scipy import linalg + + +# from action2motion +def calculate_fid(statistics_1, statistics_2): + return calculate_frechet_distance(statistics_1[0], statistics_1[1], + statistics_2[0], statistics_2[1]) + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) diff --git a/main/eval/a2m/action2motion/models.py b/main/eval/a2m/action2motion/models.py new file mode 100644 index 0000000000000000000000000000000000000000..4299ce0d4b6e006e2d169acf1b26d03f2c0ca258 --- /dev/null +++ b/main/eval/a2m/action2motion/models.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn + + +# adapted from action2motion to take inputs of different lengths +class MotionDiscriminator(nn.Module): + def __init__(self, input_size, hidden_size, hidden_layer, device, output_size=12, use_noise=None): + super(MotionDiscriminator, self).__init__() + self.device = device + + self.input_size = input_size + self.hidden_size = hidden_size + self.hidden_layer = hidden_layer + self.use_noise = use_noise + + self.recurrent = nn.GRU(input_size, hidden_size, hidden_layer) + self.linear1 = nn.Linear(hidden_size, 30) + self.linear2 = nn.Linear(30, output_size) + + def forward(self, motion_sequence, lengths=None, hidden_unit=None): + # dim (motion_length, num_samples, hidden_size) + bs, njoints, nfeats, num_frames = motion_sequence.shape + motion_sequence = motion_sequence.reshape(bs, njoints*nfeats, num_frames) + motion_sequence = motion_sequence.permute(2, 0, 1) + if hidden_unit is None: + # motion_sequence = motion_sequence.permute(1, 0, 2) + hidden_unit = self.initHidden(motion_sequence.size(1), self.hidden_layer) + gru_o, _ = self.recurrent(motion_sequence.float(), hidden_unit) + + # select the last valid, instead of: gru_o[-1, :, :] + out = gru_o[tuple(torch.stack((lengths-1, torch.arange(bs, device=self.device))))] + + # dim (num_samples, 30) + lin1 = self.linear1(out) + lin1 = torch.tanh(lin1) + # dim (num_samples, output_size) + lin2 = self.linear2(lin1) + return lin2 + + def initHidden(self, num_samples, layer): + return torch.randn(layer, num_samples, self.hidden_size, device=self.device, requires_grad=False) + + +class MotionDiscriminatorForFID(MotionDiscriminator): + def forward(self, motion_sequence, lengths=None, hidden_unit=None): + # dim (motion_length, num_samples, hidden_size) + bs, njoints, nfeats, num_frames = motion_sequence.shape + motion_sequence = motion_sequence.reshape(bs, njoints*nfeats, num_frames) + motion_sequence = motion_sequence.permute(2, 0, 1) + if hidden_unit is None: + # motion_sequence = motion_sequence.permute(1, 0, 2) + hidden_unit = self.initHidden(motion_sequence.size(1), self.hidden_layer) + gru_o, _ = self.recurrent(motion_sequence.float(), hidden_unit) + + # select the last valid, instead of: gru_o[-1, :, :] + out = gru_o[tuple(torch.stack((lengths-1, torch.arange(bs, device=self.device))))] + + # dim (num_samples, 30) + lin1 = self.linear1(out) + lin1 = torch.tanh(lin1) + return lin1 + + +model_path = "./assets/actionrecognition/humanact12_gru.tar" + + +def load_classifier(input_size_raw, num_classes, device): + model = torch.load(model_path, map_location=device) + classifier = MotionDiscriminator(input_size_raw, 128, 2, device=device, output_size=num_classes).to(device) + classifier.load_state_dict(model["model"]) + classifier.eval() + return classifier + + +def load_classifier_for_fid(input_size_raw, num_classes, device): + model = torch.load(model_path, map_location=device) + classifier = MotionDiscriminatorForFID(input_size_raw, 128, 2, device=device, output_size=num_classes).to(device) + classifier.load_state_dict(model["model"]) + classifier.eval() + return classifier + + +def test(): + from src.datasets.ntu13 import NTU13 + import src.utils.fixseed # noqa + + classifier = load_classifier("ntu13", input_size_raw=54, num_classes=13, device="cuda").eval() + params = {"pose_rep": "rot6d", + "translation": True, + "glob": True, + "jointstype": "a2m", + "vertstrans": True, + "num_frames": 60, + "sampling": "conseq", + "sampling_step": 1} + dataset = NTU13(**params) + + from src.models.rotation2xyz import Rotation2xyz + rot2xyz = Rotation2xyz(device="cuda") + confusion_xyz = torch.zeros(13, 13, dtype=torch.long) + confusion = torch.zeros(13, 13, dtype=torch.long) + + for i in range(1000): + dataset.pose_rep = "xyz" + data = dataset[i][0].to("cuda") + data = data[None] + + dataset.pose_rep = params["pose_rep"] + x = dataset[i][0].to("cuda")[None] + mask = torch.ones(1, x.shape[-1], dtype=bool, device="cuda") + lengths = mask.sum(1) + + xyz_t = rot2xyz(x, mask, **params) + + predicted_cls_xyz = classifier(data, lengths=lengths).argmax().item() + predicted_cls = classifier(xyz_t, lengths=lengths).argmax().item() + + gt_cls = dataset[i][1] + + confusion_xyz[gt_cls][predicted_cls_xyz] += 1 + confusion[gt_cls][predicted_cls] += 1 + + accuracy_xyz = torch.trace(confusion_xyz)/torch.sum(confusion_xyz).item() + accuracy = torch.trace(confusion)/torch.sum(confusion).item() + + print(f"accuracy: {accuracy:.1%}, accuracy_xyz: {accuracy_xyz:.1%}") + + +if __name__ == "__main__": + test() diff --git a/main/eval/a2m/gru_eval.py b/main/eval/a2m/gru_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..3226739972064de8b1942fdd4e8a00047c903be9 --- /dev/null +++ b/main/eval/a2m/gru_eval.py @@ -0,0 +1,131 @@ +import copy +import os + +import numpy as np +from tqdm import tqdm +import torch +import functools +from torch.utils.data import DataLoader + +from utils.fixseed import fixseed +from data_loaders.tensors import collate +from eval.a2m.action2motion.evaluate import A2MEvaluation +from eval.unconstrained.evaluate import evaluate_unconstrained_metrics +from .tools import save_metrics, format_metrics +from utils import dist_util + +num_samples_unconstrained = 1000 + +class NewDataloader: + def __init__(self, mode, model, diffusion, dataiterator, device, unconstrained, num_samples: int=-1): + assert mode in ["gen", "gt"] + self.batches = [] + sample_fn = diffusion.p_sample_loop + with torch.no_grad(): + for motions, model_kwargs in tqdm(dataiterator, desc=f"Construct dataloader: {mode}.."): + motions = motions.to(device) + if num_samples != -1 and len(self.batches) * dataiterator.batch_size > num_samples: + continue # do not break because it confuses the multiple loaders + batch = dict() + if mode == "gen": + sample = sample_fn(model, motions.shape, clip_denoised=False, model_kwargs=model_kwargs) + batch['output'] = sample + elif mode == "gt": + batch["output"] = motions + + # mask = torch.ones([batch["output"].shape[0], batch["output"].shape[-1]], dtype=bool).to(device) # batch_size x num_frames + max_n_frames = model_kwargs['y']['lengths'].max() + mask = model_kwargs['y']['mask'].reshape(dataiterator.batch_size, max_n_frames).bool() + batch["output_xyz"] = model.rot2xyz(x=batch["output"], mask=mask, pose_rep='rot6d', glob=True, + translation=True, jointstype='smpl', vertstrans=True, betas=None, + beta=0, glob_rot=None, get_rotations_back=False) + batch["lengths"] = model_kwargs['y']['lengths'].to(device) + if unconstrained: # proceed only if not running unconstrained + batch["y"] = model_kwargs['y']['action'].squeeze().long().cpu() # using torch.long so lengths/action will be used as indices + self.batches.append(batch) + + num_samples_last_batch = num_samples % dataiterator.batch_size + if num_samples_last_batch > 0: + for k, v in self.batches[-1].items(): + self.batches[-1][k] = v[:num_samples_last_batch] + + def __iter__(self): + return iter(self.batches) + +def evaluate(args, model, diffusion, data): + num_frames = 60 + + # fix parameters for action2motion evaluation + args.num_frames = num_frames + args.jointstype = "smpl" + args.vertstrans = True + + device = dist_util.dev() + + model.eval() + + a2mevaluation = A2MEvaluation(device=device) + a2mmetrics = {} + + datasetGT1 = copy.deepcopy(data) + datasetGT2 = copy.deepcopy(data) + + allseeds = list(range(args.num_seeds)) + + try: + for index, seed in enumerate(allseeds): + print(f"Evaluation number: {index+1}/{args.num_seeds}") + fixseed(seed) + + datasetGT1.reset_shuffle() + datasetGT1.shuffle() + + datasetGT2.reset_shuffle() + datasetGT2.shuffle() + + dataiterator = DataLoader(datasetGT1, batch_size=args.batch_size, + shuffle=False, num_workers=8, collate_fn=collate) + dataiterator2 = DataLoader(datasetGT2, batch_size=args.batch_size, + shuffle=False, num_workers=8, collate_fn=collate) + + new_data_loader = functools.partial(NewDataloader, model=model, diffusion=diffusion, device=device, + unconstrained=args.unconstrained, num_samples=args.num_samples) + motionloader = new_data_loader(mode="gen", dataiterator=dataiterator) + gt_motionloader = new_data_loader("gt", dataiterator=dataiterator) + gt_motionloader2 = new_data_loader("gt", dataiterator=dataiterator2) + + # Action2motionEvaluation + loaders = {"gen": motionloader, + "gt": gt_motionloader, + "gt2": gt_motionloader2} + + a2mmetrics[seed] = a2mevaluation.evaluate(model, loaders) + + del loaders + + if args.unconstrained: # unconstrained + dataset_unconstrained = copy.deepcopy(data) + dataset_unconstrained.reset_shuffle() + dataset_unconstrained.shuffle() + dataiterator_unconstrained = DataLoader(dataset_unconstrained, batch_size=args.batch_size, + shuffle=False, num_workers=8, collate_fn=collate) + motionloader_unconstrained = new_data_loader(mode="gen", dataiterator=dataiterator_unconstrained, num_samples=num_samples_unconstrained) + + generated_motions = [] + for motion in motionloader_unconstrained: + idx = [15, 12, 16, 18, 20, 17, 19, 21, 0, 1, 4, 7, 2, 5, 8] + motion = motion['output_xyz'][:, idx, :, :] + generated_motions.append(motion.cpu().numpy()) + generated_motions = np.concatenate(generated_motions) + unconstrained_metrics = evaluate_unconstrained_metrics(generated_motions, device, fast=True) + unconstrained_metrics = {k+'_unconstrained': v for k, v in unconstrained_metrics.items()} + + except KeyboardInterrupt: + string = "Saving the evaluation before exiting.." + print(string) + + metrics = {"feats": {key: [format_metrics(a2mmetrics[seed])[key] for seed in a2mmetrics.keys()] for key in a2mmetrics[allseeds[0]]}} + if args.unconstrained: + metrics["feats"] = {**metrics["feats"], **unconstrained_metrics} + + return metrics diff --git a/main/eval/a2m/recognition/models/stgcn.py b/main/eval/a2m/recognition/models/stgcn.py new file mode 100644 index 0000000000000000000000000000000000000000..1784ea247bf18484a11eb8663575b8d3153a58a7 --- /dev/null +++ b/main/eval/a2m/recognition/models/stgcn.py @@ -0,0 +1,219 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .stgcnutils.tgcn import ConvTemporalGraphical +from .stgcnutils.graph import Graph + +__all__ = ["STGCN"] + + +class STGCN(nn.Module): + r"""Spatial temporal graph convolutional networks. + Args: + in_channels (int): Number of channels in the input data + num_class (int): Number of classes for the classification task + graph_args (dict): The arguments for building the graph + edge_importance_weighting (bool): If ``True``, adds a learnable + importance weighting to the edges of the graph + **kwargs (optional): Other parameters for graph convolution units + Shape: + - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})` + - Output: :math:`(N, num_class)` where + :math:`N` is a batch size, + :math:`T_{in}` is a length of input sequence, + :math:`V_{in}` is the number of graph nodes, + :math:`M_{in}` is the number of instance in a frame. + """ + + def __init__(self, in_channels, num_class, graph_args, + edge_importance_weighting, device, **kwargs): + super().__init__() + + self.device = device + self.num_class = num_class + + self.losses = ["accuracy", "cross_entropy", "mixed"] + self.criterion = torch.nn.CrossEntropyLoss(reduction='mean') + + # load graph + self.graph = Graph(**graph_args) + A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) + self.register_buffer('A', A) + + # build networks + spatial_kernel_size = A.size(0) + temporal_kernel_size = 9 + kernel_size = (temporal_kernel_size, spatial_kernel_size) + self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) + kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} + self.st_gcn_networks = nn.ModuleList(( + st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0), + st_gcn(64, 64, kernel_size, 1, **kwargs), + st_gcn(64, 64, kernel_size, 1, **kwargs), + st_gcn(64, 64, kernel_size, 1, **kwargs), + st_gcn(64, 128, kernel_size, 2, **kwargs), + st_gcn(128, 128, kernel_size, 1, **kwargs), + st_gcn(128, 128, kernel_size, 1, **kwargs), + st_gcn(128, 256, kernel_size, 2, **kwargs), + st_gcn(256, 256, kernel_size, 1, **kwargs), + st_gcn(256, 256, kernel_size, 1, **kwargs), + )) + + # initialize parameters for edge importance weighting + if edge_importance_weighting: + self.edge_importance = nn.ParameterList([ + nn.Parameter(torch.ones(self.A.size())) + for i in self.st_gcn_networks + ]) + else: + self.edge_importance = [1] * len(self.st_gcn_networks) + + # fcn for prediction + self.fcn = nn.Conv2d(256, num_class, kernel_size=1) + + def forward(self, batch): + # TODO: use mask + # Received batch["x"] as + # Batch(48), Joints(23), Quat(4), Time(157 + # Expecting: + # Batch, Quat:4, Time, Joints, 1 + x = batch["output"].permute(0, 2, 3, 1).unsqueeze(4).contiguous() + + # data normalization + N, C, T, V, M = x.size() + x = x.permute(0, 4, 3, 1, 2).contiguous() + x = x.view(N * M, V * C, T) + x = self.data_bn(x) + x = x.view(N, M, V, C, T) + x = x.permute(0, 1, 3, 4, 2).contiguous() + x = x.view(N * M, C, T, V) + + # forward + for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): + x, _ = gcn(x, self.A * importance) + + # compute feature + # _, c, t, v = x.size() + # features = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1) + # batch["features"] = features + + # global pooling + x = F.avg_pool2d(x, x.size()[2:]) + x = x.view(N, M, -1, 1, 1).mean(dim=1) + + # features + batch["features"] = x.squeeze() + + # prediction + x = self.fcn(x) + x = x.view(x.size(0), -1) + batch["yhat"] = x + return batch + + def compute_accuracy(self, batch): + confusion = torch.zeros(self.num_class, self.num_class, dtype=int) + yhat = batch["yhat"].max(dim=1).indices + ygt = batch["y"] + for label, pred in zip(ygt, yhat): + confusion[label][pred] += 1 + accuracy = torch.trace(confusion)/torch.sum(confusion) + return accuracy + + def compute_loss(self, batch): + cross_entropy = self.criterion(batch["yhat"], batch["y"]) + mixed_loss = cross_entropy + + acc = self.compute_accuracy(batch) + losses = {"cross_entropy": cross_entropy.item(), + "mixed": mixed_loss.item(), + "accuracy": acc.item()} + return mixed_loss, losses + + +class st_gcn(nn.Module): + r"""Applies a spatial temporal graph convolution over an input graph sequence. + Args: + in_channels (int): Number of channels in the input sequence data + out_channels (int): Number of channels produced by the convolution + kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel + stride (int, optional): Stride of the temporal convolution. Default: 1 + dropout (int, optional): Dropout rate of the final output. Default: 0 + residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True`` + Shape: + - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format + - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format + - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format + - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format + where + :math:`N` is a batch size, + :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, + :math:`T_{in}/T_{out}` is a length of input/output sequence, + :math:`V` is the number of graph nodes. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dropout=0, + residual=True): + super().__init__() + + assert len(kernel_size) == 2 + assert kernel_size[0] % 2 == 1 + padding = ((kernel_size[0] - 1) // 2, 0) + + self.gcn = ConvTemporalGraphical(in_channels, out_channels, + kernel_size[1]) + + self.tcn = nn.Sequential( + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + out_channels, + out_channels, + (kernel_size[0], 1), + (stride, 1), + padding, + ), + nn.BatchNorm2d(out_channels), + nn.Dropout(dropout, inplace=True), + ) + + if not residual: + self.residual = lambda x: 0 + + elif (in_channels == out_channels) and (stride == 1): + self.residual = lambda x: x + + else: + self.residual = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=(stride, 1)), + nn.BatchNorm2d(out_channels), + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, A): + + res = self.residual(x) + x, A = self.gcn(x, A) + x = self.tcn(x) + res + + return self.relu(x), A + + +if __name__ == "__main__": + model = STGCN(in_channels=3, num_class=60, edge_importance_weighting=True, graph_args={"layout": "smpl_noglobal", "strategy": "spatial"}) + # Batch, in_channels, time, vertices, M + inp = torch.rand(10, 3, 16, 23, 1) + out = model(inp) + print(out.shape) + import pdb + pdb.set_trace() diff --git a/main/eval/a2m/recognition/models/stgcnutils/graph.py b/main/eval/a2m/recognition/models/stgcnutils/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..35ea9461f5cf2730cfaa9f07d35eea6e31a49c11 --- /dev/null +++ b/main/eval/a2m/recognition/models/stgcnutils/graph.py @@ -0,0 +1,178 @@ +import numpy as np +import pickle as pkl + +from utils.config import SMPL_KINTREE_PATH + + +class Graph: + """ The Graph to model the skeletons extracted by the openpose + Args: + strategy (string): must be one of the follow candidates + - uniform: Uniform Labeling + - distance: Distance Partitioning + - spatial: Spatial Configuration + For more information, please refer to the section 'Partition Strategies' + in our paper (https://arxiv.org/abs/1801.07455). + layout (string): must be one of the follow candidates + - openpose: Is consists of 18 joints. For more information, please + refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output + - ntu-rgb+d: Is consists of 25 joints. For more information, please + refer to https://github.com/shahroudy/NTURGB-D + - smpl: Consists of 24/23 joints with without global rotation. + max_hop (int): the maximal distance between two connected nodes + dilation (int): controls the spacing between the kernel points + """ + + def __init__(self, + layout='openpose', + strategy='uniform', + kintree_path=SMPL_KINTREE_PATH, + max_hop=1, + dilation=1): + self.max_hop = max_hop + self.dilation = dilation + + self.kintree_path = kintree_path + + self.get_edge(layout) + self.hop_dis = get_hop_distance( + self.num_node, self.edge, max_hop=max_hop) + self.get_adjacency(strategy) + + def __str__(self): + return self.A + + def get_edge(self, layout): + if layout == 'openpose': + self.num_node = 18 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, + 11), + (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1), + (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)] + self.edge = self_link + neighbor_link + self.center = 1 + elif layout == 'smpl': + self.num_node = 24 + self_link = [(i, i) for i in range(self.num_node)] + kt = pkl.load(open(self.kintree_path, "rb")) + neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])] + self.edge = self_link + neighbor_link + self.center = 0 + elif layout == 'smpl_noglobal': + self.num_node = 23 + self_link = [(i, i) for i in range(self.num_node)] + kt = pkl.load(open(self.kintree_path, "rb")) + neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])] + # remove the root joint + neighbor_1base = [n for n in neighbor_link if n[0] != 0 and n[1] != 0] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 0 + elif layout == 'ntu-rgb+d': + self.num_node = 25 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), + (6, 5), (7, 6), (8, 7), (9, 21), (10, 9), + (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), + (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), + (22, 23), (23, 8), (24, 25), (25, 12)] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 21 - 1 + elif layout == 'ntu_edge': + self.num_node = 24 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6), + (8, 7), (9, 2), (10, 9), (11, 10), (12, 11), + (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), + (18, 17), (19, 18), (20, 19), (21, 22), (22, 8), + (23, 24), (24, 12)] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 2 + # elif layout=='customer settings' + # pass + else: + raise NotImplementedError("This Layout is not supported") + + def get_adjacency(self, strategy): + valid_hop = range(0, self.max_hop + 1, self.dilation) + adjacency = np.zeros((self.num_node, self.num_node)) + for hop in valid_hop: + adjacency[self.hop_dis == hop] = 1 + normalize_adjacency = normalize_digraph(adjacency) + + if strategy == 'uniform': + A = np.zeros((1, self.num_node, self.num_node)) + A[0] = normalize_adjacency + self.A = A + elif strategy == 'distance': + A = np.zeros((len(valid_hop), self.num_node, self.num_node)) + for i, hop in enumerate(valid_hop): + A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop] + self.A = A + elif strategy == 'spatial': + A = [] + for hop in valid_hop: + a_root = np.zeros((self.num_node, self.num_node)) + a_close = np.zeros((self.num_node, self.num_node)) + a_further = np.zeros((self.num_node, self.num_node)) + for i in range(self.num_node): + for j in range(self.num_node): + if self.hop_dis[j, i] == hop: + if self.hop_dis[j, self.center] == self.hop_dis[ + i, self.center]: + a_root[j, i] = normalize_adjacency[j, i] + elif self.hop_dis[j, self. + center] > self.hop_dis[i, self. + center]: + a_close[j, i] = normalize_adjacency[j, i] + else: + a_further[j, i] = normalize_adjacency[j, i] + if hop == 0: + A.append(a_root) + else: + A.append(a_root + a_close) + A.append(a_further) + A = np.stack(A) + self.A = A + else: + raise NotImplementedError("This Strategy is not supported") + + +def get_hop_distance(num_node, edge, max_hop=1): + A = np.zeros((num_node, num_node)) + for i, j in edge: + A[j, i] = 1 + A[i, j] = 1 + + # compute hop steps + hop_dis = np.zeros((num_node, num_node)) + np.inf + transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] + arrive_mat = (np.stack(transfer_mat) > 0) + for d in range(max_hop, -1, -1): + hop_dis[arrive_mat[d]] = d + return hop_dis + + +def normalize_digraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i]**(-1) + AD = np.dot(A, Dn) + return AD + + +def normalize_undigraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i]**(-0.5) + DAD = np.dot(np.dot(Dn, A), Dn) + return DAD diff --git a/main/eval/a2m/recognition/models/stgcnutils/tgcn.py b/main/eval/a2m/recognition/models/stgcnutils/tgcn.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc5db2b97cf711eb0ce5c47905efa52a86d5eb4 --- /dev/null +++ b/main/eval/a2m/recognition/models/stgcnutils/tgcn.py @@ -0,0 +1,64 @@ +# The based unit of graph convolutional networks. + +import torch +import torch.nn as nn + + +class ConvTemporalGraphical(nn.Module): + + r"""The basic module for applying a graph convolution. + Args: + in_channels (int): Number of channels in the input sequence data + out_channels (int): Number of channels produced by the convolution + kernel_size (int): Size of the graph convolving kernel + t_kernel_size (int): Size of the temporal convolving kernel + t_stride (int, optional): Stride of the temporal convolution. Default: 1 + t_padding (int, optional): Temporal zero-padding added to both sides of + the input. Default: 0 + t_dilation (int, optional): Spacing between temporal kernel elements. + Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. + Default: ``True`` + Shape: + - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format + - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format + - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format + - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format + where + :math:`N` is a batch size, + :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, + :math:`T_{in}/T_{out}` is a length of input/output sequence, + :math:`V` is the number of graph nodes. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + t_kernel_size=1, + t_stride=1, + t_padding=0, + t_dilation=1, + bias=True): + super().__init__() + + self.kernel_size = kernel_size + self.conv = nn.Conv2d( + in_channels, + out_channels * kernel_size, + kernel_size=(t_kernel_size, 1), + padding=(t_padding, 0), + stride=(t_stride, 1), + dilation=(t_dilation, 1), + bias=bias) + + def forward(self, x, A): + assert A.size(0) == self.kernel_size + + x = self.conv(x) + + n, kc, t, v = x.size() + x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v) + x = torch.einsum('nkctv,kvw->nctw', (x, A)) + + return x.contiguous(), A diff --git a/main/eval/a2m/stgcn/accuracy.py b/main/eval/a2m/stgcn/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..e707ef9a48bd8d9cf940fc475ba2aa93776855f5 --- /dev/null +++ b/main/eval/a2m/stgcn/accuracy.py @@ -0,0 +1,14 @@ +import torch + + +def calculate_accuracy(model, motion_loader, num_labels, classifier, device): + confusion = torch.zeros(num_labels, num_labels, dtype=torch.long) + with torch.no_grad(): + for batch in motion_loader: + batch_prob = classifier(batch)["yhat"] + batch_pred = batch_prob.max(dim=1).indices + for label, pred in zip(batch["y"], batch_pred): + confusion[label][pred] += 1 + + accuracy = torch.trace(confusion)/torch.sum(confusion) + return accuracy.item(), confusion diff --git a/main/eval/a2m/stgcn/diversity.py b/main/eval/a2m/stgcn/diversity.py new file mode 100644 index 0000000000000000000000000000000000000000..5136a41d28ee22620f3496f51e9104ed4d9872bc --- /dev/null +++ b/main/eval/a2m/stgcn/diversity.py @@ -0,0 +1,54 @@ +import torch +import numpy as np + + +# from action2motion +def calculate_diversity_multimodality(activations, labels, num_labels, seed=None, unconstrained = False): + diversity_times = 200 + multimodality_times = 20 + if not unconstrained: + labels = labels.long() + num_motions = activations.shape[0] # len(labels) + + diversity = 0 + + if seed is not None: + np.random.seed(seed) + + first_indices = np.random.randint(0, num_motions, diversity_times) + second_indices = np.random.randint(0, num_motions, diversity_times) + for first_idx, second_idx in zip(first_indices, second_indices): + diversity += torch.dist(activations[first_idx, :], + activations[second_idx, :]) + diversity /= diversity_times + + if not unconstrained: + multimodality = 0 + label_quotas = np.zeros(num_labels) + label_quotas[labels.unique()] = multimodality_times # if a label does not appear in batch, its quota remains zero + while np.any(label_quotas > 0): + # print(label_quotas) + first_idx = np.random.randint(0, num_motions) + first_label = labels[first_idx] + if not label_quotas[first_label]: + continue + + second_idx = np.random.randint(0, num_motions) + second_label = labels[second_idx] + while first_label != second_label: + second_idx = np.random.randint(0, num_motions) + second_label = labels[second_idx] + + label_quotas[first_label] -= 1 + + first_activation = activations[first_idx, :] + second_activation = activations[second_idx, :] + multimodality += torch.dist(first_activation, + second_activation) + + multimodality /= (multimodality_times * num_labels) + else: + multimodality = torch.tensor(np.nan) + + return diversity.item(), multimodality.item() + diff --git a/main/eval/a2m/stgcn/evaluate.py b/main/eval/a2m/stgcn/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8170128d88836eccae48597116bc850e42f306 --- /dev/null +++ b/main/eval/a2m/stgcn/evaluate.py @@ -0,0 +1,108 @@ +import torch +import numpy as np +from .accuracy import calculate_accuracy +from .fid import calculate_fid +from .diversity import calculate_diversity_multimodality + +from eval.a2m.recognition.models.stgcn import STGCN + + +class Evaluation: + def __init__(self, dataname, parameters, device, seed=None): + layout = "smpl" # if parameters["glob"] else "smpl_noglobal" + model = STGCN(in_channels=parameters["nfeats"], + num_class=parameters["num_classes"], + graph_args={"layout": layout, "strategy": "spatial"}, + edge_importance_weighting=True, + device=device) + + model = model.to(device) + + model_path = "./assets/actionrecognition/uestc_rot6d_stgcn.tar" + + state_dict = torch.load(model_path, map_location=device) + model.load_state_dict(state_dict) + model.eval() + + self.num_classes = parameters["num_classes"] + self.model = model + + self.dataname = dataname + self.device = device + + self.seed = seed + + def compute_features(self, model, motionloader): + # calculate_activations_labels function from action2motion + activations = [] + labels = [] + with torch.no_grad(): + for idx, batch in enumerate(motionloader): + activations.append(self.model(batch)["features"]) + if model.cond_mode != 'no_cond': + labels.append(batch["y"]) + activations = torch.cat(activations, dim=0) + if model.cond_mode != 'no_cond': + labels = torch.cat(labels, dim=0) + return activations, labels + + @staticmethod + def calculate_activation_statistics(activations): + activations = activations.cpu().numpy() + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return mu, sigma + + def evaluate(self, model, loaders): + def print_logs(metric, key): + print(f"Computing stgcn {metric} on the {key} loader ...") + + metrics_all = {} + for sets in ["train", "test"]: + computedfeats = {} + metrics = {} + for key, loaderSets in loaders.items(): + loader = loaderSets[sets] + + metric = "accuracy" + mkey = f"{metric}_{key}" + if model.cond_mode != 'no_cond': + print_logs(metric, key) + metrics[mkey], _ = calculate_accuracy(model, loader, + self.num_classes, + self.model, self.device) + else: + metrics[mkey] = np.nan + + # features for diversity + print_logs("features", key) + feats, labels = self.compute_features(model, loader) + print_logs("stats", key) + stats = self.calculate_activation_statistics(feats) + + computedfeats[key] = {"feats": feats, + "labels": labels, + "stats": stats} + + print_logs("diversity", key) + ret = calculate_diversity_multimodality(feats, labels, self.num_classes, + seed=self.seed, unconstrained=(model.cond_mode=='no_cond')) + metrics[f"diversity_{key}"], metrics[f"multimodality_{key}"] = ret + + # taking the stats of the ground truth and remove it from the computed feats + gtstats = computedfeats["gt"]["stats"] + # computing fid + for key, loader in computedfeats.items(): + metric = "fid" + mkey = f"{metric}_{key}" + + stats = computedfeats[key]["stats"] + metrics[mkey] = float(calculate_fid(gtstats, stats)) + + metrics_all[sets] = metrics + + metrics = {} + for sets in ["train", "test"]: + for key in metrics_all[sets]: + metrics[f"{key}_{sets}"] = metrics_all[sets][key] + return metrics diff --git a/main/eval/a2m/stgcn/fid.py b/main/eval/a2m/stgcn/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..4302e6b2808e87e3d9d4d8080db89d1b03e45f85 --- /dev/null +++ b/main/eval/a2m/stgcn/fid.py @@ -0,0 +1,61 @@ +import numpy as np +from scipy import linalg + + +# from action2motion +def calculate_fid(statistics_1, statistics_2): + return calculate_frechet_distance(statistics_1[0], statistics_1[1], + statistics_2[0], statistics_2[1]) + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) diff --git a/main/eval/a2m/stgcn_eval.py b/main/eval/a2m/stgcn_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8426a2c1136093fcf81b204fa0b86cfbd5b182 --- /dev/null +++ b/main/eval/a2m/stgcn_eval.py @@ -0,0 +1,147 @@ +import copy +import torch +from tqdm import tqdm +import functools + +from utils.fixseed import fixseed + +from eval.a2m.stgcn.evaluate import Evaluation as STGCNEvaluation +from torch.utils.data import DataLoader +from data_loaders.tensors import collate + + +from .tools import format_metrics +import utils.rotation_conversions as geometry +from utils import dist_util + + +def convert_x_to_rot6d(x, pose_rep): + # convert rotation to rot6d + if pose_rep == "rotvec": + x = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(x)) + elif pose_rep == "rotmat": + x = x.reshape(*x.shape[:-1], 3, 3) + x = geometry.matrix_to_rotation_6d(x) + elif pose_rep == "rotquat": + x = geometry.matrix_to_rotation_6d(geometry.quaternion_to_matrix(x)) + elif pose_rep == "rot6d": + x = x + else: + raise NotImplementedError("No geometry for this one.") + return x + + +class NewDataloader: + def __init__(self, mode, model, diffusion, dataiterator, device, cond_mode, dataset, num_samples): + assert mode in ["gen", "gt"] + + self.batches = [] + sample_fn = diffusion.p_sample_loop + + with torch.no_grad(): + for motions, model_kwargs in tqdm(dataiterator, desc=f"Construct dataloader: {mode}.."): + motions = motions.to(device) + if num_samples != -1 and len(self.batches) * dataiterator.batch_size > num_samples: + continue # do not break because it confuses the multiple loaders + batch = dict() + if mode == "gen": + sample = sample_fn(model, motions.shape, clip_denoised=False, model_kwargs=model_kwargs) + batch['output'] = sample + elif mode == "gt": + batch['output'] = motions + + max_n_frames = model_kwargs['y']['lengths'].max() + mask = model_kwargs['y']['mask'].reshape(dataiterator.batch_size, max_n_frames).bool() + batch["output_xyz"] = model.rot2xyz(x=batch["output"], mask=mask, pose_rep='rot6d', glob=True, + translation=True, jointstype='smpl', vertstrans=True, betas=None, + beta=0, glob_rot=None, get_rotations_back=False) + if model.translation: + # the stgcn model expects rotations only + batch["output"] = batch["output"][:, :-1] + + batch["lengths"] = model_kwargs['y']['lengths'].to(device) + # using torch.long so lengths/action will be used as indices + if cond_mode != 'no_cond': # proceed only if not running unconstrained + batch["y"] = model_kwargs['y']['action'].squeeze().long().cpu() # using torch.long so lengths/action will be used as indices + self.batches.append(batch) + + num_samples_last_batch = num_samples % dataiterator.batch_size + if num_samples_last_batch > 0: + for k, v in self.batches[-1].items(): + self.batches[-1][k] = v[:num_samples_last_batch] + + + def __iter__(self): + return iter(self.batches) + + +def evaluate(args, model, diffusion, data): + torch.multiprocessing.set_sharing_strategy('file_system') + + bs = args.batch_size + args.num_classes = 40 + args.nfeats = 6 + args.njoint = 25 + device = dist_util.dev() + + + recogparameters = args.__dict__.copy() + recogparameters["pose_rep"] = "rot6d" + recogparameters["nfeats"] = 6 + + # Action2motionEvaluation + stgcnevaluation = STGCNEvaluation(args.dataset, recogparameters, device) + + stgcn_metrics = {} + + data_types = ['train', 'test'] + datasetGT = {'train': [data], 'test': [copy.deepcopy(data)]} + + for key in data_types: + datasetGT[key][0].split = key + + compute_gt_gt = False + if compute_gt_gt: + for key in data_types: + datasetGT[key].append(copy.deepcopy(datasetGT[key][0])) + + model.eval() + + allseeds = list(range(args.num_seeds)) + + for index, seed in enumerate(allseeds): + print(f"Evaluation number: {index + 1}/{args.num_seeds}") + fixseed(seed) + for key in data_types: + for data in datasetGT[key]: + data.reset_shuffle() + data.shuffle() + + dataiterator = {key: [DataLoader(data, batch_size=bs, shuffle=False, num_workers=8, collate_fn=collate) + for data in datasetGT[key]] + for key in data_types} + + new_data_loader = functools.partial(NewDataloader, model=model, diffusion=diffusion, device=device, + cond_mode=args.cond_mode, dataset=args.dataset, num_samples=args.num_samples) + gtLoaders = {key: new_data_loader(mode="gt", dataiterator=dataiterator[key][0]) + for key in ["train", "test"]} + + if compute_gt_gt: + gtLoaders2 = {key: new_data_loader(mode="gt", dataiterator=dataiterator[key][0]) + for key in ["train", "test"]} + + genLoaders = {key: new_data_loader(mode="gen", dataiterator=dataiterator[key][0]) + for key in ["train", "test"]} + + loaders = {"gen": genLoaders, + "gt": gtLoaders} + + if compute_gt_gt: + loaders["gt2"] = gtLoaders2 + + stgcn_metrics[seed] = stgcnevaluation.evaluate(model, loaders) + del loaders + + metrics = {"feats": {key: [format_metrics(stgcn_metrics[seed])[key] for seed in allseeds] for key in stgcn_metrics[allseeds[0]]}} + + return metrics diff --git a/main/eval/a2m/tools.py b/main/eval/a2m/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..f362244e374f74465c605148cb1c150115ab71dd --- /dev/null +++ b/main/eval/a2m/tools.py @@ -0,0 +1,19 @@ +import yaml + + +def format_metrics(metrics, formatter="{:.6}"): + newmetrics = {} + for key, val in metrics.items(): + newmetrics[key] = formatter.format(val) + return newmetrics + + +def save_metrics(path, metrics): + with open(path, "w") as yfile: + yaml.dump(metrics, yfile) + + +def load_metrics(path): + with open(path, "r") as yfile: + string = yfile.read() + return yaml.load(string, yaml.loader.BaseLoader) diff --git a/main/eval/eval_humanact12_uestc.py b/main/eval/eval_humanact12_uestc.py new file mode 100644 index 0000000000000000000000000000000000000000..a6fa2b2189447c4d66f6e6c982d8ede80b49a6e9 --- /dev/null +++ b/main/eval/eval_humanact12_uestc.py @@ -0,0 +1,80 @@ +""" +Generate a large batch of image samples from a model and save them as a large +numpy array. This can be used to produce samples for FID evaluation. +""" +import os +import torch +import re + +from utils import dist_util +from model.cfg_sampler import ClassifierFreeSampleModel +from data_loaders.get_data import get_dataset_loader +from eval.a2m.tools import save_metrics +from utils.parser_util import evaluation_parser +from utils.fixseed import fixseed +from utils.model_util import create_model_and_diffusion, load_model_wo_clip + + +def evaluate(args, model, diffusion, data): + scale = None + if args.guidance_param != 1: + model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler + scale = { + 'action': torch.ones(args.batch_size) * args.guidance_param, + } + model.to(dist_util.dev()) + model.eval() # disable random masking + + + folder, ckpt_name = os.path.split(args.model_path) + if args.dataset == "humanact12": + from eval.a2m.gru_eval import evaluate + eval_results = evaluate(args, model, diffusion, data) + elif args.dataset == "uestc": + from eval.a2m.stgcn_eval import evaluate + eval_results = evaluate(args, model, diffusion, data) + else: + raise NotImplementedError("This dataset is not supported.") + + # save results + iter = int(re.findall('\d+', ckpt_name)[0]) + scale = 1 if scale is None else scale['action'][0].item() + scale = str(scale).replace('.', 'p') + metricname = "evaluation_results_iter{}_samp{}_scale{}_a2m.yaml".format(iter, args.num_samples, scale) + evalpath = os.path.join(folder, metricname) + print(f"Saving evaluation: {evalpath}") + save_metrics(evalpath, eval_results) + + return eval_results + + +def main(): + args = evaluation_parser() + fixseed(args.seed) + dist_util.setup_dist(args.device) + + print(f'Eval mode [{args.eval_mode}]') + assert args.eval_mode in ['debug', 'full'], f'eval_mode {args.eval_mode} is not supported for dataset {args.dataset}' + if args.eval_mode == 'debug': + args.num_samples = 10 + args.num_seeds = 2 + else: + args.num_samples = 1000 + args.num_seeds = 20 + + data_loader = get_dataset_loader(name=args.dataset, num_frames=60, batch_size=args.batch_size,) + + print("creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, data_loader) + + print(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location='cpu') + load_model_wo_clip(model, state_dict) + + eval_results = evaluate(args, model, diffusion, data_loader.dataset) + + fid_to_print = {k : sum([float(vv) for vv in v])/len(v) for k, v in eval_results['feats'].items() if 'fid' in k and 'gen' in k} + print(fid_to_print) + +if __name__ == '__main__': + main() diff --git a/main/eval/eval_humanml.py b/main/eval/eval_humanml.py new file mode 100644 index 0000000000000000000000000000000000000000..b341e9ed3ab8345c0b424df99617dba8771294db --- /dev/null +++ b/main/eval/eval_humanml.py @@ -0,0 +1,304 @@ +from utils.parser_util import evaluation_parser +from utils.fixseed import fixseed +from datetime import datetime +from data_loaders.humanml.motion_loaders.model_motion_loaders import get_mdm_loader # get_motion_loader +from data_loaders.humanml.utils.metrics import * +from data_loaders.humanml.networks.evaluator_wrapper import EvaluatorMDMWrapper +from collections import OrderedDict +from data_loaders.humanml.scripts.motion_process import * +from data_loaders.humanml.utils.utils import * +from utils.model_util import create_model_and_diffusion, load_model_wo_clip + +from diffusion import logger +from utils import dist_util +from data_loaders.get_data import get_dataset_loader +from model.cfg_sampler import ClassifierFreeSampleModel + +torch.multiprocessing.set_sharing_strategy('file_system') + +def evaluate_matching_score(eval_wrapper, motion_loaders, file): + match_score_dict = OrderedDict({}) + R_precision_dict = OrderedDict({}) + activation_dict = OrderedDict({}) + print('========== Evaluating Matching Score ==========') + for motion_loader_name, motion_loader in motion_loaders.items(): + all_motion_embeddings = [] + score_list = [] + all_size = 0 + matching_score_sum = 0 + top_k_count = 0 + # print(motion_loader_name) + with torch.no_grad(): + for idx, batch in enumerate(motion_loader): + word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch + text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings( + word_embs=word_embeddings, + pos_ohot=pos_one_hots, + cap_lens=sent_lens, + motions=motions, + m_lens=m_lens + ) + dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(), + motion_embeddings.cpu().numpy()) + matching_score_sum += dist_mat.trace() + + argsmax = np.argsort(dist_mat, axis=1) + top_k_mat = calculate_top_k(argsmax, top_k=3) + top_k_count += top_k_mat.sum(axis=0) + + all_size += text_embeddings.shape[0] + + all_motion_embeddings.append(motion_embeddings.cpu().numpy()) + + all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0) + matching_score = matching_score_sum / all_size + R_precision = top_k_count / all_size + match_score_dict[motion_loader_name] = matching_score + R_precision_dict[motion_loader_name] = R_precision + activation_dict[motion_loader_name] = all_motion_embeddings + + print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}') + print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True) + + line = f'---> [{motion_loader_name}] R_precision: ' + for i in range(len(R_precision)): + line += '(top %d): %.4f ' % (i+1, R_precision[i]) + print(line) + print(line, file=file, flush=True) + + return match_score_dict, R_precision_dict, activation_dict + + +def evaluate_fid(eval_wrapper, groundtruth_loader, activation_dict, file): + eval_dict = OrderedDict({}) + gt_motion_embeddings = [] + print('========== Evaluating FID ==========') + with torch.no_grad(): + for idx, batch in enumerate(groundtruth_loader): + _, _, _, sent_lens, motions, m_lens, _ = batch + motion_embeddings = eval_wrapper.get_motion_embeddings( + motions=motions, + m_lens=m_lens + ) + gt_motion_embeddings.append(motion_embeddings.cpu().numpy()) + gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0) + gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings) + + # print(gt_mu) + for model_name, motion_embeddings in activation_dict.items(): + mu, cov = calculate_activation_statistics(motion_embeddings) + # print(mu) + fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) + print(f'---> [{model_name}] FID: {fid:.4f}') + print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True) + eval_dict[model_name] = fid + return eval_dict + + +def evaluate_diversity(activation_dict, file, diversity_times): + eval_dict = OrderedDict({}) + print('========== Evaluating Diversity ==========') + for model_name, motion_embeddings in activation_dict.items(): + diversity = calculate_diversity(motion_embeddings, diversity_times) + eval_dict[model_name] = diversity + print(f'---> [{model_name}] Diversity: {diversity:.4f}') + print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True) + return eval_dict + + +def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times): + eval_dict = OrderedDict({}) + print('========== Evaluating MultiModality ==========') + for model_name, mm_motion_loader in mm_motion_loaders.items(): + mm_motion_embeddings = [] + with torch.no_grad(): + for idx, batch in enumerate(mm_motion_loader): + # (1, mm_replications, dim_pos) + motions, m_lens = batch + motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0]) + mm_motion_embeddings.append(motion_embedings.unsqueeze(0)) + if len(mm_motion_embeddings) == 0: + multimodality = 0 + else: + mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy() + multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times) + print(f'---> [{model_name}] Multimodality: {multimodality:.4f}') + print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True) + eval_dict[model_name] = multimodality + return eval_dict + + +def get_metric_statistics(values, replication_times): + mean = np.mean(values, axis=0) + std = np.std(values, axis=0) + conf_interval = 1.96 * std / np.sqrt(replication_times) + return mean, conf_interval + + +def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=False): + with open(log_file, 'w') as f: + all_metrics = OrderedDict({'Matching Score': OrderedDict({}), + 'R_precision': OrderedDict({}), + 'FID': OrderedDict({}), + 'Diversity': OrderedDict({}), + 'MultiModality': OrderedDict({})}) + for replication in range(replication_times): + motion_loaders = {} + mm_motion_loaders = {} + motion_loaders['ground truth'] = gt_loader + for motion_loader_name, motion_loader_getter in eval_motion_loaders.items(): + motion_loader, mm_motion_loader = motion_loader_getter() + motion_loaders[motion_loader_name] = motion_loader + mm_motion_loaders[motion_loader_name] = mm_motion_loader + + print(f'==================== Replication {replication} ====================') + print(f'==================== Replication {replication} ====================', file=f, flush=True) + print(f'Time: {datetime.now()}') + print(f'Time: {datetime.now()}', file=f, flush=True) + mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f) + + print(f'Time: {datetime.now()}') + print(f'Time: {datetime.now()}', file=f, flush=True) + fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f) + + print(f'Time: {datetime.now()}') + print(f'Time: {datetime.now()}', file=f, flush=True) + div_score_dict = evaluate_diversity(acti_dict, f, diversity_times) + + if run_mm: + print(f'Time: {datetime.now()}') + print(f'Time: {datetime.now()}', file=f, flush=True) + mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times) + + print(f'!!! DONE !!!') + print(f'!!! DONE !!!', file=f, flush=True) + + for key, item in mat_score_dict.items(): + if key not in all_metrics['Matching Score']: + all_metrics['Matching Score'][key] = [item] + else: + all_metrics['Matching Score'][key] += [item] + + for key, item in R_precision_dict.items(): + if key not in all_metrics['R_precision']: + all_metrics['R_precision'][key] = [item] + else: + all_metrics['R_precision'][key] += [item] + + for key, item in fid_score_dict.items(): + if key not in all_metrics['FID']: + all_metrics['FID'][key] = [item] + else: + all_metrics['FID'][key] += [item] + + for key, item in div_score_dict.items(): + if key not in all_metrics['Diversity']: + all_metrics['Diversity'][key] = [item] + else: + all_metrics['Diversity'][key] += [item] + if run_mm: + for key, item in mm_score_dict.items(): + if key not in all_metrics['MultiModality']: + all_metrics['MultiModality'][key] = [item] + else: + all_metrics['MultiModality'][key] += [item] + + + # print(all_metrics['Diversity']) + mean_dict = {} + for metric_name, metric_dict in all_metrics.items(): + print('========== %s Summary ==========' % metric_name) + print('========== %s Summary ==========' % metric_name, file=f, flush=True) + for model_name, values in metric_dict.items(): + # print(metric_name, model_name) + mean, conf_interval = get_metric_statistics(np.array(values), replication_times) + mean_dict[metric_name + '_' + model_name] = mean + # print(mean, mean.dtype) + if isinstance(mean, np.float64) or isinstance(mean, np.float32): + print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}') + print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True) + elif isinstance(mean, np.ndarray): + line = f'---> [{model_name}]' + for i in range(len(mean)): + line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i]) + print(line) + print(line, file=f, flush=True) + return mean_dict + + +if __name__ == '__main__': + args = evaluation_parser() + fixseed(args.seed) + args.batch_size = 32 # This must be 32! Don't change it! otherwise it will cause a bug in R precision calc! + name = os.path.basename(os.path.dirname(args.model_path)) + niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '') + log_file = os.path.join(os.path.dirname(args.model_path), 'eval_humanml_{}_{}'.format(name, niter)) + if args.guidance_param != 1.: + log_file += f'_gscale{args.guidance_param}' + log_file += f'_{args.eval_mode}' + log_file += '.log' + + print(f'Will save to log file [{log_file}]') + + print(f'Eval mode [{args.eval_mode}]') + if args.eval_mode == 'debug': + num_samples_limit = 1000 # None means no limit (eval over all dataset) + run_mm = False + mm_num_samples = 0 + mm_num_repeats = 0 + mm_num_times = 0 + diversity_times = 300 + replication_times = 5 # about 3 Hrs + elif args.eval_mode == 'wo_mm': + num_samples_limit = 1000 + run_mm = False + mm_num_samples = 0 + mm_num_repeats = 0 + mm_num_times = 0 + diversity_times = 300 + replication_times = 20 # about 12 Hrs + elif args.eval_mode == 'mm_short': + num_samples_limit = 1000 + run_mm = True + mm_num_samples = 100 + mm_num_repeats = 30 + mm_num_times = 10 + diversity_times = 300 + replication_times = 5 # about 15 Hrs + else: + raise ValueError() + + + dist_util.setup_dist(args.device) + logger.configure() + + logger.log("creating data loader...") + split = 'test' + gt_loader = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, num_frames=None, split=split, hml_mode='gt') + gen_loader = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, num_frames=None, split=split, hml_mode='eval') + num_actions = gen_loader.dataset.num_actions + + logger.log("Creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, gen_loader) + + logger.log(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location='cpu') + load_model_wo_clip(model, state_dict) + + if args.guidance_param != 1: + model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler + model.to(dist_util.dev()) + model.eval() # disable random masking + + eval_motion_loaders = { + ################ + ## HumanML3D Dataset## + ################ + 'vald': lambda: get_mdm_loader( + model, diffusion, args.batch_size, + gen_loader, mm_num_samples, mm_num_repeats, gt_loader.dataset.opt.max_motion_length, num_samples_limit, args.guidance_param + ) + } + + eval_wrapper = EvaluatorMDMWrapper(args.dataset, dist_util.dev()) + evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=run_mm) diff --git a/main/eval/unconstrained/evaluate.py b/main/eval/unconstrained/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..0559628e829b8bf2d7e23f207e10428502ea1e5b --- /dev/null +++ b/main/eval/unconstrained/evaluate.py @@ -0,0 +1,111 @@ +from eval.unconstrained.models.stgcn import STGCN +import pandas as pd +import os.path as osp +import os +import datetime + +import torch + +from torch.utils.data import DataLoader +import numpy as np +import sys as _sys +from eval.a2m.action2motion.fid import calculate_fid +from eval.a2m.action2motion.diversity import calculate_diversity +from eval.unconstrained.metrics.kid import calculate_kid +from eval.unconstrained.metrics.precision_recall import precision_and_recall +from matplotlib import pyplot as plt + +TEST = False + + +def initialize_model(device, modelpath): + num_classes = 12 + model = STGCN(in_channels=3, + num_class=num_classes, + graph_args={"layout": 'openpose', "strategy": "spatial"}, + edge_importance_weighting=True, + device=device) + model = model.to(device) + state_dict = torch.load(modelpath, map_location=device) + model.load_state_dict(state_dict) + model.eval() + return model + +def calculate_activation_statistics(activations): + activations = activations.cpu().detach().numpy() + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return mu, sigma + + +def compute_features(model, iterator, device): + activations = [] + predictions = [] + with torch.no_grad(): + for i, batch in enumerate(iterator): + batch_for_model = {} + batch_for_model['x'] = batch.to(device).float() + model(batch_for_model) + activations.append(batch_for_model['features']) + predictions.append(batch_for_model['yhat']) + # labels.append(batch_for_model['y']) + activations = torch.cat(activations, dim=0) + predictions = torch.cat(predictions, dim=0) + return activations, predictions + + +def evaluate_unconstrained_metrics(generated_motions, device, fast): + + act_rec_model_path = './assets/actionrecognition/humanact12_gru_modi_struct.pth.tar' + dataset_path = './dataset/HumanAct12Poses/humanact12_modi_struct.npy' + + # initialize model + act_rec_model = initialize_model(device, act_rec_model_path) + + generated_motions -= generated_motions[:, 8:9, :, :] # locate root joint of all frames at origin + + iterator_generated = DataLoader(generated_motions, batch_size=64, shuffle=False, num_workers=8) + + # compute features of generated motions + generated_features, generated_predictions = compute_features(act_rec_model, iterator_generated, device=device) + generated_stats = calculate_activation_statistics(generated_features) + + + # dataset motions + motion_data_raw = np.load(dataset_path, allow_pickle=True) + motion_data = motion_data_raw[:, :15] # data has 16 joints for back compitability with older formats + motion_data -= motion_data[:, 8:9, :, :] # locate root joint of all frames at origin + iterator_dataset = DataLoader(motion_data, batch_size=64, shuffle=False, num_workers=8) + + # compute features of dataset motions + dataset_features, dataset_predictions = compute_features(act_rec_model, iterator_dataset, device=device) + real_stats = calculate_activation_statistics(dataset_features) + + print("evaluation resutls:\n") + + fid = calculate_fid(generated_stats, real_stats) + print(f"FID score: {fid}\n") + + print("calculating KID...") + kid = calculate_kid(dataset_features.cpu(), generated_features.cpu()) + (m, s) = kid + print('KID : %.3f (%.3f)\n' % (m, s)) + + dataset_diversity = calculate_diversity(dataset_features) + generated_diversity = calculate_diversity(generated_features) + print(f"Diversity of generated motions: {generated_diversity}") + print(f"Diversity of dataset motions: {dataset_diversity}\n") + + if fast: + print("Skipping precision-recall calculation\n") + precision = recall = None + else: + print("calculating precision recall...") + precision, recall = precision_and_recall(generated_features, dataset_features) + print(f"precision: {precision}") + print(f"recall: {recall}\n") + + metrics = {'fid': fid, 'kid': kid[0], 'diversity_gen': generated_diversity.cpu().item(), 'diversity_gt': dataset_diversity.cpu().item(), + 'precision': precision, 'recall':recall} + return metrics + diff --git a/main/eval/unconstrained/metrics/kid.py b/main/eval/unconstrained/metrics/kid.py new file mode 100644 index 0000000000000000000000000000000000000000..f56c63f34953cdd8304711d8b88c8009a1bc19fc --- /dev/null +++ b/main/eval/unconstrained/metrics/kid.py @@ -0,0 +1,136 @@ +import torch +import numpy as np +from tqdm import tqdm +from sklearn.metrics.pairwise import polynomial_kernel +import sys + +# from: https://github.com/abdulfatir/gan-metrics-pytorch/blob/master/kid_score.py +def polynomial_mmd_averages(codes_g, codes_r, n_subsets=50, subset_size=1000, + ret_var=True, output=sys.stdout, **kernel_args): + m = min(codes_g.shape[0], codes_r.shape[0]) + mmds = np.zeros(n_subsets) + if ret_var: + vars = np.zeros(n_subsets) + choice = np.random.choice + + replace = subset_size < len(codes_g) + with tqdm(range(n_subsets), desc='MMD', file=output, disable=True) as bar: + for i in bar: + g = codes_g[choice(len(codes_g), subset_size, replace=replace)] + r = codes_r[choice(len(codes_r), subset_size, replace=replace)] + o = polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var) + if ret_var: + mmds[i], vars[i] = o + else: + mmds[i] = o + bar.set_postfix({'mean': mmds[:i+1].mean()}) + return (mmds, vars) if ret_var else mmds + +def polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1, + var_at_m=None, ret_var=True): + # use k(x, y) = (gamma + coef0)^degree + # default gamma is 1 / dim + X = codes_g + Y = codes_r + + K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0) + K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0) + K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0) + + return _mmd2_and_variance(K_XX, K_XY, K_YY, + var_at_m=var_at_m, ret_var=ret_var) + +def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False, + mmd_est='unbiased', block_size=1024, + var_at_m=None, ret_var=True): + # based on + # https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py + # but changed to not compute the full kernel matrix at once + m = K_XX.shape[0] + assert K_XX.shape == (m, m) + assert K_XY.shape == (m, m) + assert K_YY.shape == (m, m) + if var_at_m is None: + var_at_m = m + + # Get the various sums of kernels that we'll use + # Kts drop the diagonal, but we don't need to compute them explicitly + if unit_diagonal: + diag_X = diag_Y = 1 + sum_diag_X = sum_diag_Y = m + sum_diag2_X = sum_diag2_Y = m + else: + diag_X = np.diagonal(K_XX) + diag_Y = np.diagonal(K_YY) + + sum_diag_X = diag_X.sum() + sum_diag_Y = diag_Y.sum() + + sum_diag2_X = _sqn(diag_X) + sum_diag2_Y = _sqn(diag_Y) + + Kt_XX_sums = K_XX.sum(axis=1) - diag_X + Kt_YY_sums = K_YY.sum(axis=1) - diag_Y + K_XY_sums_0 = K_XY.sum(axis=0) + K_XY_sums_1 = K_XY.sum(axis=1) + + Kt_XX_sum = Kt_XX_sums.sum() + Kt_YY_sum = Kt_YY_sums.sum() + K_XY_sum = K_XY_sums_0.sum() + + if mmd_est == 'biased': + mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) + + (Kt_YY_sum + sum_diag_Y) / (m * m) + - 2 * K_XY_sum / (m * m)) + else: + assert mmd_est in {'unbiased', 'u-statistic'} + mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1)) + if mmd_est == 'unbiased': + mmd2 -= 2 * K_XY_sum / (m * m) + else: + mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1)) + + if not ret_var: + return mmd2 + + Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X + Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y + K_XY_2_sum = _sqn(K_XY) + + dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1) + dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0) + + m1 = m - 1 + m2 = m - 2 + zeta1_est = ( + 1 / (m * m1 * m2) * ( + _sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum) + - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2) + + 1 / (m * m * m1) * ( + _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum) + - 2 / m**4 * K_XY_sum**2 + - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) + + 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum + ) + zeta2_est = ( + 1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum) + - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2) + + 2 / (m * m) * K_XY_2_sum + - 2 / m**4 * K_XY_sum**2 + - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) + + 4 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum + ) + var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est + + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est) + + return mmd2, var_est + + +def _sqn(arr): + flat = np.ravel(arr) + return flat.dot(flat) + +def calculate_kid(real_activations, generated_activations): + kid_values = polynomial_mmd_averages(real_activations, generated_activations, n_subsets=100) + results = (kid_values[0].mean(), kid_values[0].std()) + return results diff --git a/main/eval/unconstrained/metrics/precision_recall.py b/main/eval/unconstrained/metrics/precision_recall.py new file mode 100644 index 0000000000000000000000000000000000000000..2edd4a9167ece8019fdf6bf5ac4729eb2b5124b0 --- /dev/null +++ b/main/eval/unconstrained/metrics/precision_recall.py @@ -0,0 +1,55 @@ +# based on https://github.com/blandocs/improved-precision-and-recall-metric-pytorch/blob/master/functions.py +import os, torch +import numpy as np +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm + +# self.batch_size = args.batch_size +# self.cpu = args.cpu +# self.data_size = args.data_size + +def precision_and_recall(generated_features, real_features): + k = 3 + + data_num = min(len(generated_features), len(real_features)) + print(f'data num: {data_num}') + + if data_num <= 0: + print("there is no data") + return + generated_features = generated_features[:data_num] + real_features = real_features[:data_num] + + # get precision and recall + precision = manifold_estimate(real_features, generated_features, k) + recall = manifold_estimate(generated_features, real_features, k) + + return precision, recall + +def manifold_estimate( A_features, B_features, k): + A_features = list(A_features) + B_features = list(B_features) + KNN_list_in_A = {} + for A in tqdm(A_features, ncols=80): + pairwise_distances = np.zeros(shape=(len(A_features))) + + for i, A_prime in enumerate(A_features): + d = torch.norm((A - A_prime), 2) + pairwise_distances[i] = d + + v = np.partition(pairwise_distances, k)[k] + KNN_list_in_A[A] = v + + n = 0 + + for B in tqdm(B_features, ncols=80): + for A_prime in A_features: + d = torch.norm((B - A_prime), 2) + if d <= KNN_list_in_A[A_prime]: + n += 1 + break + + return n / len(B_features) + + diff --git a/main/eval/unconstrained/models/stgcn.py b/main/eval/unconstrained/models/stgcn.py new file mode 100644 index 0000000000000000000000000000000000000000..641d8148e06a4618c5ced7ded32dc8cf9f9ba619 --- /dev/null +++ b/main/eval/unconstrained/models/stgcn.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from eval.a2m.recognition.models.stgcnutils.tgcn import ConvTemporalGraphical +from eval.unconstrained.models.stgcnutils.graph import Graph + +__all__ = ["STGCN"] + + +class STGCN(nn.Module): + r"""Spatial temporal graph convolutional networks. + Args: + in_channels (int): Number of channels in the input data + num_class (int): Number of classes for the classification task + graph_args (dict): The arguments for building the graph + edge_importance_weighting (bool): If ``True``, adds a learnable + importance weighting to the edges of the graph + **kwargs (optional): Other parameters for graph convolution units + Shape: + - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})` + - Output: :math:`(N, num_class)` where + :math:`N` is a batch size, + :math:`T_{in}` is a length of input sequence, + :math:`V_{in}` is the number of graph nodes, + :math:`M_{in}` is the number of instance in a frame. + """ + + def __init__(self, in_channels, num_class, graph_args, + edge_importance_weighting, device, **kwargs): + super().__init__() + + self.device = device + self.num_class = num_class + + self.losses = ["accuracy", "cross_entropy", "mixed"] + self.criterion = torch.nn.CrossEntropyLoss(reduction='mean') + + # load graph + self.graph = Graph(**graph_args) + A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) + self.register_buffer('A', A) + + # build networks + spatial_kernel_size = A.size(0) + temporal_kernel_size = 9 + kernel_size = (temporal_kernel_size, spatial_kernel_size) + self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) + # self.data_bn = nn.BatchNorm1d(in_channels * A.size(1), track_running_stats=False) + # self.data_bn = nn.InstanceNorm1d(in_channels * A.size(1)) + kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} + self.st_gcn_networks = nn.ModuleList(( + st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0), + st_gcn(64, 64, kernel_size, 1, **kwargs), + st_gcn(64, 64, kernel_size, 1, **kwargs), + # st_gcn(64, 64, kernel_size, 1, **kwargs), + st_gcn(64, 128, kernel_size, 2, **kwargs), + st_gcn(128, 128, kernel_size, 1, **kwargs), + # st_gcn(128, 128, kernel_size, 1, **kwargs), + st_gcn(128, 256, kernel_size, 2, **kwargs), + # st_gcn(256, 256, kernel_size, 1, **kwargs), + # st_gcn(256, 256, kernel_size, 1, **kwargs), + )) + + # initialize parameters for edge importance weighting + if edge_importance_weighting: + self.edge_importance = nn.ParameterList([ + nn.Parameter(torch.ones(self.A.size())) + for i in self.st_gcn_networks + ]) + else: + self.edge_importance = [1] * len(self.st_gcn_networks) + + # fcn for prediction + self.fcn = nn.Conv2d(256, num_class, kernel_size=1) + + def forward(self, batch): + # TODO: use mask + # Received batch["x"] as + # Batch(48), Joints(23), Quat(4), Time(157 + # Expecting: + # Batch, Quat:4, Time, Joints, 1 + x = batch["x"].permute(0, 2, 3, 1).unsqueeze(4).contiguous() + + # data normalization + N, C, T, V, M = x.size() + x = x.permute(0, 4, 3, 1, 2).contiguous() + x = x.view(N * M, V * C, T) + x = self.data_bn(x) + x = x.view(N, M, V, C, T) + x = x.permute(0, 1, 3, 4, 2).contiguous() + x = x.view(N * M, C, T, V) + + # forward + for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): + x, _ = gcn(x, self.A * importance) + + # compute feature + # _, c, t, v = x.size() + # features = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1) + # batch["features"] = features + + # global pooling + x = F.avg_pool2d(x, x.size()[2:]) + x = x.view(N, M, -1, 1, 1).mean(dim=1) + + # features + batch["features"] = x.squeeze() + + # prediction + x = self.fcn(x) + x = x.view(x.size(0), -1) + batch["yhat"] = x + return batch + + def compute_accuracy(self, batch): + confusion = torch.zeros(self.num_class, self.num_class, dtype=int) + yhat = batch["yhat"].max(dim=1).indices + ygt = batch["y"] + for label, pred in zip(ygt, yhat): + confusion[label][pred] += 1 + accuracy = torch.trace(confusion)/torch.sum(confusion) + return accuracy + + def compute_loss(self, batch): + cross_entropy = self.criterion(batch["yhat"], batch["y"]) + mixed_loss = cross_entropy + + acc = self.compute_accuracy(batch) + losses = {"cross_entropy": cross_entropy.item(), + "mixed": mixed_loss.item(), + "accuracy": acc.item()} + return mixed_loss, losses + + +class st_gcn(nn.Module): + r"""Applies a spatial temporal graph convolution over an input graph sequence. + Args: + in_channels (int): Number of channels in the input sequence data + out_channels (int): Number of channels produced by the convolution + kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel + stride (int, optional): Stride of the temporal convolution. Default: 1 + dropout (int, optional): Dropout rate of the final output. Default: 0 + residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True`` + Shape: + - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format + - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format + - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format + - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format + where + :math:`N` is a batch size, + :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, + :math:`T_{in}/T_{out}` is a length of input/output sequence, + :math:`V` is the number of graph nodes. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dropout=0, + residual=True): + super().__init__() + + assert len(kernel_size) == 2 + assert kernel_size[0] % 2 == 1 + padding = ((kernel_size[0] - 1) // 2, 0) + + self.gcn = ConvTemporalGraphical(in_channels, out_channels, + kernel_size[1]) + + self.tcn = nn.Sequential( + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + out_channels, + out_channels, + (kernel_size[0], 1), + (stride, 1), + padding, + ), + nn.BatchNorm2d(out_channels), + nn.Dropout(dropout, inplace=True), + ) + + if not residual: + self.residual = lambda x: 0 + + elif (in_channels == out_channels) and (stride == 1): + self.residual = lambda x: x + + else: + self.residual = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=(stride, 1)), + nn.BatchNorm2d(out_channels), + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, A): + + res = self.residual(x) + x, A = self.gcn(x, A) + x = self.tcn(x) + res + + return self.relu(x), A + + +if __name__ == "__main__": + model = STGCN(in_channels=3, num_class=60, edge_importance_weighting=True, graph_args={"layout": "smpl_noglobal", "strategy": "spatial"}) + # Batch, in_channels, time, vertices, M + inp = torch.rand(10, 3, 16, 23, 1) + out = model(inp) + print(out.shape) + import pdb + pdb.set_trace() diff --git a/main/eval/unconstrained/models/stgcnutils/graph.py b/main/eval/unconstrained/models/stgcnutils/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad28036a7960339ef7c62bcc843b1a0700fd013 --- /dev/null +++ b/main/eval/unconstrained/models/stgcnutils/graph.py @@ -0,0 +1,185 @@ +import numpy as np +import pickle as pkl + +from utils.config import SMPL_KINTREE_PATH + + +class Graph: + """ The Graph to model the skeletons extracted by the openpose + Args: + strategy (string): must be one of the follow candidates + - uniform: Uniform Labeling + - distance: Distance Partitioning + - spatial: Spatial Configuration + For more information, please refer to the section 'Partition Strategies' + in our paper (https://arxiv.org/abs/1801.07455). + layout (string): must be one of the follow candidates + - openpose: Is consists of 18 joints. For more information, please + refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output + - ntu-rgb+d: Is consists of 25 joints. For more information, please + refer to https://github.com/shahroudy/NTURGB-D + - smpl: Consists of 24/23 joints with without global rotation. + max_hop (int): the maximal distance between two connected nodes + dilation (int): controls the spacing between the kernel points + """ + + def __init__(self, + layout='openpose', + strategy='uniform', + kintree_path=SMPL_KINTREE_PATH, + max_hop=1, + dilation=1): + self.max_hop = max_hop + self.dilation = dilation + + self.kintree_path = kintree_path + + self.get_edge(layout) + self.hop_dis = get_hop_distance( + self.num_node, self.edge, max_hop=max_hop) + self.get_adjacency(strategy) + + def __str__(self): + return self.A + + def get_edge(self, layout): + if layout == 'openpose': + # self.num_node = 18 + self.num_node = 15 + self_link = [(i, i) for i in range(self.num_node)] + # neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, + # 11), + # (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1), + # (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)] + neighbor_link = [(4, 3), (3, 2), (2, 1), + (7, 6), (6, 5), (5, 1), + (1, 0), + (14, 13), (13, 12), (12, 8), + (11, 10), (10, 9), (9, 8), + (8, 1),] + self.edge = self_link + neighbor_link + self.center = 1 + elif layout == 'smpl': + self.num_node = 24 + self_link = [(i, i) for i in range(self.num_node)] + kt = pkl.load(open(self.kintree_path, "rb")) + neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])] + self.edge = self_link + neighbor_link + self.center = 0 + elif layout == 'smpl_noglobal': + self.num_node = 23 + self_link = [(i, i) for i in range(self.num_node)] + kt = pkl.load(open(self.kintree_path, "rb")) + neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])] + # remove the root joint + neighbor_1base = [n for n in neighbor_link if n[0] != 0 and n[1] != 0] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 0 + elif layout == 'ntu-rgb+d': + self.num_node = 25 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), + (6, 5), (7, 6), (8, 7), (9, 21), (10, 9), + (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), + (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), + (22, 23), (23, 8), (24, 25), (25, 12)] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 21 - 1 + elif layout == 'ntu_edge': + self.num_node = 24 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6), + (8, 7), (9, 2), (10, 9), (11, 10), (12, 11), + (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), + (18, 17), (19, 18), (20, 19), (21, 22), (22, 8), + (23, 24), (24, 12)] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 2 + # elif layout=='customer settings' + # pass + else: + raise NotImplementedError("This Layout is not supported") + + def get_adjacency(self, strategy): + valid_hop = range(0, self.max_hop + 1, self.dilation) + adjacency = np.zeros((self.num_node, self.num_node)) + for hop in valid_hop: + adjacency[self.hop_dis == hop] = 1 + normalize_adjacency = normalize_digraph(adjacency) + + if strategy == 'uniform': + A = np.zeros((1, self.num_node, self.num_node)) + A[0] = normalize_adjacency + self.A = A + elif strategy == 'distance': + A = np.zeros((len(valid_hop), self.num_node, self.num_node)) + for i, hop in enumerate(valid_hop): + A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop] + self.A = A + elif strategy == 'spatial': + A = [] + for hop in valid_hop: + a_root = np.zeros((self.num_node, self.num_node)) + a_close = np.zeros((self.num_node, self.num_node)) + a_further = np.zeros((self.num_node, self.num_node)) + for i in range(self.num_node): + for j in range(self.num_node): + if self.hop_dis[j, i] == hop: + if self.hop_dis[j, self.center] == self.hop_dis[ + i, self.center]: + a_root[j, i] = normalize_adjacency[j, i] + elif self.hop_dis[j, self. + center] > self.hop_dis[i, self. + center]: + a_close[j, i] = normalize_adjacency[j, i] + else: + a_further[j, i] = normalize_adjacency[j, i] + if hop == 0: + A.append(a_root) + else: + A.append(a_root + a_close) + A.append(a_further) + A = np.stack(A) + self.A = A + else: + raise NotImplementedError("This Strategy is not supported") + + +def get_hop_distance(num_node, edge, max_hop=1): + A = np.zeros((num_node, num_node)) + for i, j in edge: + A[j, i] = 1 + A[i, j] = 1 + + # compute hop steps + hop_dis = np.zeros((num_node, num_node)) + np.inf + transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] + arrive_mat = (np.stack(transfer_mat) > 0) + for d in range(max_hop, -1, -1): + hop_dis[arrive_mat[d]] = d + return hop_dis + + +def normalize_digraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i]**(-1) + AD = np.dot(A, Dn) + return AD + + +def normalize_undigraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i]**(-0.5) + DAD = np.dot(np.dot(Dn, A), Dn) + return DAD diff --git a/main/model/cfg_sampler.py b/main/model/cfg_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..88e1232221a492f6ba91d29bf09992d58ec3b8a3 --- /dev/null +++ b/main/model/cfg_sampler.py @@ -0,0 +1,32 @@ +import numpy as np +import torch +import torch.nn as nn +from copy import deepcopy + +# A wrapper model for Classifier-free guidance **SAMPLING** only +# https://arxiv.org/abs/2207.12598 +class ClassifierFreeSampleModel(nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model # model is the actual model to run + + assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions' + + # pointers to inner model + self.rot2xyz = self.model.rot2xyz + self.translation = self.model.translation + self.njoints = self.model.njoints + self.nfeats = self.model.nfeats + self.data_rep = self.model.data_rep + self.cond_mode = self.model.cond_mode + + def forward(self, x, timesteps, y=None): + cond_mode = self.model.cond_mode + assert cond_mode in ['text', 'action'] + y_uncond = deepcopy(y) + y_uncond['uncond'] = True + out = self.model(x, timesteps, y) + out_uncond = self.model(x, timesteps, y_uncond) + return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond)) + diff --git a/main/model/local_attention/__init__.py b/main/model/local_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45dae0c529a608cd0104a1b65534ce197961c6da --- /dev/null +++ b/main/model/local_attention/__init__.py @@ -0,0 +1,2 @@ +from local_attention.local_attention import LocalAttention +from local_attention.transformer import LocalTransformer, LocalMHA diff --git a/main/model/local_attention/local_attention.py b/main/model/local_attention/local_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..5d99b81e1e8b517e070612c45b8c5617c583b879 --- /dev/null +++ b/main/model/local_attention/local_attention.py @@ -0,0 +1,199 @@ +import math + +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, repeat, pack, unpack + +from local_attention.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb + +# constant + +TOKEN_SELF_ATTN_VALUE = -5e4 + +# helper functions + +def exists(val): + return val is not None + +def default(value, d): + return d if not exists(value) else value + +def to(t): + return {'device': t.device, 'dtype': t.dtype} + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + +def l2norm(tensor): + dtype = tensor.dtype + normed = F.normalize(tensor, dim = -1) + return normed.type(dtype) + +def pad_to_multiple(tensor, multiple, dim=-1, value=0): + seqlen = tensor.shape[dim] + m = seqlen / multiple + if m.is_integer(): + return False, tensor + remainder = math.ceil(m) * multiple - seqlen + pad_offset = (0,) * (-1 - dim) * 2 + return True, F.pad(tensor, (*pad_offset, 0, remainder), value = value) + +def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2): + t = x.shape[1] + dims = (len(x.shape) - dim) * (0, 0) + padded_x = F.pad(x, (*dims, backward, forward), value = pad_value) + tensors = [padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)] + return torch.cat(tensors, dim = dim) + +# main class + +class LocalAttention(nn.Module): + def __init__( + self, + window_size, + causal = False, + look_backward = 1, + look_forward = None, + dropout = 0., + shared_qk = False, + rel_pos_emb_config = None, + dim = None, + autopad = False, + exact_windowsize = False + ): + super().__init__() + look_forward = default(look_forward, 0 if causal else 1) + assert not (causal and look_forward > 0), 'you cannot look forward if causal' + + self.window_size = window_size + self.autopad = autopad + self.exact_windowsize = exact_windowsize + + self.causal = causal + + self.look_backward = look_backward + self.look_forward = look_forward + + self.dropout = nn.Dropout(dropout) + + self.shared_qk = shared_qk + + # relative positions + + # self.rel_pos = None + # if exists(rel_pos_emb_config) or exists(dim): # backwards compatible with old `rel_pos_emb_config` deprecated argument + # if exists(rel_pos_emb_config): + # dim = rel_pos_emb_config[0] + # self.rel_pos = SinusoidalEmbeddings(dim) + + def forward(self, q, k, v, packed_shape, mask = None, input_mask = None): + mask = default(mask, input_mask) + + autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = self.autopad, -1, self.window_size, self.causal, self.look_backward, self.look_forward, self.shared_qk + + # https://github.com/arogozhnikov/einops/blob/master/docs/4-pack-and-unpack.ipynb + # (q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v)) # (2, 8, 2048, 64) -> (16, 2048, 64) + + # rotary embeddings + + # if exists(self.rel_pos): + # pos_emb = self.rel_pos(q) # (16, 2048, 64) + # q, k = apply_rotary_pos_emb(q, k, pos_emb) + + # auto padding + + # if autopad: + # orig_seq_len = q.shape[1] + # (needed_pad, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v)) + + b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype + scale = dim_head ** -0.5 + + if n % window_size != 0: + print('sequence length must be divisible by window size for local attention', n, window_size) + assert True + + windows = n // window_size + + if shared_qk: + k = l2norm(k) + + seq = torch.arange(n, device = device) + b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size) + + bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v)) + + look_around_kwargs = dict( + backward = look_backward, + forward = look_forward, + pad_value = pad_value + ) + + bk = look_around(bk, **look_around_kwargs) + bv = look_around(bv, **look_around_kwargs) + + bq_t = b_t + bq_k = look_around(b_t, **look_around_kwargs) + + bq_t = rearrange(bq_t, '... i -> ... i 1') + bq_k = rearrange(bq_k, '... j -> ... 1 j') + + sim = einsum('b h i e, b h j e -> b h i j', bq, bk) * scale + + mask_value = max_neg_value(sim) + + if shared_qk: + self_mask = bq_t == bq_k + sim = sim.masked_fill(self_mask, TOKEN_SELF_ATTN_VALUE) + del self_mask + + if causal: + causal_mask = bq_t < bq_k + + if self.exact_windowsize: + max_causal_window_size = (self.window_size * self.look_backward) + causal_mask = causal_mask | (bq_t > (bq_k + max_causal_window_size)) + + sim = sim.masked_fill(causal_mask, mask_value) + del causal_mask + + # mask out padding value + + # if autopad and needed_pad: + # pad_mask = bq_k == pad_value + # sim = sim.masked_fill(pad_mask, mask_value) + # del pad_mask + + if exists(mask): + batch = mask.shape[0] + assert (b % batch) == 0 + + h = b // mask.shape[0] + + # if autopad: + # _, mask = pad_to_multiple(mask, window_size, dim = -1, value = False) + + mask = rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size) + mask = look_around(mask, **{**look_around_kwargs, 'pad_value': False}) + mask = rearrange(mask, '... j -> ... 1 j') + mask = repeat(mask, 'b ... -> (b h) ...', h = h) + sim = sim.masked_fill(~mask, mask_value) + del mask + + # attention + + attn = sim.softmax(dim = -1) + attn = self.dropout(attn) + + # aggregation + + out = einsum('b h i j, b h j e -> b h i e', attn, bv) + out = rearrange(out, 'b w n d -> b (w n) d') + + # if autopad: + # out = out[:, :orig_seq_len, :] + + out, *_ = unpack(out, packed_shape, '* n d') + return out diff --git a/main/model/local_attention/rotary.py b/main/model/local_attention/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..66c02a77521077dbca48c78505436e17d80812bf --- /dev/null +++ b/main/model/local_attention/rotary.py @@ -0,0 +1,25 @@ +import torch +from torch import nn, einsum + +from einops import rearrange + +class SinusoidalEmbeddings(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x): + n = x.shape[-2] + t = torch.arange(n, device = x.device).type_as(self.inv_freq) + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + return torch.cat((freqs, freqs), dim=-1) + +def rotate_half(x): + x = rearrange(x, 'b ... (r d) -> b (...) r d', r = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +def apply_rotary_pos_emb(q, k, freqs): + q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) + return q, k diff --git a/main/model/local_attention/transformer.py b/main/model/local_attention/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..210fb48d3c2984243cd3b687d2716a9180be1922 --- /dev/null +++ b/main/model/local_attention/transformer.py @@ -0,0 +1,179 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from einops import rearrange + +from local_attention.local_attention import LocalAttention + +# helper function + +def exists(val): + return val is not None + +def eval_decorator(fn): + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + return inner + +# sampling functions + +def top_k(logits, thres = 0.9): + k = int((1 - thres) * logits.shape[-1]) + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float('-inf')) + probs.scatter_(1, ind, val) + return probs + +# multi-head attention + +class LocalMHA(nn.Module): + def __init__( + self, + *, + dim, + window_size, + dim_head = 64, + heads = 8, + dropout = 0., + causal = False, + prenorm = False, + **kwargs + ): + super().__init__() + inner_dim = dim_head * heads + + self.norm = nn.LayerNorm(dim) if prenorm else None + + self.heads = heads + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.attn_fn = LocalAttention( + dim = dim_head, + window_size = window_size, + causal = causal, + autopad = True, + exact_windowsize = True, + **kwargs + ) + + self.to_out = nn.Linear(inner_dim, dim, bias = False) + + def forward(self, x, mask = None): + if exists(self.norm): + x = self.norm(x) + + q, k, v = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) + + out = self.attn_fn(q, k, v, mask = mask) + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +# feedforward + +class GEGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim = -1) + return x * F.gelu(gate) + +def FeedForward(dim, mult = 4, dropout = 0.): + inner_dim = int(dim * mult * 2 / 3) + + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim * 2, bias = False), + GEGLU(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim, bias = False) + ) + +# main transformer class + +class LocalTransformer(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + dim, + depth, + causal = True, + local_attn_window_size = 512, + dim_head = 64, + heads = 8, + ff_mult = 4, + attn_dropout = 0., + ff_dropout = 0., + ignore_index = -1, + **kwargs + ): + super().__init__() + self.token_emb = nn.Embedding(num_tokens, dim) + self.pos_emb = nn.Embedding(max_seq_len, dim) + + self.max_seq_len = max_seq_len + self.layers = nn.ModuleList([]) + + for _ in range(depth): + self.layers.append(nn.ModuleList([ + LocalMHA(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal, window_size = local_attn_window_size, prenorm = True, **kwargs), + FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) + ])) + + self.ignore_index = ignore_index + self.to_logits = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, num_tokens, bias = False) + ) + + @torch.no_grad() + @eval_decorator + def generate( + self, + prime, + seq_len, + temperature = 1., + filter_thres = 0.9, + **kwargs + ): + n, device = prime.shape[1], prime.device + + out = prime + + for _ in range(seq_len): + logits = self.forward(out[:, -self.max_seq_len:], **kwargs) + filtered_logits = top_k(logits[:, -1], thres = filter_thres) + probs = F.softmax(filtered_logits / temperature, dim = -1) + sampled = torch.multinomial(probs, 1) + out = torch.cat((out, sampled), dim = -1) + + return out[:, n:] + + def forward(self, x, mask = None, return_loss = False): + if return_loss: + x, labels = x[:, :-1], x[:, 1:] + + n, device = x.shape[1], x.device + x = self.token_emb(x) + + assert n <= self.max_seq_len + x = x + self.pos_emb(torch.arange(n, device = device)) + + for attn, ff in self.layers: + x = attn(x, mask = mask) + x + x = ff(x) + x + + logits = self.to_logits(x) + + if not return_loss: + return logits + + logits = rearrange(logits, 'b n c -> b c n') + loss = F.cross_entropy(logits, labels, ignore_index = self.ignore_index) + return loss diff --git a/main/model/mdm.py b/main/model/mdm.py new file mode 100644 index 0000000000000000000000000000000000000000..eba32deafe38a1c31ac5e750539e78fecffbdca0 --- /dev/null +++ b/main/model/mdm.py @@ -0,0 +1,575 @@ +import pdb + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from local_attention.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb +from local_attention import LocalAttention + +class MDM(nn.Module): + def __init__(self, modeltype, njoints, nfeats, + latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, + ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512, + arch='trans_enc', emb_trans_dec=False, audio_feat='', n_seed=1, cond_mode='', **kargs): + super().__init__() + + self.legacy = legacy + self.modeltype = modeltype + self.njoints = njoints + self.nfeats = nfeats + self.data_rep = data_rep + self.dataset = dataset + + self.latent_dim = latent_dim + + self.ff_size = ff_size + self.num_layers = num_layers + self.num_heads = num_heads + self.dropout = dropout + + self.ablation = ablation + self.activation = activation + self.clip_dim = clip_dim + self.action_emb = kargs.get('action_emb', None) + + self.input_feats = self.njoints * self.nfeats + + self.normalize_output = kargs.get('normalize_encoder_output', False) + + self.cond_mask_prob = kargs.get('cond_mask_prob', 0.) + self.arch = arch + self.gru_emb_dim = self.latent_dim if self.arch == 'gru' else 0 + + self.audio_feat = audio_feat + if audio_feat == 'wav encoder': + self.audio_feat_dim = 32 + elif audio_feat == 'mfcc': + self.audio_feat_dim = 13 + elif self.audio_feat == 'wavlm': + print('USE WAVLM') + self.audio_feat_dim = 64 # Linear 1024 -> 64 + self.WavEncoder = WavEncoder() + + self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) + self.emb_trans_dec = emb_trans_dec + + self.cond_mode = cond_mode + self.num_head = 8 + + if 'style2' not in self.cond_mode: + self.input_process = InputProcess(self.data_rep, self.input_feats + self.audio_feat_dim + self.gru_emb_dim, self.latent_dim) + + if self.arch == 'mytrans_enc': + print("MY TRANS_ENC init") + from mytransformer import TransformerEncoderLayer, TransformerEncoder + + self.embed_positions = RoFormerSinusoidalPositionalEmbedding(1536, self.latent_dim) + + seqTransEncoderLayer = TransformerEncoderLayer(d_model=self.latent_dim, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=self.activation) + self.seqTransEncoder = TransformerEncoder(seqTransEncoderLayer, + num_layers=self.num_layers) + + elif self.arch == 'trans_enc': + print("TRANS_ENC init") + seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=self.activation) + + self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, + num_layers=self.num_layers) + elif self.arch == 'trans_dec': + print("TRANS_DEC init") + seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=activation) + self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer, + num_layers=self.num_layers) + elif self.arch == 'gru': + print("GRU init") + self.gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=False) + else: + raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]') + + self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) + self.n_seed = n_seed + if 'style1' in self.cond_mode: + print('EMBED STYLE BEGIN TOKEN') + if self.n_seed != 0: + self.style_dim = 64 + self.embed_style = nn.Linear(6, self.style_dim) + self.embed_text = nn.Linear(self.njoints * n_seed, self.latent_dim - self.style_dim) + else: + self.style_dim = self.latent_dim + self.embed_style = nn.Linear(6, self.style_dim) + + elif 'style2' in self.cond_mode: + print('EMBED STYLE ALL FRAMES') + self.style_dim = 64 + self.embed_style = nn.Linear(6, self.style_dim) + self.input_process = InputProcess(self.data_rep, self.input_feats + self.audio_feat_dim + self.gru_emb_dim + self.style_dim, + self.latent_dim) + if self.n_seed != 0: + self.embed_text = nn.Linear(self.njoints * n_seed, self.latent_dim) + elif self.n_seed != 0: + self.embed_text = nn.Linear(self.njoints * n_seed, self.latent_dim) + + self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints, + self.nfeats) + + if 'cross_local_attention' in self.cond_mode: + self.rel_pos = SinusoidalEmbeddings(self.latent_dim // self.num_head) + self.input_process = InputProcess(self.data_rep, self.input_feats + self.gru_emb_dim, self.latent_dim) + self.cross_local_attention = LocalAttention( + dim=32, # dimension of each head (you need to pass this in for relative positional encoding) + window_size=11, # window size. 512 is optimal, but 256 or 128 yields good enough results + causal=True, # auto-regressive or not + look_backward=1, # each window looks at the window before + look_forward=0, # for non-auto-regressive case, will default to 1, so each window looks at the window before and after it + dropout=0.1, # post-attention dropout + exact_windowsize=False + # if this is set to true, in the causal setting, each query will see at maximum the number of keys equal to the window size + ) + self.input_process2 = nn.Linear(self.latent_dim * 2 + self.audio_feat_dim, self.latent_dim) + + if 'cross_local_attention2' in self.cond_mode: + print('Cross Local Attention2') + self.selfAttention = LinearTemporalCrossAttention(seq_len=0, latent_dim=256, text_latent_dim=256, num_head=8, dropout=0.1, time_embed_dim=0) + + if 'cross_local_attention3' in self.cond_mode: + print('Cross Local Attention3') + + if 'cross_local_attention4' in self.cond_mode: + print('Cross Local Attention4') + + def parameters_wo_clip(self): + return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] + + def mask_cond(self, cond, force_mask=False): + bs, d = cond.shape + if force_mask: + return torch.zeros_like(cond) + elif self.training and self.cond_mask_prob > 0.: + mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond + return cond * (1. - mask) + else: + return cond + + def forward(self, x, timesteps, y=None): + """ + x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper + timesteps: [batch_size] (int) + seed: [batch_size, njoints, nfeats] + """ + + bs, njoints, nfeats, nframes = x.shape # 64, 251, 1, 196 + emb_t = self.embed_timestep(timesteps) # [1, bs, d], (1, 2, 256) + + force_mask = y.get('uncond', False) # False + if 'style1' in self.cond_mode: + embed_style = self.mask_cond(self.embed_style(y['style']), force_mask=force_mask) # (bs, 64) + if self.n_seed != 0: + embed_text = self.embed_text(self.mask_cond(y['seed'].squeeze(2).reshape(bs, -1), force_mask=force_mask)) # (bs, 256-64) + emb_1 = torch.cat((embed_style, embed_text), dim=1) + else: + emb_1 = embed_style + elif self.n_seed != 0: + emb_1 = self.embed_text(self.mask_cond(y['seed'].squeeze(2).reshape(bs, -1), force_mask=force_mask)) # z_tk + + if self.audio_feat == 'wavlm': + enc_text = self.WavEncoder(y['audio']).permute(1, 0, 2) + else: + enc_text = y['audio'] + + if 'cross_local_attention' in self.cond_mode: + if 'cross_local_attention3' in self.cond_mode: + x = x.reshape(bs, njoints * nfeats, 1, nframes) # [2, 135, 1, 240] + # self-attention + x_ = self.input_process(x) # [2, 135, 1, 240] -> [240, 2, 256] + + # local-cross-attention + packed_shape = [torch.Size([bs, self.num_head])] + xseq = torch.cat((x_, enc_text), axis=2) # [bs, d+joints*feat, 1, #frames], (240, 2, 32) + # all frames + embed_style_2 = (emb_1 + emb_t).repeat(nframes, 1, 1) # (bs, 64) -> (len, bs, 64) + xseq = torch.cat((embed_style_2, xseq), axis=2) # (seq, bs, dim) + xseq = self.input_process2(xseq) + xseq = xseq.permute(1, 0, 2) # (bs, len, dim) + xseq = xseq.view(bs, nframes, self.num_head, -1) + xseq = xseq.permute(0, 2, 1, 3) # Need (2, 8, 2048, 64) + xseq = xseq.reshape(bs * self.num_head, nframes, -1) + pos_emb = self.rel_pos(xseq) # (89, 32) + xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) + xseq = self.cross_local_attention(xseq, xseq, xseq, packed_shape=packed_shape, + mask=y['mask_local']) # (2, 8, 2048, 64) + xseq = xseq.permute(0, 2, 1, 3) # (bs, len, 8, 64) + xseq = xseq.reshape(bs, nframes, -1) + xseq = xseq.permute(1, 0, 2) + + xseq = torch.cat((emb_1 + emb_t, xseq), axis=0) # [seqlen+1, bs, d] # [(1, 2, 256), (240, 2, 256)] -> (241, 2, 256) + xseq = xseq.permute(1, 0, 2) # (bs, len, dim) + xseq = xseq.view(bs, nframes + 1, self.num_head, -1) + xseq = xseq.permute(0, 2, 1, 3) # Need (2, 8, 2048, 64) + xseq = xseq.reshape(bs * self.num_head, nframes + 1, -1) + pos_emb = self.rel_pos(xseq) # (89, 32) + xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) + xseq_rpe = xseq.reshape(bs, self.num_head, nframes + 1, -1) + xseq = xseq_rpe.permute(0, 2, 1, 3) # [seqlen+1, bs, d] + xseq = xseq.view(bs, nframes + 1, -1) + xseq = xseq.permute(1, 0, 2) + if 'cross_local_attention2' in self.cond_mode: + xseq = (self.selfAttention(xseq).permute(1, 0, 2))[1:] + else: + output = self.seqTransEncoder(xseq)[1:] + + elif 'cross_local_attention5' in self.cond_mode: + x = x.reshape(bs, njoints * nfeats, 1, nframes) # [2, 135, 1, 240] + # self-attention + x_ = self.input_process(x) # [2, 135, 1, 240] -> [240, 2, 256] + + # local-cross-attention + packed_shape = [torch.Size([bs, self.num_head])] + xseq = torch.cat((x_, enc_text), axis=2) # [bs, d+joints*feat, 1, #frames], (240, 2, 32) + # all frames + embed_style_2 = (emb_1 + emb_t).repeat(nframes, 1, 1) # (bs, 64) -> (len, bs, 64) + xseq = torch.cat((embed_style_2, xseq), axis=2) # (seq, bs, dim) + xseq = self.input_process2(xseq) + xseq = xseq.permute(1, 0, 2) # (bs, len, dim) + xseq = xseq.view(bs, nframes, self.num_head, -1) + xseq = xseq.permute(0, 2, 1, 3) # Need (2, 8, 2048, 64) + xseq = xseq.reshape(bs * self.num_head, nframes, -1) + pos_emb = self.rel_pos(xseq) # (89, 32) + xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) + xseq = self.cross_local_attention(xseq, xseq, xseq, packed_shape=packed_shape, + mask=y['mask_local']) # (2, 8, 2048, 64) + xseq = xseq.permute(0, 2, 1, 3) # (bs, len, 8, 64) + xseq = xseq.reshape(bs, nframes, -1) + output = xseq.permute(1, 0, 2) + + else: + x = x.reshape(bs, njoints*nfeats, 1, nframes) # [2, 135, 1, 240] + # self-attention + x_ = self.input_process(x) # [2, 135, 1, 240] -> [240, 2, 256] + xseq = torch.cat((emb_1 + emb_t, x_), axis=0) # [seqlen+1, bs, d] # [(1, 2, 256), (240, 2, 256)] -> (241, 2, 256) + xseq = xseq.permute(1, 0, 2) # (bs, len, dim) + xseq = xseq.view(bs, nframes + 1, self.num_head, -1) + xseq = xseq.permute(0, 2, 1, 3) # Need (2, 8, 2048, 64) + xseq = xseq.reshape(bs*self.num_head, nframes + 1, -1) + pos_emb = self.rel_pos(xseq) # (89, 32) + xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) + xseq_rpe = xseq.reshape(bs, self.num_head, nframes + 1, -1) + xseq = xseq_rpe.permute(0, 2, 1, 3) # [seqlen+1, bs, d] + xseq = xseq.view(bs, nframes + 1, -1) + xseq = xseq.permute(1, 0, 2) + if 'cross_local_attention2' in self.cond_mode: + xseq = (self.selfAttention(xseq).permute(1, 0, 2))[1:] + else: + xseq = self.seqTransEncoder(xseq)[1:] + + # local-cross-attention + packed_shape = [torch.Size([bs, self.num_head])] + xseq = torch.cat((xseq, enc_text), axis=2) #[bs, d+joints*feat, 1, #frames], (240, 2, 32) + # all frames + embed_style_2 = (emb_1 + emb_t).repeat(nframes, 1, 1) # (bs, 64) -> (len, bs, 64) + xseq = torch.cat((embed_style_2, xseq), axis=2) # (seq, bs, dim) + xseq = self.input_process2(xseq) + xseq = xseq.permute(1, 0, 2) # (bs, len, dim) + xseq = xseq.view(bs, nframes, self.num_head, -1) + xseq = xseq.permute(0, 2, 1, 3) # Need (2, 8, 2048, 64) + xseq = xseq.reshape(bs * self.num_head, nframes, -1) + pos_emb = self.rel_pos(xseq) # (89, 32) + xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) + xseq = self.cross_local_attention(xseq, xseq, xseq, packed_shape=packed_shape, mask=y['mask_local']) # (2, 8, 2048, 64) + xseq = xseq.permute(0, 2, 1, 3) # (bs, len, 8, 64) + xseq = xseq.reshape(bs, nframes, -1) + output = xseq.permute(1, 0, 2) + + else: + if self.arch == 'trans_enc' or self.arch == 'trans_dec' or self.arch == 'conformers_enc' or self.arch == 'mytrans_enc': + x_reshaped = x.reshape(bs, njoints*nfeats, 1, nframes) # [2, 135, 1, 240] + enc_text_gru = enc_text.permute(1, 2, 0) # (240, 2, 32) -> (2, 32, 240) + enc_text_gru = enc_text_gru.reshape(bs, self.audio_feat_dim, 1, nframes) + x = torch.cat((x_reshaped, enc_text_gru), axis=1) #[bs, d+joints*feat, 1, #frames] + if 'style2' in self.cond_mode: + embed_style = self.mask_cond(self.embed_style(y['style']), force_mask=force_mask).repeat(nframes, 1, 1) # (#frames, bs, 64) + embed_style = embed_style.unsqueeze(2) + embed_style = embed_style.permute(1, 3, 2, 0) + x = torch.cat((x, embed_style), axis=1) # [bs, d+joints*feat, 1, #frames] + + if self.arch == 'gru': + x_reshaped = x.reshape(bs, njoints*nfeats, 1, nframes) # [2, 135, 1, 240] + emb_gru = emb.repeat(nframes, 1, 1) #[#frames, bs, d] + + enc_text_gru = enc_text.permute(1, 2, 0) # (240, 2, 32) -> (2, 32, 240) + enc_text_gru = enc_text_gru.reshape(bs, self.audio_feat_dim, 1, nframes) + + emb_gru = emb_gru.permute(1, 2, 0) #[bs, d, #frames] + emb_gru = emb_gru.reshape(bs, self.latent_dim, 1, nframes) #[bs, d, 1, #frames] + x = torch.cat((x_reshaped, emb_gru, enc_text_gru), axis=1) #[bs, d+joints*feat, 1, #frames] + + x = self.input_process(x) # [2, 135, 1, 240] -> [240, 2, 224] + + if self.arch == 'trans_enc': + # adding the timestep embed + # x = torch.cat((x, enc_text), axis=2) # [[240, 2, 224], (240, 2, 32)] -> (240, 2, 256) + xseq = torch.cat((emb, x), axis=0) # [seqlen+1, bs, d] # [(1, 2, 256), (240, 2, 256)] -> (241, 2, 256) + + xseq = self.sequence_pos_encoder(xseq) # [seqlen+1, bs, d] + output = self.seqTransEncoder(xseq)[1:] # , src_key_padding_mask=~maskseq) # [seqlen, bs, d] # -> (240, 2, 256) + + elif self.arch == 'trans_dec': + if self.emb_trans_dec: + xseq = torch.cat((emb, x), axis=0) + else: + xseq = x + xseq = self.sequence_pos_encoder(xseq) # [seqlen+1, bs, d] + if self.emb_trans_dec: + output = self.seqTransDecoder(tgt=xseq, memory=emb)[1:] # [seqlen, bs, d] # FIXME - maybe add a causal mask + else: + output = self.seqTransDecoder(tgt=xseq, memory=emb) + + elif self.arch == 'gru': + xseq = x + xseq = self.sequence_pos_encoder(xseq) # [seqlen, bs, d] + # pdb.set_trace() + output, _ = self.gru(xseq) + + elif self.arch == 'mytrans_enc': + # adding the timestep embed + # x = torch.cat((x, enc_text), axis=2) # [[240, 2, 224], (240, 2, 32)] -> (240, 2, 256) + xseq = torch.cat((emb, x), axis=0) # [seqlen+1, bs, d] # [(1, 2, 256), (240, 2, 256)] -> (241, 2, 256) + + sinusoidal_pos = self.embed_positions(xseq.shape[0], 0)[None, None, :, :].chunk(2, dim=-1) + xseq = self.apply_rotary(xseq.permute(1, 0, 2), sinusoidal_pos).squeeze(0).permute(1, 0, 2) + + output = self.seqTransEncoder(xseq)[1:] # , src_key_padding_mask=~maskseq) # [seqlen, bs, d] # -> (240, 2, 256) + + output = self.output_process(output) # [bs, njoints, nfeats, nframes] + return output + + + @staticmethod + def apply_rotary(x, sinusoidal_pos): + sin, cos = sinusoidal_pos + x1, x2 = x[..., 0::2], x[..., 1::2] + # ๅฆ‚ๆžœๆ˜ฏๆ—‹่ฝฌquery key็š„่ฏ๏ผŒไธ‹้ข่ฟ™ไธช็›ดๆŽฅcatๅฐฑ่กŒ๏ผŒๅ› ไธบ่ฆ่ฟ›่กŒ็Ÿฉ้˜ตไน˜ๆณ•๏ผŒๆœ€็ปˆไผšๅœจ่ฟ™ไธช็ปดๅบฆๆฑ‚ๅ’Œใ€‚๏ผˆๅช่ฆไฟๆŒqueryๅ’Œkey็š„ๆœ€ๅŽไธ€ไธชdim็š„ๆฏไธ€ไธชไฝ็ฝฎๅฏนๅบ”ไธŠๅฐฑๅฏไปฅ๏ผ‰ + # torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + # ๅฆ‚ๆžœๆ˜ฏๆ—‹่ฝฌvalue็š„่ฏ๏ผŒไธ‹้ข่ฟ™ไธชstackๅŽๅ†flattenๆ‰ๅฏไปฅ๏ผŒๅ› ไธบ่ฎญ็ปƒๅฅฝ็š„ๆจกๅž‹ๆœ€ๅŽไธ€ไธชdimๆ˜ฏไธคไธคไน‹้—ดไบคๆ›ฟ็š„ใ€‚ + return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1) + + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) # (5000, 128) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (5000, 1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + # not used in the final model + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) + + +# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer +class RoFormerSinusoidalPositionalEmbedding(nn.Embedding): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__( + self, num_positions: int, embedding_dim: int + ): + super().__init__(num_positions, embedding_dim) + self.weight = self._init_weight(self.weight) + + @staticmethod + def _init_weight(out: nn.Parameter): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + n_pos, dim = out.shape + position_enc = np.array( + [ + [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] + for pos in range(n_pos) + ] + ) + out.requires_grad = False # set early to avoid an error in pytorch-1.8+ + sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 + out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.detach_() + return out + + @torch.no_grad() + def forward(self, seq_len: int, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + positions = torch.arange( + past_key_values_length, + past_key_values_length + seq_len, + dtype=torch.long, + device=self.weight.device, + ) + return super().forward(positions) + + +class TimestepEmbedder(nn.Module): + def __init__(self, latent_dim, sequence_pos_encoder): + super().__init__() + self.latent_dim = latent_dim + self.sequence_pos_encoder = sequence_pos_encoder + + time_embed_dim = self.latent_dim + self.time_embed = nn.Sequential( + nn.Linear(self.latent_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + def forward(self, timesteps): + return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) + + +class InputProcess(nn.Module): + def __init__(self, data_rep, input_feats, latent_dim): + super().__init__() + self.data_rep = data_rep + self.input_feats = input_feats + self.latent_dim = latent_dim + self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) + if self.data_rep == 'rot_vel': + self.velEmbedding = nn.Linear(self.input_feats, self.latent_dim) + + def forward(self, x): + bs, njoints, nfeats, nframes = x.shape + x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) + + if self.data_rep in ['rot6d', 'xyz', 'hml_vec']: + x = self.poseEmbedding(x) # [seqlen, bs, d] + return x + elif self.data_rep == 'rot_vel': + first_pose = x[[0]] # [1, bs, 150] + first_pose = self.poseEmbedding(first_pose) # [1, bs, d] + vel = x[1:] # [seqlen-1, bs, 150] + vel = self.velEmbedding(vel) # [seqlen-1, bs, d] + return torch.cat((first_pose, vel), axis=0) # [seqlen, bs, d] + else: + raise ValueError + + +class OutputProcess(nn.Module): + def __init__(self, data_rep, input_feats, latent_dim, njoints, nfeats): + super().__init__() + self.data_rep = data_rep + self.input_feats = input_feats + self.latent_dim = latent_dim + self.njoints = njoints + self.nfeats = nfeats + self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) + if self.data_rep == 'rot_vel': + self.velFinal = nn.Linear(self.latent_dim, self.input_feats) + + def forward(self, output): + nframes, bs, d = output.shape + if self.data_rep in ['rot6d', 'xyz', 'hml_vec']: + output = self.poseFinal(output) # [seqlen, bs, 150] + elif self.data_rep == 'rot_vel': + first_pose = output[[0]] # [1, bs, d] + first_pose = self.poseFinal(first_pose) # [1, bs, 150] + vel = output[1:] # [seqlen-1, bs, d] + vel = self.velFinal(vel) # [seqlen-1, bs, 150] + output = torch.cat((first_pose, vel), axis=0) # [seqlen, bs, 150] + else: + raise ValueError + output = output.reshape(nframes, bs, self.njoints, self.nfeats) + output = output.permute(1, 2, 3, 0) # [bs, njoints, nfeats, nframes] + return output + + +class LinearTemporalCrossAttention(nn.Module): + + def __init__(self, seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim): + super().__init__() + self.num_head = num_head + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + self.query = nn.Linear(latent_dim, latent_dim) + self.key = nn.Linear(text_latent_dim, latent_dim) + self.value = nn.Linear(text_latent_dim, latent_dim) + self.dropout = nn.Dropout(dropout) + self.proj_out = nn.Linear(latent_dim, latent_dim) + + def forward(self, x, xf=None, emb=None): + """ + x: B, T, D , [240, 2, 256] + xf: B, N, L , [1, 2, 256] + """ + x = x.permute(1, 0, 2) + # xf = xf.permute(1, 0, 2) + B, T, D = x.shape + # N = xf.shape[1] + H = self.num_head + # B, T, D + query = self.query(self.norm(x)) + # B, N, D + key = self.key(self.text_norm(x)) + query = F.softmax(query.view(B, T, H, -1), dim=-1) + key = F.softmax(key.view(B, T, H, -1), dim=1) + # B, N, H, HD + value = self.value(self.text_norm(x)).view(B, T, H, -1) + # B, H, HD, HD + attention = torch.einsum('bnhd,bnhl->bhdl', key, value) + y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) + # y = x + self.proj_out(y, emb) + return y + + +class WavEncoder(nn.Module): + def __init__(self): + super().__init__() + self.audio_feature_map = nn.Linear(1024, 64) + + def forward(self, rep): + rep = self.audio_feature_map(rep) + return rep + + +if __name__ == '__main__': + ''' + cd ./main/model + python mdm.py + ''' + n_frames = 240 + + n_seed = 8 + + model = MDM(modeltype='', njoints=1140, nfeats=1, cond_mode = 'cross_local_attention5_style1', action_emb='tensor', audio_feat='mfcc', + arch='mytrans_enc', latent_dim=256, n_seed=n_seed, cond_mask_prob=0.1) + + x = torch.randn(2, 1140, 1, 88) + t = torch.tensor([12, 85]) + + model_kwargs_ = {'y': {}} + model_kwargs_['y']['mask'] = (torch.zeros([1, 1, 1, n_frames]) < 1) # [..., n_seed:] + model_kwargs_['y']['audio'] = torch.randn(2, 88, 13).permute(1, 0, 2) # [n_seed:, ...] + model_kwargs_['y']['style'] = torch.randn(2, 6) + model_kwargs_['y']['mask_local'] = torch.ones(2, 88).bool() + model_kwargs_['y']['seed'] = x[..., 0:n_seed] + y = model(x, t, model_kwargs_['y']) + print(y.shape) diff --git a/main/model/myactivation.py b/main/model/myactivation.py new file mode 100644 index 0000000000000000000000000000000000000000..8d83fe7727d555058b0fd93cb3d04958c8d89601 --- /dev/null +++ b/main/model/myactivation.py @@ -0,0 +1,1275 @@ +import warnings +from typing import Tuple, Optional + +import torch +from torch import Tensor +if float(torch.__version__.split('.')[0]) == 0 or (float(torch.__version__.split('.')[0]) == 1 and float(torch.__version__.split('.')[1])) < 9: + from torch.nn.modules.linear import _LinearWithBias +else: + from torch.nn.modules.linear import NonDynamicallyQuantizableLinear as _LinearWithBias +from torch.nn.init import xavier_uniform_ +from torch.nn.init import constant_ +from torch.nn.init import xavier_normal_ +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module +from torch.nn import functional as F + + +class Threshold(Module): + r"""Thresholds each element of the input Tensor. + + Threshold is defined as: + + .. math:: + y = + \begin{cases} + x, &\text{ if } x > \text{threshold} \\ + \text{value}, &\text{ otherwise } + \end{cases} + + Args: + threshold: The value to threshold at + value: The value to replace with + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + Examples:: + + >>> m = nn.Threshold(0.1, 20) + >>> input = torch.randn(2) + >>> output = m(input) + """ + __constants__ = ['threshold', 'value', 'inplace'] + + threshold: float + value: float + inplace: bool + + def __init__(self, threshold: float, value: float, inplace: bool = False) -> None: + super(Threshold, self).__init__() + self.threshold = threshold + self.value = value + self.inplace = inplace + # TODO: check in THNN (if inplace == True, then assert value <= threshold) + + def forward(self, input: Tensor) -> Tensor: + return F.threshold(input, self.threshold, self.value, self.inplace) + + def extra_repr(self): + inplace_str = ', inplace=True' if self.inplace else '' + return 'threshold={}, value={}{}'.format( + self.threshold, self.value, inplace_str + ) + + +class ReLU(Module): + r"""Applies the rectified linear unit function element-wise: + + :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)` + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/ReLU.png + + Examples:: + + >>> m = nn.ReLU() + >>> input = torch.randn(2) + >>> output = m(input) + + + An implementation of CReLU - https://arxiv.org/abs/1603.05201 + + >>> m = nn.ReLU() + >>> input = torch.randn(2).unsqueeze(0) + >>> output = torch.cat((m(input),m(-input))) + """ + __constants__ = ['inplace'] + inplace: bool + + def __init__(self, inplace: bool = False): + super(ReLU, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.relu(input, inplace=self.inplace) + + def extra_repr(self) -> str: + inplace_str = 'inplace=True' if self.inplace else '' + return inplace_str + + +class RReLU(Module): + r"""Applies the randomized leaky rectified liner unit function, element-wise, + as described in the paper: + + `Empirical Evaluation of Rectified Activations in Convolutional Network`_. + + The function is defined as: + + .. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} + + where :math:`a` is randomly sampled from uniform distribution + :math:`\mathcal{U}(\text{lower}, \text{upper})`. + + See: https://arxiv.org/pdf/1505.00853.pdf + + Args: + lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` + upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + Examples:: + + >>> m = nn.RReLU(0.1, 0.3) + >>> input = torch.randn(2) + >>> output = m(input) + + .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`: + https://arxiv.org/abs/1505.00853 + """ + __constants__ = ['lower', 'upper', 'inplace'] + + lower: float + upper: float + inplace: bool + + def __init__( + self, + lower: float = 1. / 8, + upper: float = 1. / 3, + inplace: bool = False + ): + super(RReLU, self).__init__() + self.lower = lower + self.upper = upper + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) + + def extra_repr(self): + inplace_str = ', inplace=True' if self.inplace else '' + return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str) + + +class Hardtanh(Module): + r"""Applies the HardTanh function element-wise + + HardTanh is defined as: + + .. math:: + \text{HardTanh}(x) = \begin{cases} + 1 & \text{ if } x > 1 \\ + -1 & \text{ if } x < -1 \\ + x & \text{ otherwise } \\ + \end{cases} + + The range of the linear region :math:`[-1, 1]` can be adjusted using + :attr:`min_val` and :attr:`max_val`. + + Args: + min_val: minimum value of the linear region range. Default: -1 + max_val: maximum value of the linear region range. Default: 1 + inplace: can optionally do the operation in-place. Default: ``False`` + + Keyword arguments :attr:`min_value` and :attr:`max_value` + have been deprecated in favor of :attr:`min_val` and :attr:`max_val`. + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/Hardtanh.png + + Examples:: + + >>> m = nn.Hardtanh(-2, 2) + >>> input = torch.randn(2) + >>> output = m(input) + """ + __constants__ = ['min_val', 'max_val', 'inplace'] + + min_val: float + max_val: float + inplace: bool + + def __init__( + self, + min_val: float = -1., + max_val: float = 1., + inplace: bool = False, + min_value: Optional[float] = None, + max_value: Optional[float] = None + ) -> None: + super(Hardtanh, self).__init__() + if min_value is not None: + warnings.warn("keyword argument min_value is deprecated and rename to min_val") + min_val = min_value + if max_value is not None: + warnings.warn("keyword argument max_value is deprecated and rename to max_val") + max_val = max_value + + self.min_val = min_val + self.max_val = max_val + self.inplace = inplace + assert self.max_val > self.min_val + + def forward(self, input: Tensor) -> Tensor: + return F.hardtanh(input, self.min_val, self.max_val, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ', inplace=True' if self.inplace else '' + return 'min_val={}, max_val={}{}'.format( + self.min_val, self.max_val, inplace_str + ) + + +class ReLU6(Hardtanh): + r"""Applies the element-wise function: + + .. math:: + \text{ReLU6}(x) = \min(\max(0,x), 6) + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/ReLU6.png + + Examples:: + + >>> m = nn.ReLU6() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def __init__(self, inplace: bool = False): + super(ReLU6, self).__init__(0., 6., inplace) + + def extra_repr(self) -> str: + inplace_str = 'inplace=True' if self.inplace else '' + return inplace_str + + +class Sigmoid(Module): + r"""Applies the element-wise function: + + .. math:: + \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} + + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/Sigmoid.png + + Examples:: + + >>> m = nn.Sigmoid() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + return torch.sigmoid(input) + + +class Hardsigmoid(Module): + r"""Applies the element-wise function: + + .. math:: + \text{Hardsigmoid}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + 1 & \text{if~} x \ge +3, \\ + x / 6 + 1 / 2 & \text{otherwise} + \end{cases} + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + Examples:: + + >>> m = nn.Hardsigmoid() + >>> input = torch.randn(2) + >>> output = m(input) + """ + __constants__ = ['inplace'] + + inplace: bool + + def __init__(self, inplace : bool = False) -> None: + super(Hardsigmoid, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.hardsigmoid(input, self.inplace) + + +class Tanh(Module): + r"""Applies the element-wise function: + + .. math:: + \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)} + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/Tanh.png + + Examples:: + + >>> m = nn.Tanh() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + return torch.tanh(input) + +class SiLU(Module): + r"""Applies the silu function, element-wise. + + .. math:: + \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} + + .. note:: + See `Gaussian Error Linear Units (GELUs) `_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning `_ and `Swish: + a Self-Gated Activation Function `_ + where the SiLU was experimented with later. + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + Examples:: + + >>> m = nn.SiLU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + __constants__ = ['inplace'] + inplace: bool + + def __init__(self, inplace: bool = False): + super(SiLU, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.silu(input, inplace=self.inplace) + + def extra_repr(self) -> str: + inplace_str = 'inplace=True' if self.inplace else '' + return inplace_str + +class Hardswish(Module): + r"""Applies the hardswish function, element-wise, as described in the paper: + + `Searching for MobileNetV3`_. + + .. math:: + \text{Hardswish}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + x & \text{if~} x \ge +3, \\ + x \cdot (x + 3) /6 & \text{otherwise} + \end{cases} + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + Examples:: + + >>> m = nn.Hardswish() + >>> input = torch.randn(2) + >>> output = m(input) + + .. _`Searching for MobileNetV3`: + https://arxiv.org/abs/1905.02244 + """ + __constants__ = ['inplace'] + + inplace: bool + + def __init__(self, inplace : bool = False) -> None: + super(Hardswish, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.hardswish(input, self.inplace) + + +class ELU(Module): + r"""Applies the element-wise function: + + .. math:: + \text{ELU}(x) = \begin{cases} + x, & \text{ if } x > 0\\ + \alpha * (\exp(x) - 1), & \text{ if } x \leq 0 + \end{cases} + + Args: + alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/ELU.png + + Examples:: + + >>> m = nn.ELU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + __constants__ = ['alpha', 'inplace'] + alpha: float + inplace: bool + + def __init__(self, alpha: float = 1., inplace: bool = False) -> None: + super(ELU, self).__init__() + self.alpha = alpha + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.elu(input, self.alpha, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ', inplace=True' if self.inplace else '' + return 'alpha={}{}'.format(self.alpha, inplace_str) + + +class CELU(Module): + r"""Applies the element-wise function: + + .. math:: + \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1)) + + More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ . + + Args: + alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/CELU.png + + Examples:: + + >>> m = nn.CELU() + >>> input = torch.randn(2) + >>> output = m(input) + + .. _`Continuously Differentiable Exponential Linear Units`: + https://arxiv.org/abs/1704.07483 + """ + __constants__ = ['alpha', 'inplace'] + alpha: float + inplace: bool + + def __init__(self, alpha: float = 1., inplace: bool = False) -> None: + super(CELU, self).__init__() + self.alpha = alpha + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.celu(input, self.alpha, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ', inplace=True' if self.inplace else '' + return 'alpha={}{}'.format(self.alpha, inplace_str) + + +class SELU(Module): + r"""Applied element-wise, as: + + .. math:: + \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))) + + with :math:`\alpha = 1.6732632423543772848170429916717` and + :math:`\text{scale} = 1.0507009873554804934193349852946`. + + More details can be found in the paper `Self-Normalizing Neural Networks`_ . + + Args: + inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/SELU.png + + Examples:: + + >>> m = nn.SELU() + >>> input = torch.randn(2) + >>> output = m(input) + + .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 + """ + __constants__ = ['inplace'] + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super(SELU, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.selu(input, self.inplace) + + def extra_repr(self) -> str: + inplace_str = 'inplace=True' if self.inplace else '' + return inplace_str + + +class GLU(Module): + r"""Applies the gated linear unit function + :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half + of the input matrices and :math:`b` is the second half. + + Args: + dim (int): the dimension on which to split the input. Default: -1 + + Shape: + - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional + dimensions + - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` + + Examples:: + + >>> m = nn.GLU() + >>> input = torch.randn(4, 2) + >>> output = m(input) + """ + __constants__ = ['dim'] + dim: int + + def __init__(self, dim: int = -1) -> None: + super(GLU, self).__init__() + self.dim = dim + + def forward(self, input: Tensor) -> Tensor: + return F.glu(input, self.dim) + + def extra_repr(self) -> str: + return 'dim={}'.format(self.dim) + + +class GELU(Module): + r"""Applies the Gaussian Error Linear Units function: + + .. math:: \text{GELU}(x) = x * \Phi(x) + + where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/GELU.png + + Examples:: + + >>> m = nn.GELU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + def forward(self, input: Tensor) -> Tensor: + return F.gelu(input) + + +class Hardshrink(Module): + r"""Applies the hard shrinkage function element-wise: + + .. math:: + \text{HardShrink}(x) = + \begin{cases} + x, & \text{ if } x > \lambda \\ + x, & \text{ if } x < -\lambda \\ + 0, & \text{ otherwise } + \end{cases} + + Args: + lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/Hardshrink.png + + Examples:: + + >>> m = nn.Hardshrink() + >>> input = torch.randn(2) + >>> output = m(input) + """ + __constants__ = ['lambd'] + lambd: float + + def __init__(self, lambd: float = 0.5) -> None: + super(Hardshrink, self).__init__() + self.lambd = lambd + + def forward(self, input: Tensor) -> Tensor: + return F.hardshrink(input, self.lambd) + + def extra_repr(self) -> str: + return '{}'.format(self.lambd) + + +class LeakyReLU(Module): + r"""Applies the element-wise function: + + .. math:: + \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x) + + + or + + .. math:: + \text{LeakyRELU}(x) = + \begin{cases} + x, & \text{ if } x \geq 0 \\ + \text{negative\_slope} \times x, & \text{ otherwise } + \end{cases} + + Args: + negative_slope: Controls the angle of the negative slope. Default: 1e-2 + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/LeakyReLU.png + + Examples:: + + >>> m = nn.LeakyReLU(0.1) + >>> input = torch.randn(2) + >>> output = m(input) + """ + __constants__ = ['inplace', 'negative_slope'] + inplace: bool + negative_slope: float + + def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: + super(LeakyReLU, self).__init__() + self.negative_slope = negative_slope + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.leaky_relu(input, self.negative_slope, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ', inplace=True' if self.inplace else '' + return 'negative_slope={}{}'.format(self.negative_slope, inplace_str) + + +class LogSigmoid(Module): + r"""Applies the element-wise function: + + .. math:: + \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right) + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/LogSigmoid.png + + Examples:: + + >>> m = nn.LogSigmoid() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + return F.logsigmoid(input) + + +class Softplus(Module): + r"""Applies the element-wise function: + + .. math:: + \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) + + SoftPlus is a smooth approximation to the ReLU function and can be used + to constrain the output of a machine to always be positive. + + For numerical stability the implementation reverts to the linear function + when :math:`input \times \beta > threshold`. + + Args: + beta: the :math:`\beta` value for the Softplus formulation. Default: 1 + threshold: values above this revert to a linear function. Default: 20 + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/Softplus.png + + Examples:: + + >>> m = nn.Softplus() + >>> input = torch.randn(2) + >>> output = m(input) + """ + __constants__ = ['beta', 'threshold'] + beta: int + threshold: int + + def __init__(self, beta: int = 1, threshold: int = 20) -> None: + super(Softplus, self).__init__() + self.beta = beta + self.threshold = threshold + + def forward(self, input: Tensor) -> Tensor: + return F.softplus(input, self.beta, self.threshold) + + def extra_repr(self) -> str: + return 'beta={}, threshold={}'.format(self.beta, self.threshold) + + +class Softshrink(Module): + r"""Applies the soft shrinkage function elementwise: + + .. math:: + \text{SoftShrinkage}(x) = + \begin{cases} + x - \lambda, & \text{ if } x > \lambda \\ + x + \lambda, & \text{ if } x < -\lambda \\ + 0, & \text{ otherwise } + \end{cases} + + Args: + lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/Softshrink.png + + Examples:: + + >>> m = nn.Softshrink() + >>> input = torch.randn(2) + >>> output = m(input) + """ + __constants__ = ['lambd'] + lambd: float + + def __init__(self, lambd: float = 0.5) -> None: + super(Softshrink, self).__init__() + self.lambd = lambd + + def forward(self, input: Tensor) -> Tensor: + return F.softshrink(input, self.lambd) + + def extra_repr(self) -> str: + return str(self.lambd) + + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if self._qkv_same_embed_dim is False: + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + self.register_parameter('in_proj_weight', None) + else: + self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) + self.register_parameter('q_proj_weight', None) + self.register_parameter('k_proj_weight', None) + self.register_parameter('v_proj_weight', None) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + self.out_proj = _LinearWithBias(embed_dim, embed_dim) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) + self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.) + constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if '_qkv_same_embed_dim' not in state: + state['_qkv_same_embed_dim'] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward(self, query, key, value, key_padding_mask=None, + need_weights=True, attn_mask=None): + # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not self._qkv_same_embed_dim: + return F.multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight) + else: + return F.multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask) + + +class PReLU(Module): + r"""Applies the element-wise function: + + .. math:: + \text{PReLU}(x) = \max(0,x) + a * \min(0,x) + + or + + .. math:: + \text{PReLU}(x) = + \begin{cases} + x, & \text{ if } x \geq 0 \\ + ax, & \text{ otherwise } + \end{cases} + + Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single + parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, + a separate :math:`a` is used for each input channel. + + + .. note:: + weight decay should not be used when learning :math:`a` for good performance. + + .. note:: + Channel dim is the 2nd dim of input. When input has dims < 2, then there is + no channel dim and the number of channels = 1. + + Args: + num_parameters (int): number of :math:`a` to learn. + Although it takes an int as input, there is only two values are legitimate: + 1, or the number of channels at input. Default: 1 + init (float): the initial value of :math:`a`. Default: 0.25 + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + Attributes: + weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). + + .. image:: ../scripts/activation_images/PReLU.png + + Examples:: + + >>> m = nn.PReLU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + __constants__ = ['num_parameters'] + num_parameters: int + + def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None: + self.num_parameters = num_parameters + super(PReLU, self).__init__() + self.weight = Parameter(torch.Tensor(num_parameters).fill_(init)) + + def forward(self, input: Tensor) -> Tensor: + return F.prelu(input, self.weight) + + def extra_repr(self) -> str: + return 'num_parameters={}'.format(self.num_parameters) + + +class Softsign(Module): + r"""Applies the element-wise function: + + .. math:: + \text{SoftSign}(x) = \frac{x}{ 1 + |x|} + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/Softsign.png + + Examples:: + + >>> m = nn.Softsign() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + return F.softsign(input) + + +class Tanhshrink(Module): + r"""Applies the element-wise function: + + .. math:: + \text{Tanhshrink}(x) = x - \tanh(x) + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/Tanhshrink.png + + Examples:: + + >>> m = nn.Tanhshrink() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + return F.tanhshrink(input) + + +class Softmin(Module): + r"""Applies the Softmin function to an n-dimensional input Tensor + rescaling them so that the elements of the n-dimensional output Tensor + lie in the range `[0, 1]` and sum to 1. + + Softmin is defined as: + + .. math:: + \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)} + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Arguments: + dim (int): A dimension along which Softmin will be computed (so every slice + along dim will sum to 1). + + Returns: + a Tensor of the same dimension and shape as the input, with + values in the range [0, 1] + + Examples:: + + >>> m = nn.Softmin() + >>> input = torch.randn(2, 3) + >>> output = m(input) + """ + __constants__ = ['dim'] + dim: Optional[int] + + def __init__(self, dim: Optional[int] = None) -> None: + super(Softmin, self).__init__() + self.dim = dim + + def __setstate__(self, state): + self.__dict__.update(state) + if not hasattr(self, 'dim'): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + return F.softmin(input, self.dim, _stacklevel=5) + + def extra_repr(self): + return 'dim={dim}'.format(dim=self.dim) + +class Softmax(Module): + r"""Applies the Softmax function to an n-dimensional input Tensor + rescaling them so that the elements of the n-dimensional output Tensor + lie in the range [0,1] and sum to 1. + + Softmax is defined as: + + .. math:: + \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} + + When the input Tensor is a sparse tensor then the unspecifed + values are treated as ``-inf``. + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [0, 1] + + Arguments: + dim (int): A dimension along which Softmax will be computed (so every slice + along dim will sum to 1). + + .. note:: + This module doesn't work directly with NLLLoss, + which expects the Log to be computed between the Softmax and itself. + Use `LogSoftmax` instead (it's faster and has better numerical properties). + + Examples:: + + >>> m = nn.Softmax(dim=1) + >>> input = torch.randn(2, 3) + >>> output = m(input) + + """ + __constants__ = ['dim'] + dim: Optional[int] + + def __init__(self, dim: Optional[int] = None) -> None: + super(Softmax, self).__init__() + self.dim = dim + + def __setstate__(self, state): + self.__dict__.update(state) + if not hasattr(self, 'dim'): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + return F.softmax(input, self.dim, _stacklevel=5) + + def extra_repr(self) -> str: + return 'dim={dim}'.format(dim=self.dim) + + +class Softmax2d(Module): + r"""Applies SoftMax over features to each spatial location. + + When given an image of ``Channels x Height x Width``, it will + apply `Softmax` to each location :math:`(Channels, h_i, w_j)` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [0, 1] + + Examples:: + + >>> m = nn.Softmax2d() + >>> # you softmax over the 2nd dimension + >>> input = torch.randn(2, 3, 12, 13) + >>> output = m(input) + """ + + def forward(self, input: Tensor) -> Tensor: + assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input' + return F.softmax(input, 1, _stacklevel=5) + + +class LogSoftmax(Module): + r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional + input Tensor. The LogSoftmax formulation can be simplified as: + + .. math:: + \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Arguments: + dim (int): A dimension along which LogSoftmax will be computed. + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [-inf, 0) + + Examples:: + + >>> m = nn.LogSoftmax() + >>> input = torch.randn(2, 3) + >>> output = m(input) + """ + __constants__ = ['dim'] + dim: Optional[int] + + def __init__(self, dim: Optional[int] = None) -> None: + super(LogSoftmax, self).__init__() + self.dim = dim + + def __setstate__(self, state): + self.__dict__.update(state) + if not hasattr(self, 'dim'): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + return F.log_softmax(input, self.dim, _stacklevel=5) + + def extra_repr(self): + return 'dim={dim}'.format(dim=self.dim) diff --git a/main/model/mytransformer.py b/main/model/mytransformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f432769e4cddeea86aff34e7e7a9d5cb9ca805 --- /dev/null +++ b/main/model/mytransformer.py @@ -0,0 +1,390 @@ +import copy +import pdb +from typing import Optional, Any + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +from torch.nn.modules.module import Module +from myactivation import MultiheadAttention +from torch.nn.modules.container import ModuleList +from torch.nn.init import xavier_uniform_ +from torch.nn.modules.dropout import Dropout +from torch.nn.modules.linear import Linear +from torch.nn.modules.normalization import LayerNorm + + +class Transformer(Module): + r"""A transformer model. User is able to modify the attributes as needed. The architecture + is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, + Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and + Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information + Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805) + model with corresponding parameters. + + Args: + d_model: the number of expected features in the encoder/decoder inputs (default=512). + nhead: the number of heads in the multiheadattention models (default=8). + num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). + num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu). + custom_encoder: custom encoder (default=None). + custom_decoder: custom decoder (default=None). + + Examples:: + >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) + >>> src = torch.rand((10, 32, 512)) + >>> tgt = torch.rand((20, 32, 512)) + >>> out = transformer_model(src, tgt) + + Note: A full example to apply nn.Transformer module for the word language model is available in + https://github.com/pytorch/examples/tree/master/word_language_model + """ + + def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, + num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: str = "relu", custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None) -> None: + super(Transformer, self).__init__() + + if custom_encoder is not None: + self.encoder = custom_encoder + else: + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation) + encoder_norm = LayerNorm(d_model) + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation) + decoder_norm = LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + r"""Take in and process masked source/target sequences. + + Args: + src: the sequence to the encoder (required). + tgt: the sequence to the decoder (required). + src_mask: the additive mask for the src sequence (optional). + tgt_mask: the additive mask for the tgt sequence (optional). + memory_mask: the additive mask for the encoder output (optional). + src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). + tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). + memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). + + Shape: + - src: :math:`(S, N, E)`. + - tgt: :math:`(T, N, E)`. + - src_mask: :math:`(S, S)`. + - tgt_mask: :math:`(T, T)`. + - memory_mask: :math:`(T, S)`. + - src_key_padding_mask: :math:`(N, S)`. + - tgt_key_padding_mask: :math:`(N, T)`. + - memory_key_padding_mask: :math:`(N, S)`. + + Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by + the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero + positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + + - output: :math:`(T, N, E)`. + + Note: Due to the multi-head attention architecture in the transformer model, + the output sequence length of a transformer is same as the input sequence + (i.e. target) length of the decode. + + where S is the source sequence length, T is the target sequence length, N is the + batch size, E is the feature number + + Examples: + >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) + """ + + if src.size(1) != tgt.size(1): + raise RuntimeError("the batch number of src and tgt must be equal") + + if src.size(2) != self.d_model or tgt.size(2) != self.d_model: + raise RuntimeError("the feature number of src and tgt must be equal to d_model") + + memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) + output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + return output + + def generate_square_subsequent_mask(self, sz: int) -> Tensor: + r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + def _reset_parameters(self): + r"""Initiate parameters in the transformer model.""" + + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + +class TransformerEncoder(Module): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ['norm'] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = src + + for mod in self.layers: + output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(Module): + r"""TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + output = tgt + + for mod in self.layers: + output = mod(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + + if self.norm is not None: + output = self.norm(output) + + return output + +class TransformerEncoderLayer(Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + # src: [81, 2, 128] + src2 = self.self_attn(src, src, src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): + super(TransformerDecoderLayer, self).__init__() + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.norm3 = LayerNorm(d_model) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +def _get_clones(module, N): + return ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) diff --git a/main/model/rotation2xyz.py b/main/model/rotation2xyz.py new file mode 100644 index 0000000000000000000000000000000000000000..9746c7d73f2e30bfb2495cb901f8422f04ecbf5b --- /dev/null +++ b/main/model/rotation2xyz.py @@ -0,0 +1,92 @@ +# This code is based on https://github.com/Mathux/ACTOR.git +import torch +import utils.rotation_conversions as geometry + + +from model.smpl import SMPL, JOINTSTYPE_ROOT +# from .get_model import JOINTSTYPES +JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] + + +class Rotation2xyz: + def __init__(self, device, dataset='amass'): + self.device = device + self.dataset = dataset + self.smpl_model = SMPL().eval().to(device) + + def __call__(self, x, mask, pose_rep, translation, glob, + jointstype, vertstrans, betas=None, beta=0, + glob_rot=None, get_rotations_back=False, **kwargs): + if pose_rep == "xyz": + return x + + if mask is None: + mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device) + + if not glob and glob_rot is None: + raise TypeError("You must specify global rotation if glob is False") + + if jointstype not in JOINTSTYPES: + raise NotImplementedError("This jointstype is not implemented.") + + if translation: + x_translations = x[:, -1, :3] + x_rotations = x[:, :-1] + else: + x_rotations = x + + x_rotations = x_rotations.permute(0, 3, 1, 2) + nsamples, time, njoints, feats = x_rotations.shape + + # Compute rotations (convert only masked sequences output) + if pose_rep == "rotvec": + rotations = geometry.axis_angle_to_matrix(x_rotations[mask]) + elif pose_rep == "rotmat": + rotations = x_rotations[mask].view(-1, njoints, 3, 3) + elif pose_rep == "rotquat": + rotations = geometry.quaternion_to_matrix(x_rotations[mask]) + elif pose_rep == "rot6d": + rotations = geometry.rotation_6d_to_matrix(x_rotations[mask]) + else: + raise NotImplementedError("No geometry for this one.") + + if not glob: + global_orient = torch.tensor(glob_rot, device=x.device) + global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3) + global_orient = global_orient.repeat(len(rotations), 1, 1, 1) + else: + global_orient = rotations[:, 0] + rotations = rotations[:, 1:] + + if betas is None: + betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas], + dtype=rotations.dtype, device=rotations.device) + betas[:, 1] = beta + # import ipdb; ipdb.set_trace() + out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas) + + # get the desirable joints + joints = out[jointstype] + + x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype) + x_xyz[~mask] = 0 + x_xyz[mask] = joints + + x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous() + + # the first translation root at the origin on the prediction + if jointstype != "vertices": + rootindex = JOINTSTYPE_ROOT[jointstype] + x_xyz = x_xyz - x_xyz[:, [rootindex], :, :] + + if translation and vertstrans: + # the first translation root at the origin + x_translations = x_translations - x_translations[:, :, [0]] + + # add the translation to all the joints + x_xyz = x_xyz + x_translations[:, None, :, :] + + if get_rotations_back: + return x_xyz, rotations, global_orient + else: + return x_xyz diff --git a/main/model/smpl.py b/main/model/smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..587f5419601a74df92c1e37263b28d4aa6a7c0a9 --- /dev/null +++ b/main/model/smpl.py @@ -0,0 +1,97 @@ +# This code is based on https://github.com/Mathux/ACTOR.git +import numpy as np +import torch + +import contextlib + +from smplx import SMPLLayer as _SMPLLayer +from smplx.lbs import vertices2joints + + +# action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38] +# change 0 and 8 +action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38] + +from utils.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA + +JOINTSTYPE_ROOT = {"a2m": 0, # action2motion + "smpl": 0, + "a2mpl": 0, # set(smpl, a2m) + "vibe": 8} # 0 is the 8 position: OP MidHip below + +JOINT_MAP = { + 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, + 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16, + 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, + 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8, + 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, + 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27, + 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, + 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, + 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45, + 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, + 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17, + 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20, + 'Neck (LSP)': 47, 'Top of Head (LSP)': 48, + 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50, + 'Spine (H36M)': 51, 'Jaw (H36M)': 52, + 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26, + 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27 +} + +JOINT_NAMES = [ + 'OP Nose', 'OP Neck', 'OP RShoulder', + 'OP RElbow', 'OP RWrist', 'OP LShoulder', + 'OP LElbow', 'OP LWrist', 'OP MidHip', + 'OP RHip', 'OP RKnee', 'OP RAnkle', + 'OP LHip', 'OP LKnee', 'OP LAnkle', + 'OP REye', 'OP LEye', 'OP REar', + 'OP LEar', 'OP LBigToe', 'OP LSmallToe', + 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', + 'Right Ankle', 'Right Knee', 'Right Hip', + 'Left Hip', 'Left Knee', 'Left Ankle', + 'Right Wrist', 'Right Elbow', 'Right Shoulder', + 'Left Shoulder', 'Left Elbow', 'Left Wrist', + 'Neck (LSP)', 'Top of Head (LSP)', + 'Pelvis (MPII)', 'Thorax (MPII)', + 'Spine (H36M)', 'Jaw (H36M)', + 'Head (H36M)', 'Nose', 'Left Eye', + 'Right Eye', 'Left Ear', 'Right Ear' +] + + +# adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints +class SMPL(_SMPLLayer): + """ Extension of the official SMPL implementation to support more joints """ + + def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs): + kwargs["model_path"] = model_path + + # remove the verbosity for the 10-shapes beta parameters + with contextlib.redirect_stdout(None): + super(SMPL, self).__init__(**kwargs) + + J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA) + self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) + vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES]) + a2m_indexes = vibe_indexes[action2motion_joints] + smpl_indexes = np.arange(24) + a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes]) + + self.maps = {"vibe": vibe_indexes, + "a2m": a2m_indexes, + "smpl": smpl_indexes, + "a2mpl": a2mpl_indexes} + + def forward(self, *args, **kwargs): + smpl_output = super(SMPL, self).forward(*args, **kwargs) + + extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) + all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1) + + output = {"vertices": smpl_output.vertices} + + for joinstype, indexes in self.maps.items(): + output[joinstype] = all_joints[:, indexes] + + return output \ No newline at end of file diff --git a/main/model/tisa.py b/main/model/tisa.py new file mode 100644 index 0000000000000000000000000000000000000000..a8074f23f7e415de9bf76f354b30614e2ae8ec0b --- /dev/null +++ b/main/model/tisa.py @@ -0,0 +1,118 @@ +import pdb + +import torch + +from torch import nn + + +class Tisa(nn.Module): + def __init__(self, num_attention_heads: int = 12, num_kernels: int = 5): + super().__init__() + self.num_attention_heads = num_attention_heads + self.num_kernels = num_kernels + + self.kernel_offsets = nn.Parameter( + torch.Tensor(self.num_kernels, self.num_attention_heads) + ) + self.kernel_amplitudes = nn.Parameter( + torch.Tensor(self.num_kernels, self.num_attention_heads) + ) + self.kernel_sharpness = nn.Parameter( + torch.Tensor(self.num_kernels, self.num_attention_heads) + ) + self._init_weights() + + def create_relative_offsets(self, seq_len: int): + """Creates offsets for all the relative distances between + -seq_len + 1 to seq_len - 1.""" + return torch.arange(-seq_len, seq_len + 1) + + def compute_positional_scores(self, relative_offsets): + """Takes seq_len and outputs position scores for each relative distance. + This implementation uses radial basis functions. Override this function to + use other scoring functions than the example in the paper.""" + rbf_scores = ( + self.kernel_amplitudes.unsqueeze(-1) + * torch.exp( + -torch.abs(self.kernel_sharpness.unsqueeze(-1)) + * ((self.kernel_offsets.unsqueeze(-1) - relative_offsets) ** 2) + ) + ).sum(axis=0) + return rbf_scores + + def scores_to_toeplitz_matrix(self, positional_scores, seq_len: int): + """Converts the TISA positional scores into the final matrix for the + self-attention equation. PRs with memory and/or speed optimizations are + welcome.""" + deformed_toeplitz = ( + ( + (torch.arange(0, -(seq_len ** 2), step=-1) + (seq_len - 1)).view( + seq_len, seq_len + ) + + (seq_len + 1) * torch.arange(seq_len).view(-1, 1) + ) + .view(-1) + .long() + .to(device=positional_scores.device) + ) + expanded_positional_scores = torch.take_along_dim( + positional_scores, deformed_toeplitz.view(1, -1), 1 + ).view(self.num_attention_heads, seq_len, seq_len) + return expanded_positional_scores + + def forward(self, seq_len: int): + """Computes the translation-invariant positional contribution to the + attention matrix in the self-attention module of transformer models.""" + if not self.num_kernels: + return torch.zeros((self.num_attention_heads, seq_len, seq_len)) + positional_scores_vector = self.compute_positional_scores( + self.create_relative_offsets(seq_len) + ) + positional_scores_matrix = self.scores_to_toeplitz_matrix( + positional_scores_vector, seq_len + ) + return positional_scores_matrix + + def visualize(self, seq_len: int = 10, attention_heads=None): + """Visualizes the TISA interpretability by plotting position scores as + a function of relative distance for each attention head.""" + if attention_heads is None: + attention_heads = list(range(self.num_attention_heads)) + import matplotlib.pyplot as plt + + x = self.create_relative_offsets(seq_len).detach().numpy() + y = ( + self.compute_positional_scores(self.create_relative_offsets(seq_len)) + .detach() + .numpy() + ) + for i in attention_heads: + plt.plot(x, y[i]) + plt.savefig('./pic-tisa.png') + plt.show() + + def _init_weights(self): + """Initialize the weights""" + ampl_init_mean = 0.1 + sharpness_init_mean = 0.1 + torch.nn.init.normal_(self.kernel_offsets, mean=0.0, std=5.0) + torch.nn.init.normal_( + self.kernel_amplitudes, mean=ampl_init_mean, std=0.1 * ampl_init_mean + ) + torch.nn.init.normal_( + self.kernel_sharpness, + mean=sharpness_init_mean, + std=0.1 * sharpness_init_mean, + ) + + +def main(): + tisa = Tisa() + positional_scores = tisa(20) + pdb.set_trace() + tisa.visualize(seq_len=20) + + +if __name__ == "__main__": + main() + diff --git a/main/mydiffusion_zeggs/0001-0933.mkv b/main/mydiffusion_zeggs/0001-0933.mkv new file mode 100644 index 0000000000000000000000000000000000000000..8ef3083f668fdb5300912efb3fda4e78120222f4 --- /dev/null +++ b/main/mydiffusion_zeggs/0001-0933.mkv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5732ea67802f3d5aab266eee443368c0bb0f1cf4fa4eac24e677a8ab8fdefc7f +size 1269291 diff --git a/main/mydiffusion_zeggs/0001-0933.mp4 b/main/mydiffusion_zeggs/0001-0933.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..21a34681cf8900a3710a2676d1e3c43e6415ea15 --- /dev/null +++ b/main/mydiffusion_zeggs/0001-0933.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d96adab755d039f34c3106b3598f002ad037e9af6cdaed1a14fb8de8fb9c460 +size 1300351 diff --git a/main/mydiffusion_zeggs/015_Happy_4_x_1_0.wav b/main/mydiffusion_zeggs/015_Happy_4_x_1_0.wav new file mode 100644 index 0000000000000000000000000000000000000000..00f20de982b5ca32f8888851b51facbe171ce219 --- /dev/null +++ b/main/mydiffusion_zeggs/015_Happy_4_x_1_0.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27c6952b7199eb0d4b5196a5a3be1564938680cfa8b891af2815a5c38021516c +size 4393112 diff --git a/main/mydiffusion_zeggs/WavLM/README.md b/main/mydiffusion_zeggs/WavLM/README.md new file mode 100644 index 0000000000000000000000000000000000000000..493f4ab4b05851937d3e8742a73bd9af1a955e1d --- /dev/null +++ b/main/mydiffusion_zeggs/WavLM/README.md @@ -0,0 +1,125 @@ + +# WavLM + + + + + [**WavLM**](https://arxiv.org/pdf/2110.13900.pdf) : **WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing** + +Official PyTorch implementation and pretrained models of WavLM + +- Oct 2021: release preprint in [arXiv](https://arxiv.org/pdf/2110.13900.pdf) + +## Pre-Trained Models +Model | Pre-training Dataset | Fine-tuning Dataset | Model +|---|---|---|--- +WavLM Base | [960 hrs LibriSpeech](http://www.openslr.org/12)| - | [Azure Storage](https://msranlcmtteamdrive.blob.core.windows.net/share/wavlm/WavLM-Base.pt?sv=2020-04-08&st=2021-11-05T00%3A35%3A31Z&se=2022-11-06T00%3A35%3A00Z&sr=b&sp=r&sig=JljnRVzyHY6AjHzhVmHV5KyQQCvvGfgp9D2M02oGJBU%3D)
[Google Drive](https://drive.google.com/file/d/19-C7SMQvEFAYLG5uc47NX_MY03JCbI4x/view?usp=sharing) +WavLM Base+ | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [Azure Storage](https://msranlcmtteamdrive.blob.core.windows.net/share/wavlm/WavLM-Base+.pt?sv=2020-04-08&st=2021-11-05T00%3A34%3A47Z&se=2022-10-06T00%3A34%3A00Z&sr=b&sp=r&sig=Gkf1IByHaIn1t%2FVEd9D6WHjZ3zu%2Fk5eSdoj21UytKro%3D)
[Google Drive](https://drive.google.com/file/d/1PlbT_9_B4F9BsD_ija84sUTVw7almNX8/view?usp=sharing) +WavLM Large | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [Azure Storage](https://msranlcmtteamdrive.blob.core.windows.net/share/wavlm/WavLM-Large.pt?sv=2020-08-04&st=2021-11-22T10%3A03%3A53Z&se=2022-11-23T10%3A03%3A00Z&sr=b&sp=r&sig=3kB8dwTCyIS8YQ7gW5oXmDrXV%2FAaLmoxBS37oPpFsz4%3D)
[Google Drive](https://drive.google.com/file/d/1p8nbj16b7YA16sqPZ4E0JUL-oIDUBGwU/view?usp=sharing) + +## Load Pre-Trained Models for Inference + +```python +import torch +from WavLM import WavLM, WavLMConfig + +# load the pre-trained checkpoints +checkpoint = torch.load('/path/to/wavlm.pt') +cfg = WavLMConfig(checkpoint['cfg']) +model = WavLM(cfg) +model.load_state_dict(checkpoint['model']) +model.eval() + +# extract the the representation of last layer +wav_input_16khz = torch.randn(1,10000) +rep = model.extract_features(wav_input_16khz)[0] + +# extract the the representation of each layer +wav_input_16khz = torch.randn(1,10000) +rep, layer_results = model.extract_features(wav_input_16khz, output_layer=model.cfg.encoder_layers, ret_layer_results=True)[0] +layer_reps = [x.transpose(0, 1) for x, _ in layer_results] +``` + + +## Universal Representation Evaluation on SUPERB +![alt text](WavLM_SUPERB_Results.png) + +![alt text](WavLM_SUPERB_Leaderboard.png) +## Downstream Task Performance +We also evaluate our models on typical speech processing benchmarks. +### Speaker Verification + +Evaluate on the [VoxCeleb](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/#:~:text=VoxCeleb%20is%20an%20audio%2Dvisual,interview%20videos%20uploaded%20to%20YouTube) + +| Model |Fix pre-train| Vox1-O | Vox1-E | Vox1-H | +| ------------- |------------- | ---------- | ---------- | ---------- | +| ECAPA-TDNN | - | 0.87 | 1.12 | 2.12 | +| HuBERT large | Yes| 0.888 |0.912| 1.853 | +| Wav2Vec2.0 (XLSR)| Yes | 0.915| 0.945 |1.895| +| UniSpeech-SAT large | Yes | 0.771 | 0.781| 1.669| +| WavLM large | Yes | 0.638 | 0.687| 1.457| +| HuBERT large | No| 0.585| 0.654 |1.342| +| Wav2Vec2.0 (XLSR) | No| 0.564| 0.605 |1.23| +| UniSpeech-SAT large | No | 0.564 | 0.561| 1.23 | +| **WavLM large** | No | **0.431** | **0.538**| **1.154** | + + + +### Speech Separation + +Evaluation on the [LibriCSS](https://github.com/chenzhuo1011/libri_css) +| Model |0S | 0L | OV10 | OV20 |OV30 |OV40 | +| ---------------- |------| ------ | ------ | ------ | ------ | ------ | +| [Conformer](https://ieeexplore.ieee.org/abstract/document/9413423/) (SOTA) | 4.5 | 4.4 |6.2 |8.5| 11 |12.6| +| HuBERT base | 4.7| 4.6 | 6.1 | 7.9| 10.6| 12.3| +| UniSpeech-SAT base | 4.4| 4.4 |5.4| 7.2| 9.2 |10.5| +| UniSpeech-SAT large | 4.3| 4.2 |5.0 |6.3| 8.2| 8.8| +| WavLM base+ | 4.5| 4.4 |5.6| 7.5| 9.4 |10.9| +| **WavLM large** | 4.2| 4.1 | 4.8 | 5.8 | 7.4| 8.5| + + +### Speaker Diarization + +Evaluation on the [CALLHOME](https://arxiv.org/pdf/1909.06247.pdf) +| Model |spk_2 |spk_3| spk_4| spk_5| spk_6| spk_all | +| ---------------- |------| ------ | ------ | ------ | ------ | ------ | +| [EEND-vector clustering](https://arxiv.org/pdf/2105.09040.pdf) | 7.96| 11.93 |16.38| 21.21| 23.1 |12.49|| +| [EEND-EDA clustering](https://arxiv.org/abs/2107.01545) (SOTA) | 7.11| 11.88 |14.37| 25.95| 21.95 |11.84|| +| HuBERT base| 7.93|12.07| 15.21 |19.59| 23.32| 12.63| +| HuBERT large| 7.39| 11.97| 15.76 |19.82| 22.10| 12.40| +| UniSpeech-SAT large| 5.93| 10.66| 12.9 |16.48| 23.25| 10.92| +| WavLM Base| 6.99| 11.12| 15.20 |16.48| 21.61| 11.75| +| **WavLm large** | 6.46| 10.69| 11.84 |12.89| 20.70| 10.35| + +### Speech Recogntion +Evaluate on the [LibriSpeech](https://www.openslr.org/12) + +![alt text](WavLM_ASR.PNG) + + +## License +This project is licensed under the license found in the LICENSE file in the root directory of this source tree. +Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) project. + +[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct) + + +### Reference +If you find our work is useful in your research, please cite the following paper: +``` latex +@article{Chen2021WavLM, + title = {WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing}, + author = {Sanyuan Chen and Chengyi Wang and Zhengyang Chen and Yu Wu and Shujie Liu and Zhuo Chen and Jinyu Li and Naoyuki Kanda and Takuya Yoshioka and Xiong Xiao and Jian Wu and Long Zhou and Shuo Ren and Yanmin Qian and Yao Qian and Jian Wu and Micheal Zeng and Furu Wei}, + eprint={2110.13900}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + year={2021} +} +``` + + +### Contact Information + +For help or issues using WavLM models, please submit a GitHub issue. + +For other communications related to WavLM, please contact Yu Wu (`yuwu1@microsoft.com`). diff --git a/main/mydiffusion_zeggs/WavLM/WavLM.py b/main/mydiffusion_zeggs/WavLM/WavLM.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b3193dd7a7267dd05301f8bfd7da01625240e2 --- /dev/null +++ b/main/mydiffusion_zeggs/WavLM/WavLM.py @@ -0,0 +1,743 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import logging +from typing import List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +from modules_WavLM import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + MultiheadAttention, + SamePad, + init_bert_params, + get_activation_fn, + TransposeLast, + GLU_Linear, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length) + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x += x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + diff --git a/main/mydiffusion_zeggs/WavLM/WavLM_ASR.PNG b/main/mydiffusion_zeggs/WavLM/WavLM_ASR.PNG new file mode 100644 index 0000000000000000000000000000000000000000..a882f66fe0eaccf4cde853e98904d90645f0916f Binary files /dev/null and b/main/mydiffusion_zeggs/WavLM/WavLM_ASR.PNG differ diff --git a/main/mydiffusion_zeggs/WavLM/WavLM_SUPERB_Leaderboard.png b/main/mydiffusion_zeggs/WavLM/WavLM_SUPERB_Leaderboard.png new file mode 100644 index 0000000000000000000000000000000000000000..85fa14c7d642022ace70d2f73c5865e560d037f5 Binary files /dev/null and b/main/mydiffusion_zeggs/WavLM/WavLM_SUPERB_Leaderboard.png differ diff --git a/main/mydiffusion_zeggs/WavLM/WavLM_SUPERB_Results.png b/main/mydiffusion_zeggs/WavLM/WavLM_SUPERB_Results.png new file mode 100644 index 0000000000000000000000000000000000000000..a9b635920d54f4c1196ff25e9c2ac6df518f762c Binary files /dev/null and b/main/mydiffusion_zeggs/WavLM/WavLM_SUPERB_Results.png differ diff --git a/main/mydiffusion_zeggs/WavLM/modules_WavLM.py b/main/mydiffusion_zeggs/WavLM/modules_WavLM.py new file mode 100644 index 0000000000000000000000000000000000000000..1dcfc6f061cc189ca51fc90107116f38e2e48daf --- /dev/null +++ b/main/mydiffusion_zeggs/WavLM/modules_WavLM.py @@ -0,0 +1,827 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +from torch.nn import Parameter +import torch.nn.functional as F + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights diff --git a/main/mydiffusion_zeggs/configs/DiffuseStyleGesture.yml b/main/mydiffusion_zeggs/configs/DiffuseStyleGesture.yml new file mode 100644 index 0000000000000000000000000000000000000000..7c3cf38e52395efa06baac70045ffcfa8c440cb8 --- /dev/null +++ b/main/mydiffusion_zeggs/configs/DiffuseStyleGesture.yml @@ -0,0 +1,30 @@ + +# ZEGGS +train_data_path: "../../ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/train/train_lmdb/train_lmdb/" # speaker_1_state_0 +val_data_path: "../../ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/valid/valid_lmdb/valid_lmdb/" + +# 60 fps + normalized +data_mean: "../../ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/mean.npz" +data_std: "../../ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/std.npz" + +n_poses: 88 # 88 -> 20*60 +n_codes: 30 +motion_resampling_framerate: 20 # 20 -> 60 +subdivision_stride: 10 # 10 -> 200 +batch_size: 300 # 384 -> 32 +loader_workers: 2 +epochs: 500 # 500 -> 10 +save_per_epochs: 25 # 20 -> 1 +model_save_path: "./output/train_DiffuseStyleGesture" +name: "DiffuseStyleGesture" +log_interval: 50 +weight_decay: 0.0 +lr_anneal_steps: 0 +save_dir: "./zeggs_mymodel3_wavlm" +audio_feat: "wavlm" # wav encoder; mfcc; wavlm + +lr: 0.00003 # 0.00003 -> +betas: [0.5, 0.999] +milestones: [100, 200] +gamma: 0.1 + diff --git a/main/mydiffusion_zeggs/configs/parse_args.py b/main/mydiffusion_zeggs/configs/parse_args.py new file mode 100644 index 0000000000000000000000000000000000000000..fd4da88a598177b029cfd5c8a57f96b1ff758473 --- /dev/null +++ b/main/mydiffusion_zeggs/configs/parse_args.py @@ -0,0 +1,23 @@ +import configargparse +import argparse + +def str2bool(v): + """ from https://stackoverflow.com/a/43357954/1361529 """ + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise configargparse.ArgumentTypeError('Boolean value expected.') + + +def parse_args(): + parser = argparse.ArgumentParser(description='DiffuseStyleGesture') + parser.add_argument('--config', default='./configs/DiffuseStyleGesture.yml') + parser.add_argument('--gpu', type=str, default='2') + parser.add_argument('--no_cuda', type=list, default=['2']) + + args = parser.parse_args() + return args diff --git a/main/mydiffusion_zeggs/data_loader/data_preprocessor.py b/main/mydiffusion_zeggs/data_loader/data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..605cabb5a3d4e8a84a42496b198f04f9ec8ec487 --- /dev/null +++ b/main/mydiffusion_zeggs/data_loader/data_preprocessor.py @@ -0,0 +1,152 @@ +""" create data samples """ +import pdb + +import lmdb +import math +import numpy as np +import pyarrow + + +import torch +import torch.nn.functional as F + +def wavlm_init(device=torch.device('cuda:1')): + import sys + [sys.path.append(i) for i in ['./WavLM']] + from WavLM import WavLM, WavLMConfig + wavlm_model_path = './WavLM/WavLM-Large.pt' + # wavlm_model_path = '../../../My/process/WavLM-Base+.pt' + # load the pre-trained checkpoints + checkpoint = torch.load(wavlm_model_path, map_location=torch.device('cpu')) + cfg = WavLMConfig(checkpoint['cfg']) + model = WavLM(cfg) + model = model.to(device) + model.load_state_dict(checkpoint['model']) + model.eval() + return model + + +def wav2wavlm(model, wav_input_16khz, device=torch.device('cuda:1')): + with torch.no_grad(): + wav_input_16khz = torch.from_numpy(wav_input_16khz).float() + wav_input_16khz = wav_input_16khz.to(device).unsqueeze(0) + rep = model.extract_features(wav_input_16khz)[0] + rep = F.interpolate(rep.transpose(1, 2), size=88, align_corners=True, mode='linear').transpose(1, 2) + return rep.squeeze().cpu().detach().data.cpu().numpy() + + +class DataPreprocessor: + def __init__(self, clip_lmdb_dir, out_lmdb_dir, n_poses, subdivision_stride, pose_resampling_fps, device): + self.n_poses = n_poses + self.subdivision_stride = subdivision_stride + self.skeleton_resampling_fps = pose_resampling_fps + + self.src_lmdb_env = lmdb.open(clip_lmdb_dir, readonly=True, lock=False) + with self.src_lmdb_env.begin() as txn: + self.n_videos = txn.stat()['entries'] + + self.audio_sample_length = int(self.n_poses / self.skeleton_resampling_fps * 16000) + + # create db for samples + map_size = 1024 * 1024 * 20 # in TB + map_size <<= 20 # in B + self.dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size=map_size) + self.n_out_samples = 0 + + self.model = wavlm_init(device) + self.device = device + + def run(self): + src_txn = self.src_lmdb_env.begin(write=False) + + # sampling and normalization + cursor = src_txn.cursor() + for key, value in cursor: + video = pyarrow.deserialize(value) + vid = video['vid'] + clips = video['clips'] + for clip_idx, clip in enumerate(clips): + self._sample_from_clip(vid, clip, self.device) + + # print stats + with self.dst_lmdb_env.begin() as txn: + print('no. of samples: ', txn.stat()['entries']) + # close db + self.src_lmdb_env.close() + self.dst_lmdb_env.sync() + self.dst_lmdb_env.close() + + + def _sample_from_clip(self, vid, clip, device): + clip_skeleton = clip['poses'] + clip_audio_raw = clip['audio_raw'] + clip_styles_raw = clip['style_raw'] + clip_mfcc_raw = clip['mfcc_raw'] + + # divide + aux_info = [] + sample_skeletons_list = [] + sample_audio_list = [] + sample_codes_list = [] + sample_mfcc_list = [] + sample_wavlm_list = [] + + MINLEN = min(len(clip_skeleton), int(len(clip_audio_raw) * 60 / 16000), len(clip_mfcc_raw)) + + num_subdivision = math.floor( + (MINLEN - self.n_poses) + / self.subdivision_stride) # floor((K - (N+M)) / S) + 1 + + for i in range(num_subdivision): + start_idx = i * self.subdivision_stride + fin_idx = start_idx + self.n_poses + + sample_skeletons = clip_skeleton[start_idx:fin_idx] + sample_mfcc = clip_mfcc_raw[start_idx:fin_idx] + subdivision_start_time = start_idx / self.skeleton_resampling_fps + subdivision_end_time = fin_idx / self.skeleton_resampling_fps + + # raw audio + audio_start = math.floor(start_idx / len(clip_skeleton) * len(clip_audio_raw)) + audio_end = audio_start + self.audio_sample_length + sample_audio = clip_audio_raw[audio_start:audio_end] + sample_wavlm = wav2wavlm(self.model, sample_audio, device=device) + + motion_info = {'vid': vid, + 'start_frame_no': start_idx, + 'end_frame_no': fin_idx, + 'start_time': subdivision_start_time, + 'end_time': subdivision_end_time} + + sample_skeletons_list.append(sample_skeletons) + sample_mfcc_list.append(sample_mfcc) + sample_wavlm_list.append(sample_wavlm) + sample_audio_list.append(sample_audio) + sample_codes_list.append(clip_styles_raw) + aux_info.append(motion_info) + + # if len(sample_skeletons_list) > 0: + # with self.dst_lmdb_env.begin(write=True) as txn: + # for poses, audio, codes, mfcc, wavlm, aux in zip(sample_skeletons_list, + # sample_audio_list, sample_codes_list, sample_mfcc_list, sample_wavlm_list, aux_info): + # poses = np.asarray(poses) + # + # # save + # k = '{:010}'.format(self.n_out_samples).encode('ascii') + # v = [poses, audio, codes, mfcc, wavlm, aux] + # v = pyarrow.serialize(v).to_buffer() + # txn.put(k, v) + # self.n_out_samples += 1 + + if len(sample_skeletons_list) > 0: + with self.dst_lmdb_env.begin(write=True) as txn: + for poses, codes, wavlm in zip(sample_skeletons_list, sample_codes_list, sample_wavlm_list): + poses = np.asarray(poses) + + # save + k = '{:010}'.format(self.n_out_samples).encode('ascii') + v = [poses, codes, wavlm] + v = pyarrow.serialize(v).to_buffer() + txn.put(k, v) + self.n_out_samples += 1 + diff --git a/main/mydiffusion_zeggs/data_loader/lmdb_data_loader.py b/main/mydiffusion_zeggs/data_loader/lmdb_data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..18d719270a27bd4664ce2e1e927ea463e88ba844 --- /dev/null +++ b/main/mydiffusion_zeggs/data_loader/lmdb_data_loader.py @@ -0,0 +1,111 @@ +import logging +import pdb +import lmdb as lmdb +import torch +from torch.utils.data import Dataset +import pyarrow +import sys +import os +[sys.path.append(i) for i in ['.', '..']] +from data_loader.data_preprocessor import DataPreprocessor + + +class TrinityDataset(Dataset): + def __init__(self, lmdb_dir, n_poses, subdivision_stride, pose_resampling_fps, model=None, device=torch.device('cuda:0')): + self.lmdb_dir = lmdb_dir + self.n_poses = n_poses + self.subdivision_stride = subdivision_stride + self.skeleton_resampling_fps = pose_resampling_fps + self.lang_model = None + + logging.info("Reading data '{}'...".format(lmdb_dir)) + if model is not None: + if 'Long_' in model: + preloaded_dir = lmdb_dir + '_cache_' + model.split('_')[-1] + if 'WavLM' in model: + preloaded_dir = lmdb_dir + '_cache_WavLM' + else: + preloaded_dir = lmdb_dir + '_cache' + if not os.path.exists(preloaded_dir): + data_sampler = DataPreprocessor(lmdb_dir, preloaded_dir, n_poses, + subdivision_stride, pose_resampling_fps, device=device) + data_sampler.run() + else: + logging.info('Found pre-loaded samples from {}'.format(preloaded_dir)) + + # init lmdb + # map_size = 1024 * 20 # in MB + # map_size <<= 20 # in B + self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) # default 10485760 + with self.lmdb_env.begin() as txn: + self.n_samples = txn.stat()['entries'] + + def __len__(self): + return self.n_samples + + def __getitem__(self, idx): + with self.lmdb_env.begin(write=False) as txn: + key = '{:010}'.format(idx).encode('ascii') + sample = txn.get(key) + + sample = pyarrow.deserialize(sample) + # pose_seq, audio, styles, mfcc, wavlm, aux_info = sample + pose_seq, styles, wavlm = sample + + # # normalize + # std = np.clip(self.data_std, a_min=0.01, a_max=None) + # pose_seq = (pose_seq - self.data_mean) / std + + # to tensors + pose_seq = torch.from_numpy(pose_seq).reshape((pose_seq.shape[0], -1)).float() + styles = torch.from_numpy(styles).float() + # audio = torch.from_numpy(audio).float() + # mfcc = torch.from_numpy(mfcc).float() + wavlm = torch.from_numpy(wavlm).float() + + # return pose_seq, aux_info, styles, audio, mfcc, wavlm + return pose_seq, styles, wavlm + + +if __name__ == '__main__': + ''' + cd main/mydiffusion_zeggs + python data_loader/lmdb_data_loader.py --config=./configs/DiffuseStyleGesture.yml --no_cuda 0 --gpu 0 + ''' + + from configs.parse_args import parse_args + import os + import yaml + from pprint import pprint + from easydict import EasyDict + from torch.utils.data import DataLoader + + args = parse_args() + + with open(args.config) as f: + config = yaml.safe_load(f) + + for k, v in vars(args).items(): + config[k] = v + # pprint(config) + + args = EasyDict(config) + + train_dataset = TrinityDataset(args.train_data_path, + n_poses=args.n_poses, + subdivision_stride=args.subdivision_stride, + pose_resampling_fps=args.motion_resampling_framerate, model='WavLM', device=torch.device('cuda:0')) + val_dataset = TrinityDataset(args.val_data_path, + n_poses=args.n_poses, + subdivision_stride=args.subdivision_stride, + pose_resampling_fps=args.motion_resampling_framerate, model='WavLM', device=torch.device('cuda:0')) + train_loader = DataLoader(dataset=train_dataset, batch_size=128, + shuffle=True, drop_last=True, num_workers=args.loader_workers, pin_memory=True) + + print(len(train_loader)) + for batch_i, batch in enumerate(train_loader, 0): + # target_vec, aux, style, audio, mfcc, wavlm = batch # [128, 88, 1141], -, [128, 6], [128, 70400], [128, 88, 13] + target_vec, style, wavlm = batch + print(batch_i) + pdb.set_trace() + # print(target_vec.shape, audio.shape) diff --git a/main/mydiffusion_zeggs/end2end.py b/main/mydiffusion_zeggs/end2end.py new file mode 100644 index 0000000000000000000000000000000000000000..cca2c58cea1e2fb66a7de13c08ca0d843dee3879 --- /dev/null +++ b/main/mydiffusion_zeggs/end2end.py @@ -0,0 +1,70 @@ +import pdb +import logging +logging.getLogger().setLevel(logging.INFO) +from torch.utils.data import DataLoader +from data_loader.lmdb_data_loader import TrinityDataset +import torch +import yaml +from pprint import pprint +from easydict import EasyDict +from configs.parse_args import parse_args +import os +import sys +[sys.path.append(i) for i in ['.', '..', '../model', '../train']] +from utils.model_util import create_gaussian_diffusion +from training_loop import TrainLoop +from model.mdm import MDM + + +def create_model_and_diffusion(args): + model = MDM(modeltype='', njoints=1141, nfeats=1, cond_mode = 'cross_local_attention3_style1', action_emb = 'tensor', audio_feat=args.audio_feat, + arch='trans_enc', latent_dim=256, n_seed=8, cond_mask_prob=0.1) + diffusion = create_gaussian_diffusion() + return model, diffusion + + +def main(args, device): + # dataset + train_dataset = TrinityDataset(args.train_data_path, + n_poses=args.n_poses, + subdivision_stride=args.subdivision_stride, + pose_resampling_fps=args.motion_resampling_framerate, model='WavLM', device=device) + train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, + shuffle=True, drop_last=True, num_workers=args.loader_workers, pin_memory=True) + + val_dataset = TrinityDataset(args.val_data_path, + n_poses=args.n_poses, + subdivision_stride=args.subdivision_stride, + pose_resampling_fps=args.motion_resampling_framerate, model='WavLM', device=device) + test_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, + shuffle=False, drop_last=True, num_workers=args.loader_workers, pin_memory=False) + + logging.info('len of train loader:{}, len of test loader:{}'.format(len(train_loader), len(test_loader))) + + if not os.path.exists(args.model_save_path): + os.mkdir(args.model_save_path) + + model, diffusion = create_model_and_diffusion(args) + model.to(mydevice) + TrainLoop(args, model, diffusion, mydevice, data=train_loader).run_loop() + + +if __name__ == '__main__': + ''' + cd mydiffusion_zeggs/ + ''' + + args = parse_args() + mydevice = torch.device('cuda:' + args.gpu) + torch.cuda.set_device(int(args.gpu)) + + with open(args.config) as f: + config = yaml.safe_load(f) + + for k, v in vars(args).items(): + config[k] = v + pprint(config) + + config = EasyDict(config) + + main(config, mydevice) diff --git a/main/mydiffusion_zeggs/generate/diffwav.py b/main/mydiffusion_zeggs/generate/diffwav.py new file mode 100644 index 0000000000000000000000000000000000000000..9b01b458f429cb1016747c242a2591650cb6c40a --- /dev/null +++ b/main/mydiffusion_zeggs/generate/diffwav.py @@ -0,0 +1,96 @@ +import pdb + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.autograd import Variable +import math +import sys +[sys.path.append(i) for i in ['.', '..', '../../..']] +from generate.generate import WavEncoder + +from mydiffwave.src.diffwave.model import DiffWave +from mydiffwave.src.diffwave.params import params +import numpy as np + + +class diffwav_model(nn.Module): + def __init__(self): + super().__init__() + self.WavEncoder = WavEncoder() + self.diffwav_model = DiffWave(params) + self.criterion = nn.SmoothL1Loss() + + def sample(self, batch_size, tmp_audio, beta, T): + wav_feature = self.WavEncoder(tmp_audio).transpose(1, 2) # (b, 240, 32) + noisy_pose = torch.randn(batch_size, 240, 135).transpose(1, 2).to(tmp_audio.device) + alpha = 1 - beta + alpha_cum = np.cumprod(alpha) + for n in range(len(alpha) - 1, -1, -1): + c1 = 1 / alpha[n] ** 0.5 + c2 = beta[n] / (1 - alpha_cum[n]) ** 0.5 + noisy_pose = c1 * (noisy_pose - c2 * self.diffwav_model(noisy_pose, torch.tensor([T[n]], device=noisy_pose.device), wav_feature).squeeze(1)) + if n > 0: + noise = torch.randn_like(noisy_pose) + sigma = ((1.0 - alpha_cum[n - 1]) / (1.0 - alpha_cum[n]) * beta[n]) ** 0.5 + noisy_pose += sigma * noise + noisy_pose = torch.clamp(noisy_pose, -1.0, 1.0) + return noisy_pose.transpose(1, 2) + + def forward(self, noisy_pose, t, audio, noise): # (b, len, 13) + wav_feature = self.WavEncoder(audio).transpose(1, 2) # (b, 240, 32) + predicted = self.diffwav_model(noisy_pose, t, wav_feature) + loss = self.criterion(predicted, noise) + return loss + + +if __name__ == '__main__': + ''' + cd mydiffusion/generate/ + python diffwav.py + ''' + # z = torch.arange(0, 60).reshape(2, 30) + + device = torch.device('cuda:2') + audio = torch.rand(2, 64000).to(device) + pose = torch.rand(2, 240, 135).transpose(1, 2).to(device) + model = diffwav_model().to(device) + + n_frames = 240 + n_pose_dims = 135 + n_audio_dim = 32 + hop_samples = 1 + + N = pose.shape[0] # 1, 15872 + device = pose.device + + beta = np.linspace(1e-4, 0.05, 50) + noise_level = np.cumprod(1 - beta) + noise_level = torch.tensor(noise_level.astype(np.float32)).to(device) + + t = torch.randint(0, len(beta), [N], device=pose.device) # (batch) + noise_scale = noise_level[t].unsqueeze(1).unsqueeze(1) # (batch, 1) + noise_scale_sqrt = noise_scale ** 0.5 # (batch, 1) + noise = torch.randn_like(pose) # (batch, 15872) + noisy_pose = noise_scale_sqrt * pose + (1.0 - noise_scale) ** 0.5 * noise # (batch, 15872) + + loss = model(noisy_pose, t, audio, pose) # (batch, 1, 15872) + print(loss) + + talpha = 1 - beta + talpha_cum = np.cumprod(talpha) + alpha = 1 - beta + alpha_cum = np.cumprod(alpha) + + T = [] + for s in range(len(beta)): + for t in range(len(beta) - 1): + if talpha_cum[t+1] <= alpha_cum[s] <= talpha_cum[t]: + twiddle = (talpha_cum[t]**0.5 - alpha_cum[s]**0.5) / (talpha_cum[t]**0.5 - talpha_cum[t+1]**0.5) + T.append(t + twiddle) + break + T = np.array(T, dtype=np.float32) + + tmp_audio = torch.rand(1, 64000).to(device) + sampled_seq = model.sample(batch_size=1, tmp_audio=tmp_audio, beta=beta, T=T) + print(sampled_seq.shape) # (4, 32, 128) diff --git a/main/mydiffusion_zeggs/generate/generate.py b/main/mydiffusion_zeggs/generate/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..83af29c2b00b58229f5e0c3e418e3e3c21ce4dfe --- /dev/null +++ b/main/mydiffusion_zeggs/generate/generate.py @@ -0,0 +1,405 @@ +import pdb +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +# from denoising_diffusion_pytorch import myUnet1D, myGaussianDiffusion1D + + +class WavEncoder(nn.Module): # (b, 64000) -> (b, 240, 32) + def __init__(self): + super().__init__() + self.feat_extractor = nn.Sequential( + nn.Conv1d(1, 16, 15, stride=3, padding=800), + nn.BatchNorm1d(16), + nn.LeakyReLU(0.3, inplace=True), + nn.Conv1d(16, 32, 15, stride=3), + nn.BatchNorm1d(32), + nn.LeakyReLU(0.3, inplace=True), + nn.Conv1d(32, 64, 15, stride=5), + nn.BatchNorm1d(64), + nn.LeakyReLU(0.3, inplace=True), + nn.Conv1d(64, 32, 15, stride=6), + ) + + def forward(self, wav_data): + wav_data = wav_data.unsqueeze(1) # add channel dim + out = self.feat_extractor(wav_data) + return out.transpose(1, 2) # to (batch x seq x dim) + + +class Generator_linear(nn.Module): + def __init__(self): + super().__init__() + self.WavEncoder = WavEncoder() + self.project = nn.Linear(32, 512, bias=False) + self.norm = nn.LayerNorm(32) + + def sample(self, x): + wav_feature = self.WavEncoder(x) + wav_feature = self.norm(wav_feature) # (1, 30, 512) + codebook_embedding = self.project(wav_feature).squeeze() + # codebook_embedding = self.norm(codebook_embedding).squeeze() # (1, 30, 512) + code = torch.tensor([]).to(x.device) + for k in codebook_embedding: + probs = F.softmax(k, dim=-1) + _, ix = torch.topk(probs, k=1, dim=-1) + code = torch.cat((code, ix)) + return [code.unsqueeze(0).int()] + + def forward(self, x, target=None): + wav_feature = self.WavEncoder(x) + wav_feature = self.norm(wav_feature) # norm before linear + codebook_embedding = self.project(wav_feature) + loss = None + if target is not None: + loss = F.cross_entropy(codebook_embedding.view(-1, codebook_embedding.size(-1)), target.view(-1)) + return codebook_embedding, loss + +''' +Based on the following Se2Seq implementations: +- https://github.com/AuCson/PyTorch-Batch-Attention-Seq2seq +- https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb +''' + + +class EncoderRNN(nn.Module): + def __init__(self, input_size, embed_size, hidden_size, n_layers=1, dropout=0.5, pre_trained_embedding=None): + super(EncoderRNN, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.embed_size = embed_size + self.n_layers = n_layers + self.dropout = dropout + + if pre_trained_embedding is not None: # use pre-trained embedding (e.g., word2vec, glove) + assert pre_trained_embedding.shape[0] == input_size + assert pre_trained_embedding.shape[1] == embed_size + self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding), freeze=False) + else: + self.embedding = nn.Embedding(input_size, embed_size) + + self.gru = nn.GRU(embed_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True) + + self.do_flatten_parameters = False + if torch.cuda.device_count() > 1: + self.do_flatten_parameters = True + + def forward(self, input_seqs, input_lengths, hidden=None): + ''' + :param input_seqs: + Variable of shape (num_step(T),batch_size(B)), sorted decreasingly by lengths(for packing) + :param input_lengths: + list of sequence length + :param hidden: + initial state of GRU + :returns: + GRU outputs in shape (T,B,hidden_size(H)) + last hidden stat of RNN(i.e. last output for GRU) + ''' + if self.do_flatten_parameters: + self.gru.flatten_parameters() + + embedded = self.embedding(input_seqs) + packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths) + outputs, hidden = self.gru(packed, hidden) + outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded) + outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] # Sum bidirectional outputs + return outputs, hidden + + +class Attn(nn.Module): + def __init__(self, hidden_size): + super(Attn, self).__init__() + self.hidden_size = hidden_size + self.attn = nn.Linear(self.hidden_size * 2, hidden_size) + self.v = nn.Parameter(torch.rand(hidden_size)) + stdv = 1. / math.sqrt(self.v.size(0)) + self.v.data.normal_(mean=0, std=stdv) + + def forward(self, hidden, encoder_outputs): + ''' + :param hidden: + previous hidden state of the decoder, in shape (layers*directions,B,H) + :param encoder_outputs: + encoder outputs from Encoder, in shape (T,B,H) + :return + attention energies in shape (B,T) + ''' + max_len = encoder_outputs.size(0) + this_batch_size = encoder_outputs.size(1) + H = hidden.repeat(max_len, 1, 1).transpose(0, 1) + encoder_outputs = encoder_outputs.transpose(0, 1) # [B*T*H] + attn_energies = self.score(H, encoder_outputs) # compute attention score + return F.softmax(attn_energies, dim=1).unsqueeze(1) # normalize with softmax + + def score(self, hidden, encoder_outputs): + energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], 2))) # [B*T*2H]->[B*T*H] + energy = energy.transpose(2, 1) # [B*H*T] + v = self.v.repeat(encoder_outputs.data.shape[0], 1).unsqueeze(1) # [B*1*H] + energy = torch.bmm(v, energy) # [B*1*T] + return energy.squeeze(1) # [B*T] + + +class BahdanauAttnDecoderRNN(nn.Module): + def __init__(self, input_size, hidden_size, output_size, n_layers=1, dropout_p=0.1, + discrete_representation=False, speaker_model=None): + super(BahdanauAttnDecoderRNN, self).__init__() + + # define parameters + self.hidden_size = hidden_size + self.output_size = output_size + self.n_layers = n_layers + self.dropout_p = dropout_p + self.discrete_representation = discrete_representation + self.speaker_model = speaker_model + + # define embedding layer + if self.discrete_representation: + self.embedding = nn.Embedding(output_size, hidden_size) + self.dropout = nn.Dropout(dropout_p) + + if self.speaker_model: + self.speaker_embedding = nn.Embedding(speaker_model.n_words, 8) + + # calc input size + if self.discrete_representation: + input_size = hidden_size # embedding size + linear_input_size = input_size + hidden_size + if self.speaker_model: + linear_input_size += 8 + + # define layers + self.pre_linear = nn.Sequential( + nn.Linear(linear_input_size, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True) + ) + self.attn = Attn(hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=dropout_p) + + # self.out = nn.Linear(hidden_size * 2, output_size) + self.out = nn.Linear(hidden_size, output_size) + + self.do_flatten_parameters = False + if torch.cuda.device_count() > 1: + self.do_flatten_parameters = True + + def freeze_attn(self): + for param in self.attn.parameters(): + param.requires_grad = False + + def forward(self, motion_input, last_hidden, encoder_outputs, vid_indices=None): + ''' + :param motion_input: + motion input for current time step, in shape [batch x dim] + :param last_hidden: + last hidden state of the decoder, in shape [layers x batch x hidden_size] + :param encoder_outputs: + encoder outputs in shape [steps x batch x hidden_size] + :param vid_indices: + :return + decoder output + Note: we run this one step at a time i.e. you should use a outer loop + to process the whole sequence + ''' + + if self.do_flatten_parameters: + self.gru.flatten_parameters() + + if self.discrete_representation: + word_embedded = self.embedding(motion_input).view(1, motion_input.size(0), -1) # [1 x B x embedding_dim] + motion_input = self.dropout(word_embedded) + else: + motion_input = motion_input.view(1, motion_input.size(0), -1) # [1 x batch x dim] + + # attention + attn_weights = self.attn(last_hidden[-1], encoder_outputs) # [batch x 1 x T] + context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # [batch x 1 x attn_size] + context = context.transpose(0, 1) # [1 x batch x attn_size] + + # make input vec + rnn_input = torch.cat((motion_input, context), 2) # [1 x batch x (dim + attn_size)] + + if self.speaker_model: + assert vid_indices is not None + speaker_context = self.speaker_embedding(vid_indices).unsqueeze(0) + rnn_input = torch.cat((rnn_input, speaker_context), 2) # [1 x batch x (dim + attn_size + embed_size)] + + rnn_input = self.pre_linear(rnn_input.squeeze(0)) + rnn_input = rnn_input.unsqueeze(0) + + # rnn + output, hidden = self.gru(rnn_input, last_hidden) + + # post-fc + output = output.squeeze(0) # [1 x batch x hidden_size] -> [batch x hidden_size] + output = self.out(output) + + return output, hidden, attn_weights + + +class Generator(nn.Module): + def __init__(self, args, motion_dim, discrete_representation=False, speaker_model=None): + super(Generator, self).__init__() + self.output_size = motion_dim + self.n_layers = args.n_layers + self.discrete_representation = discrete_representation + self.decoder = BahdanauAttnDecoderRNN(input_size=motion_dim, + hidden_size=args.hidden_size, + output_size=self.output_size, + n_layers=self.n_layers, + dropout_p=args.dropout_prob, + discrete_representation=discrete_representation, + speaker_model=speaker_model) + + def freeze_attn(self): + self.decoder.freeze_attn() + + def forward(self, z, motion_input, last_hidden, encoder_output, vid_indices=None): + if z is None: + input_with_noise_vec = motion_input + else: + assert not self.discrete_representation # not valid for discrete representation + input_with_noise_vec = torch.cat([motion_input, z], dim=1) # [bs x (10+z_size)] + + return self.decoder(input_with_noise_vec, last_hidden, encoder_output, vid_indices) + + +class Seq2SeqNet(nn.Module): + def __init__(self, args, pose_dim, n_frames, n_words, word_embed_size, word_embeddings, speaker_model=None): + super().__init__() + self.encoder = EncoderRNN( + n_words, word_embed_size, args.hidden_size, args.n_layers, + dropout=args.dropout_prob, pre_trained_embedding=word_embeddings) + self.decoder = Generator(args, pose_dim, speaker_model=speaker_model) + + self.n_frames = n_frames + self.n_pre_poses = args.n_pre_poses + self.pose_dim = pose_dim + + def forward(self, in_text, in_lengths, poses, vid_indices): + # reshape to (seq x batch x dim) + in_text = in_text.transpose(0, 1) + poses = poses.transpose(0, 1) + + outputs = torch.zeros(self.n_frames, poses.size(1), self.decoder.output_size).to(poses.device) + + # run words through encoder + encoder_outputs, encoder_hidden = self.encoder(in_text, in_lengths, None) + decoder_hidden = encoder_hidden[:self.decoder.n_layers] # use last hidden state from encoder + + # run through decoder one time step at a time + decoder_input = poses[0] # initial pose from the dataset + outputs[0] = decoder_input + + for t in range(1, self.n_frames): + decoder_output, decoder_hidden, _ = self.decoder(None, decoder_input, decoder_hidden, encoder_outputs, + vid_indices) + outputs[t] = decoder_output + + if t < self.n_pre_poses: + decoder_input = poses[t] # next input is current target + else: + decoder_input = decoder_output # next input is current prediction + + return outputs.transpose(0, 1) + + +class Generator_gru(nn.Module): + def __init__(self): + super().__init__() + self.WavEncoder = WavEncoder() + self.hidden_size = 200 + self.output_size = 512 + + self.project = nn.GRU(input_size=32, hidden_size=self.hidden_size, num_layers=2, dropout=0.1, bidirectional=True, batch_first=True) + self.norm = nn.LayerNorm(self.hidden_size) + self.out = nn.Linear(self.hidden_size, self.output_size) + + def sample(self, x): + wav_feature = self.WavEncoder(x) + hidden = None + outputs, hidden = self.project(wav_feature, hidden) + outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] # Sum bidirectional outputs, ((batch, seq_len=240, ) + outputs = self.norm(outputs) + codebook_embedding = self.out(outputs) + code = torch.tensor([]).to(x.device) + for k in codebook_embedding: + probs = F.softmax(k, dim=-1) + _, ix = torch.topk(probs, k=1, dim=-1) + code = torch.cat((code, ix.squeeze(-1))) + return [code.unsqueeze(0).long()] + + def forward(self, x, target=None): # (b, len, 13) + wav_feature = self.WavEncoder(x) # (b, 30, 32) + hidden = None + + outputs, hidden = self.project(wav_feature, hidden) # (batch, seq_len=240, num_directions * hidden_size) + outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] # Sum bidirectional outputs, ((batch, seq_len=240, ) + + outputs = self.norm(outputs) + codebook_embedding = self.out(outputs) + + loss = None + if target is not None: + loss = F.cross_entropy(codebook_embedding.view(-1, codebook_embedding.size(-1)), target.view(-1)) + return codebook_embedding, loss + + +class Generator_diff(nn.Module): + def __init__(self): + super().__init__() + self.WavEncoder = WavEncoder() + seq_len = 240 + joints = 15 + n_dim = 9 + n_channels = joints * n_dim + audio_dim = 32 + + model = myUnet1D( + dim=64, + dim_mults=(1, 2, 4, 8), + channels=n_channels, + self_condition=True, + audio_dim=audio_dim + ) + + self.diffusion = myGaussianDiffusion1D( + model, + seq_length=seq_len, + timesteps=250, + objective='pred_v', + audio_dim=audio_dim, + loss_type='huber' + ) + + def sample(self, batch_size, tmp_audio): + wav_feature = self.WavEncoder(tmp_audio).transpose(1, 2) # (b, 240, 32) + sampled_seq = self.diffusion.sample(batch_size=batch_size, tmp_audio_feat=wav_feature) + return sampled_seq + + def forward(self, target, x): # (b, len, 13) + wav_feature = self.WavEncoder(x).transpose(1, 2) # (b, 240, 32) + loss = self.diffusion(target, wav_feature) + return loss + + +if __name__ == '__main__': + ''' + cd mydiffusion/generate/ + python generate.py + ''' + audio = torch.rand(2, 64000) + pose = torch.rand(2, 240, 135).transpose(1, 2) + + # z = torch.arange(0, 60).reshape(2, 30) + # model = Generator_gru() + model = Generator_diff() + loss = model(pose, audio) # (b, 30, 32) + pdb.set_trace() + sampled_seq = model.sample(batch_size=1) + print(sampled_seq.shape) # (4, 32, 128) + + + diff --git a/main/mydiffusion_zeggs/inference.py b/main/mydiffusion_zeggs/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc2d79a096302d53cc6cd4a3e4879122e47b000 --- /dev/null +++ b/main/mydiffusion_zeggs/inference.py @@ -0,0 +1,140 @@ +import pdb +import yaml +from pprint import pprint +from easydict import EasyDict +import numpy as np +import torch +from configs.parse_args import parse_args +import math +from models.vqvae import VQVAE +import torch.nn as nn +from generate.generate import Generator_gru as Generator +import librosa +import os +import sys + +[sys.path.append(i) for i in ['.', '..', '../process']] +from process.process_bvh import make_bvh_GENEA2020_BT +from process.bvh_to_position import bvh_to_npy +from process.visualize_bvh import visualize + +args = parse_args() +mydevice = torch.device('cuda:' + args.gpu) + + +def main(args, audio_path, normalize=True, mode='position', codebook_model_path=None, end2end_model_path=None, + save_path=None, prefix=None, max_frames=None): + + audio_raw, audio_sr = librosa.load(audio_path, mono=True, sr=16000, res_type='kaiser_fast') + + clip_length = audio_raw.shape[0] + + # divide into synthesize units and do synthesize + unit_time = audio_sr * args.n_poses / args.motion_resampling_framerate # 4 * 16000 + + if clip_length < unit_time: + num_subdivision = 1 + else: + num_subdivision = math.ceil((clip_length - unit_time) / unit_time) + 1 + + if max_frames is not None and num_subdivision >= (max_frames / args.motion_resampling_framerate) / args.n_poses: + num_subdivision = int(max_frames / args.n_poses) # 3600 / 60 / 4 = 15 + + print('num_subdivision: {}, unit_time: {}, clip_length: {}'.format(num_subdivision, unit_time, clip_length)) + + + with torch.no_grad(): + if mode == 'position': + model_VQVAE = VQVAE(args.VQVAE, 15 * 3) # n_joints * n_chanels + elif mode == 'rotation': + model_VQVAE = VQVAE(args.VQVAE, 15 * 9) # n_joints * n_chanels + model_VQVAE = nn.DataParallel(model_VQVAE, device_ids=[eval(i) for i in config.no_cuda]) + model_VQVAE = model_VQVAE.to(mydevice) + checkpoint = torch.load(codebook_model_path, map_location=torch.device('cpu')) + model_VQVAE.load_state_dict(checkpoint['model_dict']) + model_VQVAE = model_VQVAE.eval() + + model = Generator() + model = nn.DataParallel(model, device_ids=[eval(i) for i in config.no_cuda]) + model = model.to(mydevice) + checkpoint = torch.load(end2end_model_path, map_location=torch.device('cpu')) + model.load_state_dict(checkpoint['model_dict']) + model = model.eval() + + result = [] + code = [] + + for i in range(0, num_subdivision): + start_time = i * unit_time + + # prepare pose input + pose_start = math.floor(start_time) + pose_end = pose_start + 64000 + in_audio = audio_raw[pose_start:pose_end] + if len(in_audio) < 64000: + in_audio = np.pad(in_audio, (0, 64000 - len(in_audio)), 'constant') + in_audio = torch.from_numpy(in_audio).unsqueeze(0).to(mydevice) + + out_zs = model.module.sample(in_audio) + + code.append(out_zs[0].squeeze(0).data.cpu().numpy()) + + out_code = np.vstack(code) + # print(torch.from_numpy(out_code.flatten()).to(mydevice).unsqueeze(0).shape) + pose_sample = model_VQVAE.module.decode([torch.from_numpy(out_code.flatten()).to(mydevice).unsqueeze(0)]).squeeze(0).data.cpu().numpy() + result.append(pose_sample) + + out_poses = np.vstack(result) + + if normalize: + data_mean = np.array(args.data_mean).squeeze() + data_std = np.array(args.data_std).squeeze() + std = np.clip(data_std, a_min=0.01, a_max=None) + out_poses = np.multiply(out_poses, std) + data_mean + print(out_poses.shape) + print(out_code.shape) + np.save(os.path.join(save_path, 'code' + prefix + '.npy'), out_code) + np.save(os.path.join(save_path, 'generate' + prefix + '.npy'), out_poses) + return out_poses, out_code + + +if __name__ == '__main__': + ''' + cd codebook/ + python inference.py --config=./configs/codebook.yml --train --no_cuda 3 --gpu 3 + ''' + + with open(args.config) as f: + config = yaml.safe_load(f) + + for k, v in vars(args).items(): + config[k] = v + pprint(config) + + config = EasyDict(config) + + # audio_path = "/mnt/nfs7/y50021900/My/tmp/TEST/TestSeq001.wav" + audio_path = "../tmp/TEST/1_wayne_0_103_110.wav" + mode = 'rotation' + codebook_path = '../codebook/BEAT_output_60fps_rotation/train_codebook/' + "codebook_checkpoint_best.bin" + end2end_path = "./BEAT_output_60fps_rotation_gru/train_end2end/end2end_checkpoint_080.bin" # + save_path = "../tmp/TEST/npy_position/" + prefix = 'BEAT_gru' + save_path = os.path.join(save_path, prefix) + if not os.path.exists(save_path): + os.mkdir(save_path) + MAX_FRAMES = 60 * 60 + pipeline_path = '../process/resource/data_pipe_60_rotation.sav' + + out_poses, out_code = main(config, audio_path=audio_path, normalize=True, mode=mode, codebook_model_path=codebook_path, + end2end_model_path=end2end_path, save_path=save_path, prefix=prefix, max_frames=MAX_FRAMES) + print('rotation npy to bvh...') + # make_bvh_GENEA2020_BT(save_path, filename_prefix='GT', poses=poses, smoothing=False, pipeline_path=pipeline_path) + make_bvh_GENEA2020_BT(save_path, prefix, out_poses, smoothing=False, pipeline_path=pipeline_path) + print('bvh to position npy...') + bvh_path_generated = os.path.join(save_path, prefix + '_generated.bvh') + bvh_to_npy(bvh_path_generated, save_path) + print('visualize code...') + npy_generated = np.load(os.path.join(save_path, prefix + '_generated.npy')) + out_video = os.path.join(save_path, prefix + '_generated.mp4') + visualize(npy_generated.reshape((npy_generated.shape[0], -1, 3)), out_video, out_code.flatten(), 'upper') diff --git a/main/mydiffusion_zeggs/mfcc.py b/main/mydiffusion_zeggs/mfcc.py new file mode 100644 index 0000000000000000000000000000000000000000..2405644af6478bbdee88490bb948f10d87f188bb --- /dev/null +++ b/main/mydiffusion_zeggs/mfcc.py @@ -0,0 +1,257 @@ +# Copyright (c) 2006 Carnegie Mellon University +# +# You may copy and modify this freely under the same terms as +# Sphinx-III + +"""Compute MFCC coefficients. + +This module provides functions for computing MFCC (mel-frequency +cepstral coefficients) as used in the Sphinx speech recognition +system. +""" + +__author__ = "David Huggins-Daines " +__version__ = "$Revision: 6390 $" + +import pdb + +import numpy, numpy.fft +import math +import librosa +import os + + +def mel(f): + return 2595. * numpy.log10(1. + f / 700.) + + +def melinv(m): + return 700. * (numpy.power(10., m / 2595.) - 1.) + + +class MFCC(object): + def __init__(self, nfilt=40, ncep=13, + lowerf=133.3333, upperf=6855.4976, alpha=0.97, + samprate=16000, frate=100, wlen=0.0256, + nfft=512): + # Store parameters + self.samprate = samprate + self.lowerf = lowerf + self.upperf = upperf + self.nfft = nfft + self.ncep = ncep + self.nfilt = nfilt + self.frate = frate + self.fshift = float(samprate) / frate + + # Build Hamming window + self.wlen = int(wlen * samprate) + self.win = numpy.hamming(self.wlen) + + # Prior sample for pre-emphasis + self.prior = 0 + self.alpha = alpha + + # Build mel filter matrix + self.filters = numpy.zeros((nfft // 2 + 1, nfilt), 'd') + dfreq = float(samprate) / nfft + if upperf > samprate / 2: + raise (Exception, + "Upper frequency %f exceeds Nyquist %f" % (upperf, samprate / 2)) + melmax = mel(upperf) + melmin = mel(lowerf) + dmelbw = (melmax - melmin) / (nfilt + 1) + # Filter edges, in Hz + filt_edge = melinv(melmin + dmelbw * numpy.arange(nfilt + 2, dtype='d')) + + for whichfilt in range(0, nfilt): + # Filter triangles, in DFT points + leftfr = round(filt_edge[whichfilt] / dfreq) + centerfr = round(filt_edge[whichfilt + 1] / dfreq) + rightfr = round(filt_edge[whichfilt + 2] / dfreq) + # For some reason this is calculated in Hz, though I think + # it doesn't really matter + fwidth = (rightfr - leftfr) * dfreq + height = 2. / fwidth + + if centerfr != leftfr: + leftslope = height / (centerfr - leftfr) + else: + leftslope = 0 + freq = leftfr + 1 + while freq < centerfr: + self.filters[freq, whichfilt] = (freq - leftfr) * leftslope + freq = freq + 1 + if freq == centerfr: # This is always true + self.filters[freq, whichfilt] = height + freq = freq + 1 + if centerfr != rightfr: + rightslope = height / (centerfr - rightfr) + while freq < rightfr: + self.filters[freq, whichfilt] = (freq - rightfr) * rightslope + freq = freq + 1 + # print("Filter %d: left %d=%f center %d=%f right %d=%f width %d" % + # (whichfilt, + # leftfr, leftfr*dfreq, + # centerfr, centerfr*dfreq, + # rightfr, rightfr*dfreq, + # freq - leftfr)) + # print self.filters[leftfr:rightfr,whichfilt] + + # Build DCT matrix + self.s2dct = s2dctmat(nfilt, ncep, 1. / nfilt) + self.dct = dctmat(nfilt, ncep, numpy.pi / nfilt) + + def sig2s2mfc(self, sig): + nfr = int(len(sig) / self.fshift + 1) + mfcc = numpy.zeros((nfr, self.ncep), 'd') + fr = 0 + while fr < nfr: + start = round(fr * self.fshift) + end = min(len(sig), start + self.wlen) + frame = sig[start:end] + if len(frame) < self.wlen: + frame = numpy.resize(frame, self.wlen) + frame[self.wlen:] = 0 + mfcc[fr] = self.frame2s2mfc(frame) + fr = fr + 1 + return mfcc + + def sig2logspec(self, sig): + nfr = int(len(sig) / self.fshift + 1) + mfcc = numpy.zeros((nfr, self.nfilt), 'd') + fr = 0 + while fr < nfr: + start = round(fr * self.fshift) + end = min(len(sig), start + self.wlen) + frame = sig[start:end] + if len(frame) < self.wlen: + frame = numpy.resize(frame, self.wlen) + frame[self.wlen:] = 0 + mfcc[fr] = self.frame2logspec(frame) + fr = fr + 1 + return mfcc + + def pre_emphasis(self, frame): + # FIXME: Do this with matrix multiplication + outfr = numpy.empty(len(frame), 'd') + outfr[0] = frame[0] - self.alpha * self.prior + for i in range(1, len(frame)): + outfr[i] = frame[i] - self.alpha * frame[i - 1] + self.prior = frame[-1] + return outfr + + def frame2logspec(self, frame): + frame = self.pre_emphasis(frame) * self.win + fft = numpy.fft.rfft(frame, self.nfft) + # Square of absolute value + power = fft.real * fft.real + fft.imag * fft.imag + return numpy.log(numpy.dot(power, self.filters).clip(1e-5, numpy.inf)) + + def frame2s2mfc(self, frame): + logspec = self.frame2logspec(frame) + return numpy.dot(logspec, self.s2dct.T) / self.nfilt + + def sig2s2mfc_energy(self, sig, dn): + nfr = int(len(sig) / self.fshift + 1) + + mfcc = numpy.zeros((nfr, self.ncep + 2), 'd') + fr = 0 + while fr < nfr: + start = int(round(fr * self.fshift)) + end = min(len(sig), start + self.wlen) + frame = sig[start:end] + if len(frame) < self.wlen: + frame = numpy.resize(frame, self.wlen) + frame[self.wlen:] = 0 + mfcc[fr, :-2] = self.frame2s2mfc(frame) + mfcc[fr, -2] = math.log(1 + numpy.mean(numpy.power(frame.astype(float), 2))) + mid = 0.5 * (start + end - 1) + mfcc[fr, -1] = mid / self.samprate + + fr = fr + 1 + return mfcc + + +def s2dctmat(nfilt, ncep, freqstep): + """Return the 'legacy' not-quite-DCT matrix used by Sphinx""" + melcos = numpy.empty((ncep, nfilt), 'double') + for i in range(0, ncep): + freq = numpy.pi * float(i) / nfilt + melcos[i] = numpy.cos(freq * numpy.arange(0.5, float(nfilt) + 0.5, 1.0, 'double')) + melcos[:, 0] = melcos[:, 0] * 0.5 + return melcos + + +def logspec2s2mfc(logspec, ncep=13): + """Convert log-power-spectrum bins to MFCC using the 'legacy' + Sphinx transform""" + nframes, nfilt = logspec.shape + melcos = s2dctmat(nfilt, ncep, 1. / nfilt) + return numpy.dot(logspec, melcos.T) / nfilt + + +def dctmat(N, K, freqstep, orthogonalize=True): + """Return the orthogonal DCT-II/DCT-III matrix of size NxK. + For computing or inverting MFCCs, N is the number of + log-power-spectrum bins while K is the number of cepstra.""" + cosmat = numpy.zeros((N, K), 'double') + for n in range(0, N): + for k in range(0, K): + cosmat[n, k] = numpy.cos(freqstep * (n + 0.5) * k) + if orthogonalize: + cosmat[:, 0] = cosmat[:, 0] * 1. / numpy.sqrt(2) + return cosmat + + +def dct(input, K=13): + """Convert log-power-spectrum to MFCC using the orthogonal DCT-II""" + nframes, N = input.shape + freqstep = numpy.pi / N + cosmat = dctmat(N, K, freqstep) + return numpy.dot(input, cosmat) * numpy.sqrt(2.0 / N) + + +def dct2(input, K=13): + """Convert log-power-spectrum to MFCC using the normalized DCT-II""" + nframes, N = input.shape + freqstep = numpy.pi / N + cosmat = dctmat(N, K, freqstep, False) + return numpy.dot(input, cosmat) * (2.0 / N) + + +def idct(input, K=40): + """Convert MFCC to log-power-spectrum using the orthogonal DCT-III""" + nframes, N = input.shape + freqstep = numpy.pi / K + cosmat = dctmat(K, N, freqstep).T + return numpy.dot(input, cosmat) * numpy.sqrt(2.0 / K) + + +def dct3(input, K=40): + """Convert MFCC to log-power-spectrum using the unnormalized DCT-III""" + nframes, N = input.shape + freqstep = numpy.pi / K + cosmat = dctmat(K, N, freqstep, False) + cosmat[:, 0] = cosmat[:, 0] * 0.5 + return numpy.dot(input, cosmat.T) + + +if __name__ == '__main__': + ''' + cd codebook/Speech2GestureMatching/ + python mfcc.py + ''' + obj = MFCC(frate=20) # 60 -> 20 + root = '/mnt/nfs7/y50021900/My/data/BEAT0909/' + target_path = 'MFCC_20' + if not os.path.exists(os.path.join(root, target_path)): + os.mkdir(os.path.join(root, target_path)) + + for item in os.listdir(os.path.join(root, 'Audio_normalized')): + print(item) + wav_path = os.path.join(root, 'Audio_normalized', item) + wav, fs = librosa.load(wav_path, sr=16000) # 2117849 / 16000 = 132.3655625 + mfcc = obj.sig2s2mfc_energy(wav, None) # (13237, 15) with -1:1.275ใ€2.275ใ€3.375... + print(mfcc[:, :-2].shape) # -1 -> -2 + numpy.savez_compressed(os.path.join(root, target_path, item[:-4] + '.npz'), mfcc=mfcc[:, :-2]) diff --git a/main/mydiffusion_zeggs/output/train_DiffuseStyleGesture/Readme.txt b/main/mydiffusion_zeggs/output/train_DiffuseStyleGesture/Readme.txt new file mode 100644 index 0000000000000000000000000000000000000000..fc262c6626fe5ba5f5d268df13a11622cba5e42d --- /dev/null +++ b/main/mydiffusion_zeggs/output/train_DiffuseStyleGesture/Readme.txt @@ -0,0 +1 @@ +Models saved here \ No newline at end of file diff --git a/main/mydiffusion_zeggs/sample.py b/main/mydiffusion_zeggs/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..d623eec0848384cf8ad8fa5dd360ef1b20f470be --- /dev/null +++ b/main/mydiffusion_zeggs/sample.py @@ -0,0 +1,419 @@ +import sys +[sys.path.append(i) for i in ['.', '..', '../process', '../model', '../../ubisoft-laforge-ZeroEGGS-main', '../../ubisoft-laforge-ZeroEGGS-main/ZEGGS']] +from model.mdm import MDM +from utils.model_util import create_gaussian_diffusion, load_model_wo_clip +import subprocess +import os +from datetime import datetime +from mfcc import MFCC +import librosa +import numpy as np +import yaml +from pprint import pprint +import torch +import torch.nn.functional as F +from easydict import EasyDict +import math +from process_zeggs_bvh import pose2bvh, quat # '../process' +import argparse + +style2onehot = { +'Happy':[1, 0, 0, 0, 0, 0], +'Sad':[0, 1, 0, 0, 0, 0], +'Neutral':[0, 0, 1, 0, 0, 0], +'Old':[0, 0, 0, 1, 0, 0], +'Angry':[0, 0, 0, 0, 1, 0], +'Relaxed':[0, 0, 0, 0, 0, 1], +} + + +def wavlm_init(device=torch.device('cuda:2')): + import sys + [sys.path.append(i) for i in ['./WavLM']] + from WavLM import WavLM, WavLMConfig + wavlm_model_path = './WavLM/WavLM-Large.pt' + checkpoint = torch.load(wavlm_model_path, map_location=torch.device('cpu')) # load the pre-trained checkpoints + cfg = WavLMConfig(checkpoint['cfg']) + model = WavLM(cfg) + model = model.to(device) + model.load_state_dict(checkpoint['model']) + model.eval() + return model + + +def wav2wavlm(model, wav_input_16khz, device=torch.device('cuda:2')): + wav_input_16khz = wav_input_16khz.to(device) + rep = model.extract_features(wav_input_16khz)[0] + rep = F.interpolate(rep.transpose(1, 2), size=88, align_corners=True, mode='linear').transpose(1, 2) + return rep + + +def create_model_and_diffusion(args): + model = MDM(modeltype='', njoints=1141, nfeats=1, translation=True, pose_rep='rot6d', glob=True, + glob_rot=True, cond_mode = 'cross_local_attention3_style1', clip_version = 'ViT-B/32', action_emb = 'tensor', audio_feat=args.audio_feat, + arch='trans_enc', latent_dim=256, n_seed=8) # trans_enc, trans_dec, gru, mytrans_enc + diffusion = create_gaussian_diffusion() + return model, diffusion + + +def inference_mfcc(args, mfcc, sample_fn, model, n_frames=0, smoothing=False, SG_filter=False, minibatch=False, skip_timesteps=0, n_seed=8, style=None, seed=123456, smooth_foot=False): + + torch.manual_seed(seed) + + if n_frames == 0: + n_frames = mfcc.shape[0] + if minibatch: + stride_poses = args.n_poses - n_seed + if n_frames < stride_poses: + num_subdivision = 1 + else: + num_subdivision = math.floor(n_frames / stride_poses) + n_frames = num_subdivision * stride_poses + print( + '{}, {}, {}'.format(num_subdivision, stride_poses, n_frames)) + mfcc = mfcc[:n_frames] + + model_kwargs_ = {'y': {}} + model_kwargs_['y']['mask'] = (torch.zeros([1, 1, 1, n_frames]) < 1).to(mydevice) + model_kwargs_['y']['style'] = torch.as_tensor([style]).float().to(mydevice) + model_kwargs_['y']['mask_local'] = torch.ones(1, args.n_poses).bool().to(mydevice) + + # tmp_mfcc = torch.from_numpy(np.load('10_kieks_0_9_16.npz')['mfcc'][:n_frames]).to(torch.float32).unsqueeze(0).to(mydevice) + # model_kwargs_['y']['audio'] = tmp_mfcc.permute(1, 0, 2) + + if minibatch: + audio_reshape = torch.from_numpy(mfcc).to(torch.float32).reshape(num_subdivision, stride_poses, -1).to(mydevice).permute(1, 0, 2) # mfcc[:, :-2] + shape_ = (1, model.njoints, model.nfeats, args.n_poses) + out_list = [] + for i in range(0, num_subdivision): + print(i, num_subdivision) + model_kwargs_['y']['audio'] = audio_reshape[:, i:i + 1, :] + if i == 0: + if n_seed != 0: + pad_zeros = torch.zeros([n_seed, 1, 13]).to(mydevice) # mfcc dims are 13 + model_kwargs_['y']['audio'] = torch.cat((pad_zeros, model_kwargs_['y']['audio']), 0) + model_kwargs_['y']['seed'] = torch.zeros([1, 1141, 1, n_seed]).to(mydevice) + else: + if n_seed != 0: + pad_audio = audio_reshape[-n_seed:, i - 1:i, :] + model_kwargs_['y']['audio'] = torch.cat((pad_audio, model_kwargs_['y']['audio']), 0) + model_kwargs_['y']['seed'] = out_list[-1][..., -n_seed:].to(mydevice) + + sample = sample_fn( + model, + shape_, + clip_denoised=False, + model_kwargs=model_kwargs_, + skip_timesteps=skip_timesteps, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=True, + dump_steps=None, + noise=None, # None, torch.randn(*shape_, device=mydevice) + const_noise=False, + ) + # smoothing motion transition + if len(out_list) > 0 and n_seed != 0: + last_poses = out_list[-1][..., -n_seed:] # # (1, model.njoints, 1, n_seed) + out_list[-1] = out_list[-1][..., :-n_seed] # delete last 4 frames + if smoothing: + # Extract predictions + last_poses_root_pos = last_poses[:, 0:3] # (1, 3, 1, 8) + # last_poses_root_rot = last_poses[:, 3:7] + # last_poses_root_vel = last_poses[:, 7:10] + # last_poses_root_vrt = last_poses[:, 10:13] + next_poses_root_pos = sample[:, 0:3] # (1, 3, 1, 88) + # next_poses_root_rot = sample[:, 3:7] + # next_poses_root_vel = sample[:, 7:10] + # next_poses_root_vrt = sample[:, 10:13] + root_pos = last_poses_root_pos[..., 0] # (1, 3, 1) + predict_pos = next_poses_root_pos[..., 0] + delta_pos = (predict_pos - root_pos).unsqueeze(-1) # # (1, 3, 1, 1) + sample[:, 0:3] = sample[:, 0:3] - delta_pos + + if smooth_foot: + njoints = 75 + length = n_seed + last_poses_lpos = last_poses[:, 13 + njoints * 0: 13 + njoints * 3].reshape([length, njoints, 3]) + last_poses_LeftToeBase = last_poses_lpos[0, -4] + last_poses_RightToeBase = last_poses_lpos[0, -11] + + next_poses_lpos = sample[:, 13 + njoints * 0: 13 + njoints * 3].reshape([args.n_poses, njoints, 3]) + next_poses_LeftToeBase = next_poses_lpos[0, -4] + next_poses_RightToeBase = next_poses_lpos[0, -11] + + delta_poses_LeftToeBase = (next_poses_LeftToeBase - last_poses_LeftToeBase) + delta_poses_RightToeBase = (next_poses_RightToeBase - last_poses_RightToeBase) + + next_poses_lpos[:, -4] = (next_poses_lpos[:, -4] - delta_poses_LeftToeBase) + next_poses_lpos[:, -11] = (next_poses_lpos[:, -11] - delta_poses_RightToeBase) + sample[:, 13 + njoints * 0: 13 + njoints * 3] = next_poses_lpos.reshape(1, -1, 1, args.n_poses) + + for j in range(len(last_poses)): + n = len(last_poses) + prev = last_poses[..., j] + next = sample[..., j] + sample[..., j] = prev * (n - j) / (n + 1) + next * (j + 1) / (n + 1) + out_list.append(sample) + + if n_seed != 0: + out_list[-1] = out_list[-1][..., :-n_seed] + out_list = [i.detach().data.cpu().numpy() for i in out_list] + out_dir_vec = np.vstack(out_list) + sampled_seq = out_dir_vec.squeeze(2).transpose(0, 2, 1).reshape(batch_size, n_frames, model.njoints) + sampled_seq = sampled_seq[:, n_seed:] + else: + out_list = [i.detach().data.cpu().numpy() for i in out_list] + out_dir_vec = np.vstack(out_list) + sampled_seq = out_dir_vec.squeeze(2).transpose(0, 2, 1).reshape(batch_size, n_frames, model.njoints) + else: + model_kwargs_['y']['audio'] = torch.from_numpy(mfcc).to(torch.float32).unsqueeze(0).to(mydevice).permute(1, 0, 2) + shape_ = (batch_size, model.njoints, model.nfeats, n_frames) + model_kwargs_['y']['seed'] = torch.zeros([1, 1141, 1, n_seed]).to(mydevice) + sample = sample_fn( + model, + shape_, + clip_denoised=False, + model_kwargs=model_kwargs_, + skip_timesteps=skip_timesteps, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=True, + dump_steps=None, + noise=None, # None, torch.randn(*shape_, device=mydevice) + const_noise=False, + ) + out_dir_vec = sample.data.cpu().numpy() + sampled_seq = out_dir_vec.squeeze(2).transpose(0, 2, 1).reshape(batch_size, n_frames, model.njoints) + + data_mean_ = np.load("../../ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/mean.npz")['mean'].squeeze() + data_std_ = np.load("../../ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/std.npz")['std'].squeeze() + + data_mean = np.array(data_mean_).squeeze() + data_std = np.array(data_std_).squeeze() + std = np.clip(data_std, a_min=0.01, a_max=None) + out_poses = np.multiply(sampled_seq[0], std) + data_mean + print(out_poses.shape) + pipeline_path = '../../../My/process/resource/data_pipe_20_rotation.sav' + prefix = str(datetime.now().strftime('%Y%m%d_%H%M%S')) + if smoothing: prefix += '_smoothing' + if smooth_foot: prefix += 'smoothfoot' + if SG_filter: prefix += '_SG' + if minibatch: prefix += '_minibatch' + prefix += '_%s' % (n_frames) + prefix += '_' + str(style) + prefix += '_' + str(seed) + if minibatch: + pose2bvh(out_poses, os.path.join(save_dir, prefix + '.bvh'), length=n_frames - n_seed, smoothing=SG_filter) + else: + pose2bvh(out_poses, os.path.join(save_dir, prefix + '.bvh'), length=n_frames, smoothing=SG_filter) + + +def inference(args, wavlm_model, audio, sample_fn, model, n_frames=0, smoothing=False, SG_filter=False, minibatch=False, skip_timesteps=0, n_seed=8, style=None, seed=123456): + + torch.manual_seed(seed) + + if n_frames == 0: + n_frames = audio.shape[0] * 20 // 16000 + if minibatch: + stride_poses = args.n_poses - n_seed + if n_frames < stride_poses: + num_subdivision = 1 + else: + num_subdivision = math.floor(n_frames / stride_poses) + n_frames = num_subdivision * stride_poses + print( + '{}, {}, {}'.format(num_subdivision, stride_poses, n_frames)) + audio = audio[:int(n_frames * 16000 / 20)] + + model_kwargs_ = {'y': {}} + model_kwargs_['y']['mask'] = (torch.zeros([1, 1, 1, n_frames]) < 1).to(mydevice) + model_kwargs_['y']['style'] = torch.as_tensor([style]).float().to(mydevice) + model_kwargs_['y']['mask_local'] = torch.ones(1, args.n_poses).bool().to(mydevice) + + if minibatch: + audio_reshape = torch.from_numpy(audio).to(torch.float32).reshape(num_subdivision, int(stride_poses * 16000 / 20)).to(mydevice).transpose(0, 1) # mfcc[:, :-2] + shape_ = (1, model.njoints, model.nfeats, args.n_poses) + out_list = [] + for i in range(0, num_subdivision): + print(i, num_subdivision) + model_kwargs_['y']['audio'] = audio_reshape[:, i:i + 1] + + if i == 0: + if n_seed != 0: + pad_zeros = torch.zeros([int(n_seed * 16000 / 20), 1]).to(mydevice) # wavlm dims are 1024 + model_kwargs_['y']['audio'] = torch.cat((pad_zeros, model_kwargs_['y']['audio']), 0) + model_kwargs_['y']['seed'] = torch.zeros([1, 1141, 1, n_seed]).to(mydevice) + else: + if n_seed != 0: + pad_audio = audio_reshape[-int(n_seed * 16000 / 20):, i - 1:i] + model_kwargs_['y']['audio'] = torch.cat((pad_audio, model_kwargs_['y']['audio']), 0) + model_kwargs_['y']['seed'] = out_list[-1][..., -n_seed:].to(mydevice) + + model_kwargs_['y']['audio'] = wav2wavlm(wavlm_model, model_kwargs_['y']['audio'].transpose(0, 1), mydevice) + + sample = sample_fn( + model, + shape_, + clip_denoised=False, + model_kwargs=model_kwargs_, + skip_timesteps=skip_timesteps, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=True, + dump_steps=None, + noise=None, # None, torch.randn(*shape_, device=mydevice) + const_noise=False, + ) + # smoothing motion transition + if len(out_list) > 0 and n_seed != 0: + last_poses = out_list[-1][..., -n_seed:] # # (1, model.njoints, 1, n_seed) + out_list[-1] = out_list[-1][..., :-n_seed] # delete last 4 frames + if smoothing: + # Extract predictions + last_poses_root_pos = last_poses[:, 0:3] # (1, 3, 1, 8) + # last_poses_root_rot = last_poses[:, 3:7] + # last_poses_root_vel = last_poses[:, 7:10] + # last_poses_root_vrt = last_poses[:, 10:13] + next_poses_root_pos = sample[:, 0:3] # (1, 3, 1, 88) + # next_poses_root_rot = sample[:, 3:7] + # next_poses_root_vel = sample[:, 7:10] + # next_poses_root_vrt = sample[:, 10:13] + root_pos = last_poses_root_pos[..., 0] # (1, 3, 1) + predict_pos = next_poses_root_pos[..., 0] + delta_pos = (predict_pos - root_pos).unsqueeze(-1) # # (1, 3, 1, 1) + sample[:, 0:3] = sample[:, 0:3] - delta_pos + + for j in range(len(last_poses)): + n = len(last_poses) + prev = last_poses[..., j] + next = sample[..., j] + sample[..., j] = prev * (n - j) / (n + 1) + next * (j + 1) / (n + 1) + out_list.append(sample) + + if n_seed != 0: + out_list[-1] = out_list[-1][..., :-n_seed] + out_list = [i.detach().data.cpu().numpy() for i in out_list] + out_dir_vec = np.vstack(out_list) + sampled_seq = out_dir_vec.squeeze(2).transpose(0, 2, 1).reshape(batch_size, n_frames, model.njoints) + sampled_seq = sampled_seq[:, n_seed:] + else: + out_list = [i.detach().data.cpu().numpy() for i in out_list] + out_dir_vec = np.vstack(out_list) + sampled_seq = out_dir_vec.squeeze(2).transpose(0, 2, 1).reshape(batch_size, n_frames, model.njoints) + else: + model_kwargs_['y']['audio'] = torch.from_numpy(mfcc).to(torch.float32).unsqueeze(0).to(mydevice).permute(1, 0, 2) + shape_ = (batch_size, model.njoints, model.nfeats, n_frames) + model_kwargs_['y']['seed'] = torch.zeros([1, 1141, 1, n_seed]).to(mydevice) + sample = sample_fn( + model, + shape_, + clip_denoised=False, + model_kwargs=model_kwargs_, + skip_timesteps=skip_timesteps, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=True, + dump_steps=None, + noise=None, # None, torch.randn(*shape_, device=mydevice) + const_noise=False, + ) + out_dir_vec = sample.data.cpu().numpy() + sampled_seq = out_dir_vec.squeeze(2).transpose(0, 2, 1).reshape(batch_size, n_frames, model.njoints) + + data_mean_ = np.load("../../ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/mean.npz")['mean'].squeeze() + data_std_ = np.load("../../ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/std.npz")['std'].squeeze() + + data_mean = np.array(data_mean_).squeeze() + data_std = np.array(data_std_).squeeze() + std = np.clip(data_std, a_min=0.01, a_max=None) + out_poses = np.multiply(sampled_seq[0], std) + data_mean + print(out_poses.shape) + prefix = str(datetime.now().strftime('%Y%m%d_%H%M%S')) + if smoothing: prefix += '_smoothing' + if SG_filter: prefix += '_SG' + if minibatch: prefix += '_minibatch' + prefix += '_%s' % (n_frames) + prefix += '_' + str(style) + prefix += '_' + str(seed) + if minibatch: + pose2bvh(out_poses, os.path.join(save_dir, prefix + '.bvh'), length=n_frames - n_seed, smoothing=SG_filter) + else: + pose2bvh(out_poses, os.path.join(save_dir, prefix + '.bvh'), length=n_frames, smoothing=SG_filter) + + +def main(args, save_dir, model_path, audio_path=None, mfcc_path=None, audiowavlm_path=None, max_len=0): + if not os.path.exists(save_dir): + os.mkdir(save_dir) + + if audiowavlm_path != None: + mfcc, fs = librosa.load(audiowavlm_path, sr=16000) + + elif audio_path != None and mfcc_path == None: + # normalize_audio + audio_name = audio_path.split('/')[-1] + print('normalize audio: ' + audio_name) + normalize_wav_path = os.path.join(save_dir, 'normalize_' + audio_name) + cmd = ['ffmpeg-normalize', audio_path, '-o', normalize_wav_path, '-ar', '16000'] + subprocess.call(cmd) + + # MFCC, https://github.com/supasorn/synthesizing_obama_network_training + print('extract MFCC...') + obj = MFCC(frate=20) + wav, fs = librosa.load(normalize_wav_path, sr=16000) + mfcc = obj.sig2s2mfc_energy(wav, None) + print(mfcc[:, :-2].shape) # -1 -> -2 # (502, 13) + np.savez_compressed(os.path.join(save_dir, audio_name[:-4] + '.npz'), mfcc=mfcc[:, :-2]) + + elif mfcc_path != None and audio_path == None: + mfcc = np.load(mfcc_path)['mfcc'] + + # sample + print("Creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args) + print(f"Loading checkpoints from [{model_path}]...") + state_dict = torch.load(model_path, map_location='cpu') + load_model_wo_clip(model, state_dict) + model.to(mydevice) + model.eval() + + sample_fn = diffusion.p_sample_loop # predict x_start + + style = style2onehot[audiowavlm_path.split('/')[-1].split('_')[1]] + print(style) + + wavlm_model = wavlm_init(mydevice) + inference(args, wavlm_model, mfcc, sample_fn, model, n_frames=max_len, smoothing=True, SG_filter=True, minibatch=True, skip_timesteps=0, style=style, seed=123456) # style2onehot['Happy'] + + +if __name__ == '__main__': + ''' + cd /ceph/hdd/yangsc21/Python/DSG/ + ''' + + # audio_path = '../../../My/Test_audio/Example1/ZeroEGGS_cut.wav' + # mfcc_path = "../../ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/valid/mfcc/015_Happy_4_mirror_x_1_0.npz" # 010_Sad_4_x_1_0.npz + # audiowavlm_path = "./015_Happy_4_x_1_0.wav" + + # prefix = str(datetime.now().strftime('%Y%m%d_%H%M%S')) + # save_dir = 'sample_' + prefix + save_dir = 'sample_dir' + + parser = argparse.ArgumentParser(description='DiffuseStyleGesture') + parser.add_argument('--config', default='./configs/DiffuseStyleGesture.yml') + parser.add_argument('--gpu', type=str, default='2') + parser.add_argument('--no_cuda', type=list, default=['2']) + parser.add_argument('--model_path', type=str, default='./model000450000.pt') + parser.add_argument('--audiowavlm_path', type=str, default='') + parser.add_argument('--max_len', type=int, default=0) + args = parser.parse_args() + with open(args.config) as f: + config = yaml.safe_load(f) + for k, v in vars(args).items(): + config[k] = v + pprint(config) + + config = EasyDict(config) + mydevice = torch.device('cuda:' + config.gpu) + torch.cuda.set_device(int(config.gpu)) + + batch_size = 1 + + main(config, save_dir, config.model_path, audio_path=None, mfcc_path=None, audiowavlm_path=config.audiowavlm_path, max_len=config.max_len) + diff --git a/main/mydiffusion_zeggs/zeggs_data_to_lmdb.py b/main/mydiffusion_zeggs/zeggs_data_to_lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..6859a755d8f634dd776b84109696a4fd2e161ad6 --- /dev/null +++ b/main/mydiffusion_zeggs/zeggs_data_to_lmdb.py @@ -0,0 +1,176 @@ +import os +import glob +import pdb +import subprocess +import numpy as np +import lmdb +import pyarrow +from mfcc import MFCC +import soundfile as sf +import sys +[sys.path.append(i) for i in ['.', '..', '../process']] +from process_zeggs_bvh import preprocess_animation, pose2bvh + + +style2onehot = { +'Happy':[1, 0, 0, 0, 0, 0], +'Sad':[0, 1, 0, 0, 0, 0], +'Neutral':[0, 0, 1, 0, 0, 0], +'Old':[0, 0, 0, 1, 0, 0], +'Angry':[0, 0, 0, 0, 1, 0], +'Relaxed':[0, 0, 0, 0, 0, 1], +} + +def make_lmdb_gesture_dataset(root_path): + + def make_lmdb_gesture_subdataset(base_path, lmdb_subname): + gesture_path = os.path.join(base_path, 'gesture_npz') + audio_path = os.path.join(base_path, 'normalize_audio_npz') + mfcc_path = os.path.join(base_path, 'mfcc') + out_path = os.path.join(base_path, lmdb_name) + if not os.path.exists(out_path): + os.makedirs(out_path) + + map_size = 1024 * 200 # in MB + map_size <<= 20 # in B + dataset_idx = 0 + + db = [lmdb.open(os.path.join(out_path, lmdb_subname), map_size=map_size)] + + # delete existing files + for i in range(1): + with db[i].begin(write=True) as txn: + txn.drop(db[i].open_db()) + + all_poses = [] + bvh_files = sorted(glob.glob(gesture_path + "/*.npz")) + v_i = 0 + + for _, bvh_file in enumerate(bvh_files): + name = os.path.split(bvh_file)[1][:-4] + if name.split('_')[1] in style2onehot: + style = style2onehot[name.split('_')[1]] + else: + continue + + print('process: ' + name) + + poses = np.load(bvh_file)['gesture'] + audio_raw = np.load(os.path.join(audio_path, name + '.npz'))['wav'] + mfcc_raw = np.load(os.path.join(mfcc_path, name + '.npz'))['mfcc'] + + # process + clips = [{'vid': name, 'clips': []}] # train and test + + data_mean = np.load(os.path.join(root_path, 'mean.npz'))['mean'] + data_std = np.load(os.path.join(root_path, 'std.npz'))['std'] + data_mean = np.array(data_mean).squeeze() + data_std = np.array(data_std).squeeze() + std = np.clip(data_std, a_min=0.01, a_max=None) + poses = (poses - data_mean) / std + + poses = np.asarray(poses) + clips[dataset_idx]['clips'].append( + { # 'words': word_list, + 'poses': poses, + 'audio_raw': audio_raw, + 'mfcc_raw': mfcc_raw, # for debug + 'style_raw': np.array(style) # for debug + }) + + # write to db + for i in range(1): + with db[i].begin(write=True) as txn: + if len(clips[i]['clips']) > 0: + k = '{:010}'.format(v_i).encode('ascii') + v = pyarrow.serialize(clips[i]).to_buffer() + txn.put(k, v) + + all_poses.append(poses) + v_i += 1 + + print('total length of dataset: ' + str(v_i)) + + # close db + for i in range(1): + db[i].sync() + db[i].close() + + train_path = os.path.join(root_path, 'train') + lmdb_name = 'train_lmdb' + make_lmdb_gesture_subdataset(train_path, lmdb_name) + test_path = os.path.join(root_path, 'valid') + lmdb_name = 'valid_lmdb' + make_lmdb_gesture_subdataset(test_path, lmdb_name) + + +def make_zeggs_dataset(source_path, target): + if not os.path.exists(target): + os.mkdir(target) + + def make_zeggs_subdataset(source_path, target, all_poses): + if not os.path.exists(target): + os.mkdir(target) + target_audio_path = os.path.join(target, 'normalize_audio') + target_audionpz_path = os.path.join(target, 'normalize_audio_npz') + target_gesture_path = os.path.join(target, 'gesture_npz') + target_mfcc_path = os.path.join(target, 'mfcc') + if not os.path.exists(target_audio_path): + os.mkdir(target_audio_path) + if not os.path.exists(target_mfcc_path): + os.mkdir(target_mfcc_path) + if not os.path.exists(target_audionpz_path): + os.mkdir(target_audionpz_path) + if not os.path.exists(target_gesture_path): + os.mkdir(target_gesture_path) + wav_files = sorted(glob.glob(source_path + "/*.wav")) + for _, wav_file in enumerate(wav_files): + name = os.path.split(wav_file)[1][:-4] + print(name) + # audio + print('normalize audio: ' + name + '.wav') + normalize_wav_path = os.path.join(target_audio_path, name + '.wav') + cmd = ['ffmpeg-normalize', wav_file, '-o', normalize_wav_path, '-ar', '16000'] + subprocess.call(cmd) + print('extract MFCC...') + obj = MFCC(frate=20) + # wav, fs = librosa.load(normalize_wav_path, sr=16000) + wav, fs = sf.read(normalize_wav_path) + mfcc = obj.sig2s2mfc_energy(wav, None) + print(mfcc[:, :-2].shape) # -1 -> -2 # (502, 13) + np.savez_compressed(os.path.join(target_mfcc_path, name + '.npz'), mfcc=mfcc[:, :-2]) + np.savez_compressed(os.path.join(target_audionpz_path, name + '.npz'), wav=wav) + # bvh + print('extract gesture...') + bvh_file = os.path.join(source_path, name + '.bvh') + pose, parents, dt, order, njoints = preprocess_animation(bvh_file, fps=20) + print(pose.shape) + np.savez_compressed(os.path.join(target_gesture_path, name + '.npz'), gesture=pose) + all_poses.append(pose) + + return all_poses + + source_path_train = os.path.join(source_path, 'train') + target_train = os.path.join(target, 'train') + all_poses = [] + all_poses = make_zeggs_subdataset(source_path_train, target_train, all_poses) + source_path_test = os.path.join(source_path, 'valid') + target_test = os.path.join(target, 'valid') + all_poses = make_zeggs_subdataset(source_path_test, target_test, all_poses) + + all_poses = np.vstack(all_poses) + pose_mean = np.mean(all_poses, axis=0, dtype=np.float64) + pose_std = np.std(all_poses, axis=0, dtype=np.float64) + np.savez_compressed(os.path.join(target, 'mean.npz'), mean=pose_mean) + np.savez_compressed(os.path.join(target, 'std.npz'), std=pose_std) + + +if __name__ == '__main__': + ''' + python zeggs_data_to_lmdb.py + ''' + source_path = '../../ubisoft-laforge-ZeroEGGS-main/data/processed_v1/trimmed/' + target = '../../ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/' + make_zeggs_dataset(source_path, target) + make_lmdb_gesture_dataset(target) + diff --git a/main/prepare/download_a2m_datasets.sh b/main/prepare/download_a2m_datasets.sh new file mode 100644 index 0000000000000000000000000000000000000000..fdfc7a8e42dba84714d9d320013c6e22505d9bfa --- /dev/null +++ b/main/prepare/download_a2m_datasets.sh @@ -0,0 +1,22 @@ +mkdir -p dataset/ +cd dataset/ + +echo "The datasets will be stored in the 'dataset' folder\n" + +# HumanAct12 poses +echo "Downloading the HumanAct12 poses dataset" +gdown "https://drive.google.com/uc?id=1130gHSvNyJmii7f6pv5aY5IyQIWc3t7R" +echo "Extracting the HumanAct12 poses dataset" +tar xfzv HumanAct12Poses.tar.gz +echo "Cleaning\n" +rm HumanAct12Poses.tar.gz + +# Donwload UESTC poses estimated with VIBE +echo "Downloading the UESTC poses estimated with VIBE" +gdown "https://drive.google.com/uc?id=1LE-EmYNzECU8o7A2DmqDKtqDMucnSJsy" +echo "Extracting the UESTC poses estimated with VIBE" +tar xjvf uestc.tar.bz2 +echo "Cleaning\n" +rm uestc.tar.bz2 + +echo "Downloading done!" diff --git a/main/prepare/download_glove.sh b/main/prepare/download_glove.sh new file mode 100644 index 0000000000000000000000000000000000000000..ccf270b59ae637313fcde9a856e25ab5044df753 --- /dev/null +++ b/main/prepare/download_glove.sh @@ -0,0 +1,9 @@ +echo -e "Downloading glove (in use by the evaluators, not by MDM itself)" +gdown --fuzzy https://drive.google.com/file/d/1cmXKUT31pqd7_XpJAiWEo1K81TMYHA5n/view?usp=sharing +rm -rf glove + +unzip glove.zip +echo -e "Cleaning\n" +rm glove.zip + +echo -e "Downloading done!" \ No newline at end of file diff --git a/main/prepare/download_recognition_models.sh b/main/prepare/download_recognition_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..4c7a663b02c4df420be45d4aed88835b0a7517a9 --- /dev/null +++ b/main/prepare/download_recognition_models.sh @@ -0,0 +1,12 @@ +mkdir -p assets/actionrecognition/ +cd assets/actionrecognition/ + +echo -e "Downloading the HumanAct12 action recognition model" +wget https://raw.githubusercontent.com/EricGuo5513/action-to-motion/master/model_file/action_recognition_model_humanact12.tar -O humanact12_gru.tar +echo -e + +echo -e "Downloading the UESTC action recognition model" +gdown "https://drive.google.com/uc?id=1bSSD69s1dHY7Uk0RGbGc6p7uhUxSDSBK" +echo -e + +echo -e "Downloading done!" diff --git a/main/prepare/download_recognition_unconstrained_models.sh b/main/prepare/download_recognition_unconstrained_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..487471d43ff4cd4e31ecdf2335537d4c9b540a8b --- /dev/null +++ b/main/prepare/download_recognition_unconstrained_models.sh @@ -0,0 +1,8 @@ +mkdir -p assets/actionrecognition/ +cd assets/actionrecognition/ + +echo -e "Downloading the HumanAct12 action recognition model, adjusted for the unconstrained setting." +gdown "1xfigimkPxKt3a8zvn_ME_NAR6CyTqneK" +echo -e + +echo -e "Downloading done!" diff --git a/main/prepare/download_smpl_files.sh b/main/prepare/download_smpl_files.sh new file mode 100644 index 0000000000000000000000000000000000000000..595175e0fb2016f52f56cdcf448e38765cbb86e5 --- /dev/null +++ b/main/prepare/download_smpl_files.sh @@ -0,0 +1,12 @@ +mkdir -p body_models +cd body_models/ + +echo -e "The smpl files will be stored in the 'body_models/smpl/' folder\n" +gdown "https://drive.google.com/uc?id=1INYlGA76ak_cKGzvpOV2Pe6RkYTlXTW2" +rm -rf smpl + +unzip smpl.zip +echo -e "Cleaning\n" +rm smpl.zip + +echo -e "Downloading done!" \ No newline at end of file diff --git a/main/prepare/download_t2m_evaluators.sh b/main/prepare/download_t2m_evaluators.sh new file mode 100644 index 0000000000000000000000000000000000000000..2db73275e75a7d13f35a40a8b16fbedab02561be --- /dev/null +++ b/main/prepare/download_t2m_evaluators.sh @@ -0,0 +1,13 @@ +echo -e "Downloading T2M evaluators" +gdown --fuzzy https://drive.google.com/file/d/1DSaKqWX2HlwBtVH5l7DdW96jeYUIXsOP/view +gdown --fuzzy https://drive.google.com/file/d/1tX79xk0fflp07EZ660Xz1RAFE33iEyJR/view +rm -rf t2m +rm -rf kit + +unzip t2m.zip +unzip kit.zip +echo -e "Cleaning\n" +rm t2m.zip +rm kit.zip + +echo -e "Downloading done!" \ No newline at end of file diff --git a/main/prepare/download_unconstrained_datasets.sh b/main/prepare/download_unconstrained_datasets.sh new file mode 100644 index 0000000000000000000000000000000000000000..8abf7e92d854b015037fdeef3dc587c84b52deb1 --- /dev/null +++ b/main/prepare/download_unconstrained_datasets.sh @@ -0,0 +1,10 @@ +mkdir -p dataset/HumanAct12Poses +cd dataset/HumanAct12Poses + +echo "The datasets will be stored in the 'dataset' folder\n" + +# HumanAct12 poses unconstrained +echo "Downloading the HumanAct12 unconstrained poses dataset" +gdown "1KqOBTtLFgkvWSZb8ao-wdBMG7sTP3Q7d" + +echo "Downloading done!" diff --git a/main/process/process_zeggs_bvh.py b/main/process/process_zeggs_bvh.py new file mode 100644 index 0000000000000000000000000000000000000000..76cb7152ea99a00453fdee994d065806b34a2bba --- /dev/null +++ b/main/process/process_zeggs_bvh.py @@ -0,0 +1,410 @@ +import json +import pdb +import numpy as np +from omegaconf import DictConfig +import os +os.environ['KMP_DUPLICATE_LIB_OK']='True' +import sys +[sys.path.append(i) for i in ['.', '..', '../../ubisoft-laforge-ZeroEGGS-main/ZEGGS']] + +from anim import bvh, quat, txform +from utils_zeggs import write_bvh +import torch +from scipy.signal import savgol_filter + + +bone_names = [ + "Hips", + "Spine", + "Spine1", + "Spine2", + "Spine3", + "Neck", + "Neck1", + "Head", + "HeadEnd", + "RightShoulder", + "RightArm", + "RightForeArm", + "RightHand", + "RightHandThumb1", + "RightHandThumb2", + "RightHandThumb3", + "RightHandThumb4", + "RightHandIndex1", + "RightHandIndex2", + "RightHandIndex3", + "RightHandIndex4", + "RightHandMiddle1", + "RightHandMiddle2", + "RightHandMiddle3", + "RightHandMiddle4", + "RightHandRing1", + "RightHandRing2", + "RightHandRing3", + "RightHandRing4", + "RightHandPinky1", + "RightHandPinky2", + "RightHandPinky3", + "RightHandPinky4", + "RightForeArmEnd", + "RightArmEnd", + "LeftShoulder", + "LeftArm", + "LeftForeArm", + "LeftHand", + "LeftHandThumb1", + "LeftHandThumb2", + "LeftHandThumb3", + "LeftHandThumb4", + "LeftHandIndex1", + "LeftHandIndex2", + "LeftHandIndex3", + "LeftHandIndex4", + "LeftHandMiddle1", + "LeftHandMiddle2", + "LeftHandMiddle3", + "LeftHandMiddle4", + "LeftHandRing1", + "LeftHandRing2", + "LeftHandRing3", + "LeftHandRing4", + "LeftHandPinky1", + "LeftHandPinky2", + "LeftHandPinky3", + "LeftHandPinky4", + "LeftForeArmEnd", + "LeftArmEnd", + "RightUpLeg", + "RightLeg", + "RightFoot", + "RightToeBase", + "RightToeBaseEnd", + "RightLegEnd", + "RightUpLegEnd", + "LeftUpLeg", + "LeftLeg", + "LeftFoot", + "LeftToeBase", + "LeftToeBaseEnd", + "LeftLegEnd", + "LeftUpLegEnd" + ] + + +def preprocess_animation(animation_file, fps=60): + anim_data = bvh.load(animation_file) # 'rotations' (8116, 75, 3), 'positions', 'offsets' (75, 3), 'parents', 'names' (75,), 'order' 'zyx', 'frametime' 0.016667 + nframes = len(anim_data["rotations"]) + + if fps != 60 : + rate = 60 // fps + anim_data["rotations"] = anim_data["rotations"][0:nframes:rate] + anim_data["positions"] = anim_data["positions"][0:nframes:rate] + dt = 1 / fps + nframes = anim_data["positions"].shape[0] + else: + dt = anim_data["frametime"] + + njoints = len(anim_data["parents"]) + + lrot = quat.unroll(quat.from_euler(np.radians(anim_data["rotations"]), anim_data["order"])) + lpos = anim_data["positions"] + grot, gpos = quat.fk(lrot, lpos, anim_data["parents"]) + # Find root (Projected hips on the ground) + root_pos = gpos[:, anim_data["names"].index("Spine2")] * np.array([1, 0, 1]) + # Root direction + root_fwd = quat.mul_vec(grot[:, anim_data["names"].index("Hips")], np.array([[0, 0, 1]])) + root_fwd[:, 1] = 0 + root_fwd = root_fwd / np.sqrt(np.sum(root_fwd * root_fwd, axis=-1))[..., np.newaxis] + # Root rotation + root_rot = quat.normalize( + quat.between(np.array([[0, 0, 1]]).repeat(len(root_fwd), axis=0), root_fwd) + ) + + # Find look at direction + gaze_lookat = quat.mul_vec(grot[:, anim_data["names"].index("Head")], np.array([0, 0, 1])) + gaze_lookat[:, 1] = 0 + gaze_lookat = gaze_lookat / np.sqrt(np.sum(np.square(gaze_lookat), axis=-1))[..., np.newaxis] + # Find gaze position + gaze_distance = 100 # Assume other actor is one meter away + gaze_pos_all = root_pos + gaze_distance * gaze_lookat + gaze_pos = np.median(gaze_pos_all, axis=0) + gaze_pos = gaze_pos[np.newaxis].repeat(nframes, axis=0) + + # Visualize Gaze Pos + visualize_gaze = False + if visualize_gaze: + import matplotlib.pyplot as plt + + plt.scatter(gaze_pos_all[:, 0], gaze_pos_all[:, 2], s=0.1, marker=".") + plt.scatter(gaze_pos[0, 0], gaze_pos[0, 2]) + plt.scatter(root_pos[:, 0], root_pos[:, 2], s=0.1, marker=".") + plt.quiver(root_pos[::60, 0], root_pos[::60, 2], root_fwd[::60, 0], root_fwd[::60, 2]) + plt.gca().set_aspect("equal") + plt.savefig('1.jpg') + plt.show() + + # Compute local gaze dir + gaze_dir = gaze_pos - root_pos + # gaze_dir = gaze_dir / np.sqrt(np.sum(np.square(gaze_dir), axis=-1))[..., np.newaxis] + gaze_dir = quat.mul_vec(quat.inv(root_rot), gaze_dir) + + # Make relative to root + lrot[:, 0] = quat.mul(quat.inv(root_rot), lrot[:, 0]) + lpos[:, 0] = quat.mul_vec(quat.inv(root_rot), lpos[:, 0] - root_pos) + + # Local velocities + lvel = np.zeros_like(lpos) + lvel[1:] = (lpos[1:] - lpos[:-1]) / dt + lvel[0] = lvel[1] - (lvel[3] - lvel[2]) + + lvrt = np.zeros_like(lpos) + lvrt[1:] = quat.to_helical(quat.abs(quat.mul(lrot[1:], quat.inv(lrot[:-1])))) / dt + lvrt[0] = lvrt[1] - (lvrt[3] - lvrt[2]) + + # Root velocities + root_vrt = np.zeros_like(root_pos) + root_vrt[1:] = quat.to_helical(quat.abs(quat.mul(root_rot[1:], quat.inv(root_rot[:-1])))) / dt + root_vrt[0] = root_vrt[1] - (root_vrt[3] - root_vrt[2]) + root_vrt[1:] = quat.mul_vec(quat.inv(root_rot[:-1]), root_vrt[1:]) + root_vrt[0] = quat.mul_vec(quat.inv(root_rot[0]), root_vrt[0]) + + root_vel = np.zeros_like(root_pos) + root_vel[1:] = (root_pos[1:] - root_pos[:-1]) / dt + root_vel[0] = root_vel[1] - (root_vel[3] - root_vel[2]) + root_vel[1:] = quat.mul_vec(quat.inv(root_rot[:-1]), root_vel[1:]) + root_vel[0] = quat.mul_vec(quat.inv(root_rot[0]), root_vel[0]) + + # Compute character space + crot, cpos, cvrt, cvel = quat.fk_vel(lrot, lpos, lvrt, lvel, anim_data["parents"]) + + # Compute 2-axis transforms + ltxy = np.zeros(dtype=np.float32, shape=[len(lrot), njoints, 2, 3]) + ltxy[..., 0, :] = quat.mul_vec(lrot, np.array([1.0, 0.0, 0.0])) + ltxy[..., 1, :] = quat.mul_vec(lrot, np.array([0.0, 1.0, 0.0])) + + ctxy = np.zeros(dtype=np.float32, shape=[len(crot), njoints, 2, 3]) + ctxy[..., 0, :] = quat.mul_vec(crot, np.array([1.0, 0.0, 0.0])) + ctxy[..., 1, :] = quat.mul_vec(crot, np.array([0.0, 1.0, 0.0])) + + # return ( + # root_pos, + # root_rot, + # root_vel, + # root_vrt, + # lpos, + # lrot, + # ltxy, + # lvel, + # lvrt, + # cpos, + # crot, + # ctxy, + # cvel, + # cvrt, + # gaze_pos, + # gaze_dir, + # ), anim_data["parents"], dt, anim_data["order"] + + lpos = lpos.reshape(nframes, -1) + ltxy = ltxy.reshape(nframes, -1) + lvel = lvel.reshape(nframes, -1) + lvrt = lvrt.reshape(nframes, -1) + + all_poses = np.concatenate((root_pos, root_rot, root_vel, root_vrt, lpos, ltxy, lvel, lvrt, gaze_dir), axis=1) + + return all_poses, anim_data["parents"], dt, anim_data["order"], njoints + + +def pose2bvh(poses, outpath, length, smoothing=False, smooth_foot=False): + parents = np.array([-1, 0, 1, 2, 3, 4, 5, 6, 7, 4, 9, 10, 11, 12, 13, 14, 15, + 12, 17, 18, 19, 12, 21, 22, 23, 12, 25, 26, 27, 12, 29, 30, 31, 12, + 11, 4, 35, 36, 37, 38, 39, 40, 41, 38, 43, 44, 45, 38, 47, 48, 49, + 38, 51, 52, 53, 38, 55, 56, 57, 38, 37, 0, 61, 62, 63, 64, 63, 62, + 0, 68, 69, 70, 71, 70, 69], dtype=np.int32) + order = 'zyx' + dt = 0.05 + njoints = 75 + + # smoothing + if smoothing: + n_poses = poses.shape[0] + out_poses = np.zeros((n_poses, poses.shape[1])) + for i in range(out_poses.shape[1]): + # if (13 + (njoints - 14) * 9) <= i < (13 + njoints * 9): out_poses[:, i] = savgol_filter(poses[:, i], 41, 2) # NOTE: smoothing on rotation matrices is not optimal + # else: + out_poses[:, i] = savgol_filter(poses[:, i], 15, 2) # NOTE: smoothing on rotation matrices is not optimal + else: + out_poses = poses + + # Extract predictions + P_root_pos = out_poses[:, 0:3] + P_root_rot = out_poses[:, 3:7] + P_root_vel = out_poses[:, 7:10] + P_root_vrt = out_poses[:, 10:13] + P_lpos = out_poses[:, 13 + njoints * 0: 13 + njoints * 3].reshape([length, njoints, 3]) + P_ltxy = out_poses[:, 13 + njoints * 3: 13 + njoints * 9].reshape([length, njoints, 2, 3]) + P_lvel = out_poses[:, 13 + njoints * 9: 13 + njoints * 12].reshape([length, njoints, 3]) + P_lvrt = out_poses[:, 13 + njoints * 12: 13 + njoints * 15].reshape([length, njoints, 3]) + + P_ltxy = torch.as_tensor(P_ltxy, dtype=torch.float32) + P_lrot = quat.from_xform(txform.xform_orthogonalize_from_xy(P_ltxy).cpu().numpy()) # + + if smooth_foot: + pdb.set_trace() + next_poses_LeftToeBase = P_lrot[:, -7] # (length, 4) 7/14, 5/12 + next_poses_RightToeBase = P_lrot[:, -14] + next_poses_LeftToeBase = np.zeros_like(next_poses_LeftToeBase) + next_poses_RightToeBase = np.zeros_like(next_poses_RightToeBase) + P_lrot[:, -7] = next_poses_LeftToeBase + P_lrot[:, -14] = next_poses_RightToeBase + + # 20fps -> 60fps + dt = 1 / 60 + P_root_pos = P_root_pos.repeat(3, axis=0) + P_root_rot = P_root_rot.repeat(3, axis=0) + P_lpos = P_lpos.repeat(3, axis=0) + P_lrot = P_lrot.repeat(3, axis=0) + + write_bvh(outpath, + P_root_pos, + P_root_rot, + P_lpos, + P_lrot, + parents, bone_names, order, dt + ) + +if __name__ == '__main__': + ''' + cd mymdm/process + python process_zeggs_bvh.py + ''' + config_file = "../../ubisoft-laforge-ZeroEGGS-main/configs/data_pipeline_conf_v1.json" + with open(config_file, "r") as f: + conf = json.load(f) + + conf = DictConfig(conf) + + # animation_file = "../../ubisoft-laforge-ZeroEGGS-main/Data/original/001_Neutral_0.bvh" + animation_file = r"E:\ไธ‹่ฝฝ\bvh2fpx" + + # ( + # root_pos, # (8116, 3) + # root_rot, # (8116, 4) + # root_vel, # (8116, 3) # 1 + # root_vrt, # (8116, 3) # 2 + # lpos, # (8116, 75, 3) # 3 + # lrot, # (8116, 75, 4) + # ltxy, # (8116, 75, 2, 3) # 4 + # lvel, # (8116, 75, 3) # 5 + # lvrt, # (8116, 75, 3) # 6 + # cpos, # (8116, 75, 3) + # crot, # (8116, 75, 4) + # ctxy, # (8116, 75, 2, 3) + # cvel, # (8116, 75, 3) + # cvrt, # (8116, 75, 3) + # gaze_pos, # (8116, 3) + # gaze_dir, # (8116, 3) # 7 + # ), parents, dt, order = preprocess_animation(animation_file) + for item in os.listdir(os.path.join(animation_file, '20fps')): + print(item) + all_poses, parents, dt, order, njoints = preprocess_animation(os.path.join(animation_file, '20fps', item), fps=60) # 20 + pose2bvh(poses=all_poses, outpath=os.path.join(animation_file, 'processed', item), length=all_poses.shape[0], smoothing=True, smooth_foot=False) + + # length = all_poses.shape[0] + + # root_rot = torch.as_tensor(root_rot, dtype=torch.float32) + # gaze_pos = torch.as_tensor(gaze_pos, dtype=torch.float32) + # root_pos = torch.as_tensor(root_pos, dtype=torch.float32) + + # pose_mean = np.load(r'E:\ไธ‹่ฝฝ\mean.npz')['mean'] + # pose_std = np.load(r'E:\ไธ‹่ฝฝ\std.npz')['std'] + # # normalize + # std = np.clip(pose_std, a_min=0.01, a_max=None) + # all_poses = (all_poses - pose_mean) / std + # np.savez(r"E:\ไธ‹่ฝฝ\bvh\happy-normalize.npz", pose=all_poses) + # out_poses = all_poses + # # denormalize + # out_poses = np.multiply(out_poses, std) + pose_mean + # + # outpath = "../mydiffusion_zeggs/sample_20230104_192239/20230104_193613_smoothing_SG_minibatch_2720_[1, 0, 0, 0, 0, 0]_123456_1.bvh" + # pose2bvh(poses=out_poses, outpath="../mydiffusion_zeggs/sample_20230104_192239/20230104_193613_smoothing_SG_minibatch_2720_[1, 0, 0, 0, 0, 0]_123456_1.bvh", length=length, smoothing=True, smooth_foot=False) + + + + # root_vel_mean = root_vel.mean(axis=0) + # root_vrt_mean = root_vrt.mean(axis=0) + # lpos_mean = lpos.mean(axis=0) + # ltxy_mean = ltxy.mean(axis=0) + # lvel_mean = lvel.mean(axis=0) + # lvrt_mean = lvrt.mean(axis=0) + # gaze_dir_mean = gaze_dir.mean(axis=0) + # + # anim_mean = np.hstack([root_vel_mean.ravel(), root_vrt_mean.ravel(), lpos_mean.ravel(), ltxy_mean.ravel(), lvel_mean.ravel(), lvrt_mean.ravel(), gaze_dir_mean.ravel()]) + # + # root_vel_std = root_vel.std() + 1e-10 + # root_vrt_std = root_vrt.std() + 1e-10 + # lpos_std = lpos.std() + 1e-10 + # ltxy_std = ltxy.std() + 1e-10 + # lvel_std = lvel.std() + 1e-10 + # lvrt_std = lvrt.std() + 1e-10 + # gaze_dir_std = gaze_dir.std() + 1e-10 + # + # anim_input_std = np.hstack([root_vel_std.repeat(len(root_vel_mean.ravel())), + # root_vrt_std.repeat(len(root_vrt_mean.ravel())), + # lpos_std.repeat(len(lpos_mean.ravel())), + # ltxy_std.repeat(len(ltxy_mean.ravel())), + # lvel_std.repeat(len(lvel_mean.ravel())), + # lvrt_std.repeat(len(lvrt_mean.ravel())), + # gaze_dir_std.repeat(len(gaze_dir_mean.ravel()))]) + # + # root_vel_std = root_vel.std(axis=0) + # root_vrt_std = root_vrt.std(axis=0) + # lpos_std = lpos.std(axis=0) + # ltxy_std = ltxy.std(axis=0) + # lvel_std = lvel.std(axis=0) + # lvrt_std = lvrt.std(axis=0) + # gaze_dir_std = gaze_dir.std(axis=0) + # + # anim_output_std = np.hstack([root_vel_std.ravel(), + # root_vrt_std.ravel(), + # lpos_std.ravel(), + # ltxy_std.ravel(), + # lvel_std.ravel(), + # lvrt_std.ravel(), + # gaze_dir_std.ravel()]) + # + # Z_root_vel = torch.as_tensor(root_vel, dtype=torch.float32) + # Z_root_vrt = torch.as_tensor(root_vrt, dtype=torch.float32) + # Z_lpos = torch.as_tensor(lpos, dtype=torch.float32) + # Z_ltxy = torch.as_tensor(ltxy, dtype=torch.float32) + # Z_lvel = torch.as_tensor(lvel, dtype=torch.float32) + # Z_lvrt = torch.as_tensor(lvrt, dtype=torch.float32) + # # gaze_dir = torch.as_tensor(gaze_dir, dtype=torch.float32) # + # + # # Compute Local Gaze + # Z_gaze_dir = tquat.quat_inv_mul_vec(root_rot, gaze_pos - root_pos) + # + # + # + # pose_encoding = torch.cat( + # [ + # Z_root_vel.reshape([length, -1]), + # Z_root_vrt.reshape([length, -1]), + # Z_lpos.reshape([length, -1]), + # Z_ltxy.reshape([length, -1]), + # Z_lvel.reshape([length, -1]), + # Z_lvrt.reshape([length, -1]), + # Z_gaze_dir.reshape([length, -1]), + # ], + # dim=1, + # ) # Need to Normalize + # + # pdb.set_trace() + # pose_encoding = (pose_encoding - anim_mean) / anim_input_std + # + + # + # # processed_data_path = "../../ubisoft-laforge-ZeroEGGS-main/Data/processed_v1/processed_data.npz" + # # processed_data = np.load(processed_data_path) + # pdb.set_trace() diff --git a/main/sample/edit.py b/main/sample/edit.py new file mode 100644 index 0000000000000000000000000000000000000000..13457e3b8dcf80aebda4ea0cfe15c3efff268e09 --- /dev/null +++ b/main/sample/edit.py @@ -0,0 +1,199 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +Generate a large batch of image samples from a model and save them as a large +numpy array. This can be used to produce samples for FID evaluation. +""" +from utils.fixseed import fixseed +import os +import numpy as np +import torch +from utils.parser_util import edit_args +from utils.model_util import create_model_and_diffusion, load_model_wo_clip +from utils import dist_util +from model.cfg_sampler import ClassifierFreeSampleModel +from data_loaders.get_data import get_dataset_loader +from data_loaders.humanml.scripts.motion_process import recover_from_ric +from data_loaders import humanml_utils +import data_loaders.humanml.utils.paramUtil as paramUtil +from data_loaders.humanml.utils.plot_script import plot_3d_motion +import shutil + + +def main(): + args = edit_args() + fixseed(args.seed) + out_path = args.output_dir + name = os.path.basename(os.path.dirname(args.model_path)) + niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '') + max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60 + fps = 12.5 if args.dataset == 'kit' else 20 + dist_util.setup_dist(args.device) + if out_path == '': + out_path = os.path.join(os.path.dirname(args.model_path), + 'edit_{}_{}_{}_seed{}'.format(name, niter, args.edit_mode, args.seed)) + if args.text_condition != '': + out_path += '_' + args.text_condition.replace(' ', '_').replace('.', '') + + print('Loading dataset...') + assert args.num_samples <= args.batch_size, \ + f'Please either increase batch_size({args.batch_size}) or reduce num_samples({args.num_samples})' + # So why do we need this check? In order to protect GPU from a memory overload in the following line. + # If your GPU can handle batch size larger then default, you can specify it through --batch_size flag. + # If it doesn't, and you still want to sample more prompts, run this script with different seeds + # (specify through the --seed flag) + args.batch_size = args.num_samples # Sampling a single batch from the testset, with exactly args.num_samples + data = get_dataset_loader(name=args.dataset, + batch_size=args.batch_size, + num_frames=max_frames, + split='test', + hml_mode='train') # in train mode, you get both text and motion. + # data.fixed_length = n_frames + total_num_samples = args.num_samples * args.num_repetitions + + print("Creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, data) + + print(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location='cpu') + load_model_wo_clip(model, state_dict) + + model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler + model.to(dist_util.dev()) + model.eval() # disable random masking + + iterator = iter(data) + input_motions, model_kwargs = next(iterator) + input_motions = input_motions.to(dist_util.dev()) + texts = [args.text_condition] * args.num_samples + model_kwargs['y']['text'] = texts + if args.text_condition == '': + args.guidance_param = 0. # Force unconditioned generation + + # add inpainting mask according to args + assert max_frames == input_motions.shape[-1] + gt_frames_per_sample = {} + model_kwargs['y']['inpainted_motion'] = input_motions + if args.edit_mode == 'in_between': + model_kwargs['y']['inpainting_mask'] = torch.ones_like(input_motions, dtype=torch.bool, + device=input_motions.device) # True means use gt motion + for i, length in enumerate(model_kwargs['y']['lengths'].cpu().numpy()): + start_idx, end_idx = int(args.prefix_end * length), int(args.suffix_start * length) + gt_frames_per_sample[i] = list(range(0, start_idx)) + list(range(end_idx, max_frames)) + model_kwargs['y']['inpainting_mask'][i, :, :, + start_idx: end_idx] = False # do inpainting in those frames + elif args.edit_mode == 'upper_body': + model_kwargs['y']['inpainting_mask'] = torch.tensor(humanml_utils.HML_LOWER_BODY_MASK, dtype=torch.bool, + device=input_motions.device) # True is lower body data + model_kwargs['y']['inpainting_mask'] = model_kwargs['y']['inpainting_mask'].unsqueeze(0).unsqueeze( + -1).unsqueeze(-1).repeat(input_motions.shape[0], 1, input_motions.shape[2], input_motions.shape[3]) + + all_motions = [] + all_lengths = [] + all_text = [] + + for rep_i in range(args.num_repetitions): + print(f'### Start sampling [repetitions #{rep_i}]') + + # add CFG scale to batch + model_kwargs['y']['scale'] = torch.ones(args.batch_size, device=dist_util.dev()) * args.guidance_param + + sample_fn = diffusion.p_sample_loop + + sample = sample_fn( + model, + (args.batch_size, model.njoints, model.nfeats, max_frames), + clip_denoised=False, + model_kwargs=model_kwargs, + skip_timesteps=0, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=True, + dump_steps=None, + noise=None, + const_noise=False, + ) + + + # Recover XYZ *positions* from HumanML3D vector representation + if model.data_rep == 'hml_vec': + n_joints = 22 if sample.shape[1] == 263 else 21 + sample = data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float() + sample = recover_from_ric(sample, n_joints) + sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) + + all_text += model_kwargs['y']['text'] + all_motions.append(sample.cpu().numpy()) + all_lengths.append(model_kwargs['y']['lengths'].cpu().numpy()) + + print(f"created {len(all_motions) * args.batch_size} samples") + + + all_motions = np.concatenate(all_motions, axis=0) + all_motions = all_motions[:total_num_samples] # [bs, njoints, 6, seqlen] + all_text = all_text[:total_num_samples] + all_lengths = np.concatenate(all_lengths, axis=0)[:total_num_samples] + + if os.path.exists(out_path): + shutil.rmtree(out_path) + os.makedirs(out_path) + + npy_path = os.path.join(out_path, 'results.npy') + print(f"saving results file to [{npy_path}]") + np.save(npy_path, + {'motion': all_motions, 'text': all_text, 'lengths': all_lengths, + 'num_samples': args.num_samples, 'num_repetitions': args.num_repetitions}) + with open(npy_path.replace('.npy', '.txt'), 'w') as fw: + fw.write('\n'.join(all_text)) + with open(npy_path.replace('.npy', '_len.txt'), 'w') as fw: + fw.write('\n'.join([str(l) for l in all_lengths])) + + print(f"saving visualizations to [{out_path}]...") + skeleton = paramUtil.kit_kinematic_chain if args.dataset == 'kit' else paramUtil.t2m_kinematic_chain + + # Recover XYZ *positions* from HumanML3D vector representation + if model.data_rep == 'hml_vec': + input_motions = data.dataset.t2m_dataset.inv_transform(input_motions.cpu().permute(0, 2, 3, 1)).float() + input_motions = recover_from_ric(input_motions, n_joints) + input_motions = input_motions.view(-1, *input_motions.shape[2:]).permute(0, 2, 3, 1).cpu().numpy() + + + for sample_i in range(args.num_samples): + caption = 'Input Motion' + length = model_kwargs['y']['lengths'][sample_i] + motion = input_motions[sample_i].transpose(2, 0, 1)[:length] + save_file = 'input_motion{:02d}.mp4'.format(sample_i) + animation_save_path = os.path.join(out_path, save_file) + rep_files = [animation_save_path] + print(f'[({sample_i}) "{caption}" | -> {save_file}]') + plot_3d_motion(animation_save_path, skeleton, motion, title=caption, + dataset=args.dataset, fps=fps, vis_mode='gt', + gt_frames=gt_frames_per_sample.get(sample_i, [])) + for rep_i in range(args.num_repetitions): + caption = all_text[rep_i*args.batch_size + sample_i] + if caption == '': + caption = 'Edit [{}] unconditioned'.format(args.edit_mode) + else: + caption = 'Edit [{}]: {}'.format(args.edit_mode, caption) + length = all_lengths[rep_i*args.batch_size + sample_i] + motion = all_motions[rep_i*args.batch_size + sample_i].transpose(2, 0, 1)[:length] + save_file = 'sample{:02d}_rep{:02d}.mp4'.format(sample_i, rep_i) + animation_save_path = os.path.join(out_path, save_file) + rep_files.append(animation_save_path) + print(f'[({sample_i}) "{caption}" | Rep #{rep_i} | -> {save_file}]') + plot_3d_motion(animation_save_path, skeleton, motion, title=caption, + dataset=args.dataset, fps=fps, vis_mode=args.edit_mode, + gt_frames=gt_frames_per_sample.get(sample_i, [])) + # Credit for visualization: https://github.com/EricGuo5513/text-to-motion + + all_rep_save_file = os.path.join(out_path, 'sample{:02d}.mp4'.format(sample_i)) + ffmpeg_rep_files = [f' -i {f} ' for f in rep_files] + hstack_args = f' -filter_complex hstack=inputs={args.num_repetitions+1}' + ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join(ffmpeg_rep_files) + f'{hstack_args} {all_rep_save_file}' + os.system(ffmpeg_rep_cmd) + print(f'[({sample_i}) "{caption}" | all repetitions | -> {all_rep_save_file}]') + + abs_path = os.path.abspath(out_path) + print(f'[Done] Results are at [{abs_path}]') + + +if __name__ == "__main__": + main() diff --git a/main/sample/generate.py b/main/sample/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..9176d5bbe4037fec282c03b6695e6ff7dfc87dc4 --- /dev/null +++ b/main/sample/generate.py @@ -0,0 +1,256 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +Generate a large batch of image samples from a model and save them as a large +numpy array. This can be used to produce samples for FID evaluation. +""" +from utils.fixseed import fixseed +import os +import numpy as np +import torch +from utils.parser_util import generate_args +from utils.model_util import create_model_and_diffusion, load_model_wo_clip +from utils import dist_util +from model.cfg_sampler import ClassifierFreeSampleModel +from data_loaders.get_data import get_dataset_loader +from data_loaders.humanml.scripts.motion_process import recover_from_ric +import data_loaders.humanml.utils.paramUtil as paramUtil +from data_loaders.humanml.utils.plot_script import plot_3d_motion +import shutil +from data_loaders.tensors import collate + + +def main(): + args = generate_args() + fixseed(args.seed) + out_path = args.output_dir + name = os.path.basename(os.path.dirname(args.model_path)) + niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '') + max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60 + fps = 12.5 if args.dataset == 'kit' else 20 + n_frames = min(max_frames, int(args.motion_length*fps)) + is_using_data = not any([args.input_text, args.text_prompt, args.action_file, args.action_name]) + dist_util.setup_dist(args.device) + if out_path == '': + out_path = os.path.join(os.path.dirname(args.model_path), + 'samples_{}_{}_seed{}'.format(name, niter, args.seed)) + if args.text_prompt != '': + out_path += '_' + args.text_prompt.replace(' ', '_').replace('.', '') + elif args.input_text != '': + out_path += '_' + os.path.basename(args.input_text).replace('.txt', '').replace(' ', '_').replace('.', '') + + # this block must be called BEFORE the dataset is loaded + if args.text_prompt != '': + texts = [args.text_prompt] + args.num_samples = 1 + elif args.input_text != '': + assert os.path.exists(args.input_text) + with open(args.input_text, 'r') as fr: + texts = fr.readlines() + texts = [s.replace('\n', '') for s in texts] + args.num_samples = len(texts) + elif args.action_name: + action_text = [args.action_name] + args.num_samples = 1 + elif args.action_file != '': + assert os.path.exists(args.action_file) + with open(args.action_file, 'r') as fr: + action_text = fr.readlines() + action_text = [s.replace('\n', '') for s in action_text] + args.num_samples = len(action_text) + + assert args.num_samples <= args.batch_size, \ + f'Please either increase batch_size({args.batch_size}) or reduce num_samples({args.num_samples})' + # So why do we need this check? In order to protect GPU from a memory overload in the following line. + # If your GPU can handle batch size larger then default, you can specify it through --batch_size flag. + # If it doesn't, and you still want to sample more prompts, run this script with different seeds + # (specify through the --seed flag) + args.batch_size = args.num_samples # Sampling a single batch from the testset, with exactly args.num_samples + + print('Loading dataset...') + data = load_dataset(args, max_frames, n_frames) + total_num_samples = args.num_samples * args.num_repetitions + + print("Creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, data) + + print(f"Loading checkpoints from [{args.model_path}]...") + state_dict = torch.load(args.model_path, map_location='cpu') + load_model_wo_clip(model, state_dict) + + if args.guidance_param != 1: + model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler + model.to(dist_util.dev()) + model.eval() # disable random masking + + if is_using_data: + iterator = iter(data) + _, model_kwargs = next(iterator) + else: + collate_args = [{'inp': torch.zeros(n_frames), 'tokens': None, 'lengths': n_frames}] * args.num_samples + is_t2m = any([args.input_text, args.text_prompt]) + if is_t2m: + # t2m + collate_args = [dict(arg, text=txt) for arg, txt in zip(collate_args, texts)] + else: + # a2m + action = data.dataset.action_name_to_action(action_text) + collate_args = [dict(arg, action=one_action, action_text=one_action_text) for + arg, one_action, one_action_text in zip(collate_args, action, action_text)] + _, model_kwargs = collate(collate_args) + + all_motions = [] + all_lengths = [] + all_text = [] + + for rep_i in range(args.num_repetitions): + print(f'### Sampling [repetitions #{rep_i}]') + + # add CFG scale to batch + if args.guidance_param != 1: + model_kwargs['y']['scale'] = torch.ones(args.batch_size, device=dist_util.dev()) * args.guidance_param + + sample_fn = diffusion.p_sample_loop + + sample = sample_fn( + model, + (args.batch_size, model.njoints, model.nfeats, n_frames), + clip_denoised=False, + model_kwargs=model_kwargs, + skip_timesteps=0, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=True, + dump_steps=None, + noise=None, + const_noise=False, + ) + + # Recover XYZ *positions* from HumanML3D vector representation + if model.data_rep == 'hml_vec': + n_joints = 22 if sample.shape[1] == 263 else 21 + sample = data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float() + sample = recover_from_ric(sample, n_joints) + sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) + + rot2xyz_pose_rep = 'xyz' if model.data_rep in ['xyz', 'hml_vec'] else model.data_rep + rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.batch_size, n_frames).bool() + sample = model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True, + jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None, + get_rotations_back=False) + + if args.unconstrained: + all_text += ['unconstrained'] * args.num_samples + else: + text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text' + all_text += model_kwargs['y'][text_key] + + all_motions.append(sample.cpu().numpy()) + all_lengths.append(model_kwargs['y']['lengths'].cpu().numpy()) + + print(f"created {len(all_motions) * args.batch_size} samples") + + + all_motions = np.concatenate(all_motions, axis=0) + all_motions = all_motions[:total_num_samples] # [bs, njoints, 6, seqlen] + all_text = all_text[:total_num_samples] + all_lengths = np.concatenate(all_lengths, axis=0)[:total_num_samples] + + if os.path.exists(out_path): + shutil.rmtree(out_path) + os.makedirs(out_path) + + npy_path = os.path.join(out_path, 'results.npy') + print(f"saving results file to [{npy_path}]") + np.save(npy_path, + {'motion': all_motions, 'text': all_text, 'lengths': all_lengths, + 'num_samples': args.num_samples, 'num_repetitions': args.num_repetitions}) + with open(npy_path.replace('.npy', '.txt'), 'w') as fw: + fw.write('\n'.join(all_text)) + with open(npy_path.replace('.npy', '_len.txt'), 'w') as fw: + fw.write('\n'.join([str(l) for l in all_lengths])) + + print(f"saving visualizations to [{out_path}]...") + skeleton = paramUtil.kit_kinematic_chain if args.dataset == 'kit' else paramUtil.t2m_kinematic_chain + + sample_files = [] + num_samples_in_out_file = 7 + + sample_print_template, row_print_template, all_print_template, \ + sample_file_template, row_file_template, all_file_template = construct_template_variables(args.unconstrained) + + for sample_i in range(args.num_samples): + rep_files = [] + for rep_i in range(args.num_repetitions): + caption = all_text[rep_i*args.batch_size + sample_i] + length = all_lengths[rep_i*args.batch_size + sample_i] + motion = all_motions[rep_i*args.batch_size + sample_i].transpose(2, 0, 1)[:length] + save_file = sample_file_template.format(sample_i, rep_i) + print(sample_print_template.format(caption, sample_i, rep_i, save_file)) + animation_save_path = os.path.join(out_path, save_file) + plot_3d_motion(animation_save_path, skeleton, motion, dataset=args.dataset, title=caption, fps=fps) + # Credit for visualization: https://github.com/EricGuo5513/text-to-motion + rep_files.append(animation_save_path) + + sample_files = save_multiple_samples(args, out_path, + row_print_template, all_print_template, row_file_template, all_file_template, + caption, num_samples_in_out_file, rep_files, sample_files, sample_i) + + abs_path = os.path.abspath(out_path) + print(f'[Done] Results are at [{abs_path}]') + + +def save_multiple_samples(args, out_path, row_print_template, all_print_template, row_file_template, all_file_template, + caption, num_samples_in_out_file, rep_files, sample_files, sample_i): + all_rep_save_file = row_file_template.format(sample_i) + all_rep_save_path = os.path.join(out_path, all_rep_save_file) + ffmpeg_rep_files = [f' -i {f} ' for f in rep_files] + hstack_args = f' -filter_complex hstack=inputs={args.num_repetitions}' if args.num_repetitions > 1 else '' + ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join(ffmpeg_rep_files) + f'{hstack_args} {all_rep_save_path}' + os.system(ffmpeg_rep_cmd) + print(row_print_template.format(caption, sample_i, all_rep_save_file)) + sample_files.append(all_rep_save_path) + if (sample_i + 1) % num_samples_in_out_file == 0 or sample_i + 1 == args.num_samples: + # all_sample_save_file = f'samples_{(sample_i - len(sample_files) + 1):02d}_to_{sample_i:02d}.mp4' + all_sample_save_file = all_file_template.format(sample_i - len(sample_files) + 1, sample_i) + all_sample_save_path = os.path.join(out_path, all_sample_save_file) + print(all_print_template.format(sample_i - len(sample_files) + 1, sample_i, all_sample_save_file)) + ffmpeg_rep_files = [f' -i {f} ' for f in sample_files] + vstack_args = f' -filter_complex vstack=inputs={len(sample_files)}' if len(sample_files) > 1 else '' + ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join( + ffmpeg_rep_files) + f'{vstack_args} {all_sample_save_path}' + os.system(ffmpeg_rep_cmd) + sample_files = [] + return sample_files + + +def construct_template_variables(unconstrained): + row_file_template = 'sample{:02d}.mp4' + all_file_template = 'samples_{:02d}_to_{:02d}.mp4' + if unconstrained: + sample_file_template = 'row{:02d}_col{:02d}.mp4' + sample_print_template = '[{} row #{:02d} column #{:02d} | -> {}]' + row_file_template = row_file_template.replace('sample', 'row') + row_print_template = '[{} row #{:02d} | all columns | -> {}]' + all_file_template = all_file_template.replace('samples', 'rows') + all_print_template = '[rows {:02d} to {:02d} | -> {}]' + else: + sample_file_template = 'sample{:02d}_rep{:02d}.mp4' + sample_print_template = '["{}" ({:02d}) | Rep #{:02d} | -> {}]' + row_print_template = '[ "{}" ({:02d}) | all repetitions | -> {}]' + all_print_template = '[samples {:02d} to {:02d} | all repetitions | -> {}]' + + return sample_print_template, row_print_template, all_print_template, \ + sample_file_template, row_file_template, all_file_template + + +def load_dataset(args, max_frames, n_frames): + data = get_dataset_loader(name=args.dataset, + batch_size=args.batch_size, + num_frames=max_frames, + split='test', + hml_mode='text_only') + data.fixed_length = n_frames + return data + + +if __name__ == "__main__": + main() diff --git a/main/train/mytrain.py b/main/train/mytrain.py new file mode 100644 index 0000000000000000000000000000000000000000..312e586783f0199b24d9646a2c595889593ffee6 --- /dev/null +++ b/main/train/mytrain.py @@ -0,0 +1,67 @@ +import pdb +import sys +[sys.path.append(i) for i in ['.', '..']] + +from model.mdm import MDM +from utils.model_util import create_gaussian_diffusion + +# from data_loaders.get_data import get_dataset_loader +from train.training_loop import TrainLoop + +import torch +import os +import json + + +device = torch.device('cuda:2') +n_frames = 240 +n_pose_dims = 135 +n_audio_dim = 32 + +# n_frames = 240 +# n_pose_dims = 251 + + +def create_model_and_diffusion(): + model = MDM(modeltype='', njoints=n_pose_dims, nfeats=1, translation=True, pose_rep='rot6d', glob=True, + glob_rot=True, cond_mode = 'text', clip_version = 'ViT-B/32', action_emb = 'tensor') + diffusion = create_gaussian_diffusion() + return model, diffusion + + +if __name__ == '__main__': + ''' + python train/mytrain.py --overwrite --save_dir save/mydebug --dataset kit --device 1 + ''' + # modify data/dataset.py + + model, diffusion = create_model_and_diffusion() + model.to(device) + # model.rot2xyz.smpl_model.eval() + + print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters_wo_clip()) / 1000000.0)) + + + from utils.parser_util import train_args + from utils.fixseed import fixseed + + args = train_args() + fixseed(args.seed) + + if args.save_dir is None: + raise FileNotFoundError('save_dir was not specified.') + elif os.path.exists(args.save_dir) and not args.overwrite: + raise FileExistsError('save_dir [{}] already exists.'.format(args.save_dir)) + elif not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + args_path = os.path.join(args.save_dir, 'args.json') + with open(args_path, 'w') as fw: + json.dump(vars(args), fw, indent=4, sort_keys=True) + + # print("creating data loader...") + # data = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, num_frames=args.num_frames) + + # print(iter(data).next()[1]['y'].keys()) + print("Training...") + TrainLoop(args, model, diffusion, device).run_loop() + diff --git a/main/train/train_mdm.py b/main/train/train_mdm.py new file mode 100644 index 0000000000000000000000000000000000000000..adeabe7fa6c3e8b42d94b98262133c8f2f4d23a8 --- /dev/null +++ b/main/train/train_mdm.py @@ -0,0 +1,49 @@ +# This code is based on https://github.com/openai/guided-diffusion +""" +Train a diffusion model on images. +""" + +import os +import json +from utils.fixseed import fixseed +from utils.parser_util import train_args +from utils import dist_util +from train.training_loop import TrainLoop +from data_loaders.get_data import get_dataset_loader +from utils.model_util import create_model_and_diffusion +from train.train_platforms import ClearmlPlatform, TensorboardPlatform, NoPlatform # required for the eval operation + +def main(): + args = train_args() + fixseed(args.seed) + train_platform_type = eval(args.train_platform_type) + train_platform = train_platform_type(args.save_dir) + train_platform.report_args(args, name='Args') + + if args.save_dir is None: + raise FileNotFoundError('save_dir was not specified.') + elif os.path.exists(args.save_dir) and not args.overwrite: + raise FileExistsError('save_dir [{}] already exists.'.format(args.save_dir)) + elif not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + args_path = os.path.join(args.save_dir, 'args.json') + with open(args_path, 'w') as fw: + json.dump(vars(args), fw, indent=4, sort_keys=True) + + dist_util.setup_dist(args.device) + + print("creating data loader...") + data = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, num_frames=args.num_frames) + + print("creating model and diffusion...") + model, diffusion = create_model_and_diffusion(args, data) + model.to(dist_util.dev()) + model.rot2xyz.smpl_model.eval() + + print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters_wo_clip()) / 1000000.0)) + print("Training...") + TrainLoop(args, train_platform, model, diffusion, data).run_loop() + train_platform.close() + +if __name__ == "__main__": + main() diff --git a/main/train/train_platforms.py b/main/train/train_platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..b83b3fc9736cf6195ccf89076c35b7046c93810f --- /dev/null +++ b/main/train/train_platforms.py @@ -0,0 +1,52 @@ +import os + +class TrainPlatform: + def __init__(self, save_dir): + pass + + def report_scalar(self, name, value, iteration, group_name=None): + pass + + def report_args(self, args, name): + pass + + def close(self): + pass + + +class ClearmlPlatform(TrainPlatform): + def __init__(self, save_dir): + from clearml import Task + path, name = os.path.split(save_dir) + self.task = Task.init(project_name='motion_diffusion', + task_name=name, + output_uri=path) + self.logger = self.task.get_logger() + + def report_scalar(self, name, value, iteration, group_name): + self.logger.report_scalar(title=group_name, series=name, iteration=iteration, value=value) + + def report_args(self, args, name): + self.task.connect(args, name=name) + + def close(self): + self.task.close() + + +class TensorboardPlatform(TrainPlatform): + def __init__(self, save_dir): + from torch.utils.tensorboard import SummaryWriter + self.writer = SummaryWriter(log_dir=save_dir) + + def report_scalar(self, name, value, iteration, group_name=None): + self.writer.add_scalar(f'{group_name}/{name}', value, iteration) + + def close(self): + self.writer.close() + + +class NoPlatform(TrainPlatform): + def __init__(self, save_dir): + pass + + diff --git a/main/train/training_loop.py b/main/train/training_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..92f0d3116eec462eb9d52aefd0b970f509c67c38 --- /dev/null +++ b/main/train/training_loop.py @@ -0,0 +1,356 @@ +import functools +import os +import numpy as np + +import blobfile as bf +import torch +from torch.optim import AdamW + +from diffusion import logger +from diffusion.fp16_util import MixedPrecisionTrainer +from diffusion.resample import LossAwareSampler, UniformSampler +from tqdm import tqdm +from diffusion.resample import create_named_schedule_sampler + +import sys +[sys.path.append(i) for i in ['../process', '../../ubisoft-laforge-ZeroEGGS-main', '../mydiffusion_zeggs']] +from generate.generate import WavEncoder +from process_zeggs_bvh import pose2bvh + +# For ImageNet experiments, this was a good default value. +# We found that the lg_loss_scale quickly climbed to +# 20-21 within the first ~1K steps of training. +INITIAL_LOG_LOSS_SCALE = 20.0 + + +class TrainLoop: + def __init__(self, args, model, diffusion, device, data=None): + self.args = args + self.data = data + self.model = model + self.diffusion = diffusion + self.cond_mode = model.cond_mode + self.batch_size = args.batch_size + self.microbatch = args.batch_size # deprecating this option + self.lr = args.lr + self.log_interval = args.log_interval + # self.save_interval = args.save_interval + # self.resume_checkpoint = args.resume_checkpoint + self.use_fp16 = False # deprecating this option + self.fp16_scale_growth = 1e-3 # deprecating this option + self.weight_decay = args.weight_decay + self.lr_anneal_steps = args.lr_anneal_steps + + self.step = 0 + self.resume_step = 0 + self.global_batch = self.batch_size # * dist.get_world_size() + # self.num_steps = args.num_steps + self.num_epochs = 40000 + self.n_seed = 8 + + self.sync_cuda = torch.cuda.is_available() + + # self._load_and_sync_parameters() + self.mp_trainer = MixedPrecisionTrainer( + model=self.model, + use_fp16=self.use_fp16, + fp16_scale_growth=self.fp16_scale_growth, + ) + + self.save_dir = args.save_dir + + self.device = device + if args.audio_feat == "wav encoder": + self.WavEncoder = WavEncoder().to(self.device) + self.opt = AdamW([ + {'params': self.mp_trainer.master_params, 'lr':self.lr, 'weight_decay':self.weight_decay}, + {'params': self.WavEncoder.parameters(), 'lr':self.lr} + ]) + elif args.audio_feat == "mfcc" or args.audio_feat == 'wavlm': + self.opt = AdamW([ + {'params': self.mp_trainer.master_params, 'lr':self.lr, 'weight_decay':self.weight_decay} + ]) + + # if self.resume_step: + # self._load_optimizer_state() + # Model was resumed, either due to a restart or a checkpoint + # being specified at the command line. + + self.schedule_sampler_type = 'uniform' + self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, diffusion) + self.eval_wrapper, self.eval_data, self.eval_gt_data = None, None, None + # if args.dataset in ['kit', 'humanml'] and args.eval_during_training: + # mm_num_samples = 0 # mm is super slow hence we won't run it during training + # mm_num_repeats = 0 # mm is super slow hence we won't run it during training + # gen_loader = get_dataset_loader(name=args.dataset, batch_size=args.eval_batch_size, num_frames=None, + # split=args.eval_split, + # hml_mode='eval') + # + # self.eval_gt_data = get_dataset_loader(name=args.dataset, batch_size=args.eval_batch_size, num_frames=None, + # split=args.eval_split, + # hml_mode='gt') + # self.eval_wrapper = EvaluatorMDMWrapper(args.dataset, self.device) + # self.eval_data = { + # 'test': lambda: eval_humanml.get_mdm_loader( + # model, diffusion, args.eval_batch_size, + # gen_loader, mm_num_samples, mm_num_repeats, gen_loader.dataset.opt.max_motion_length, + # args.eval_num_samples, scale=1., + # ) + # } + self.use_ddp = False + self.ddp_model = self.model + self.mask_train = (torch.zeros([self.batch_size, 1, 1, args.n_poses]) < 1).to(self.device) + self.mask_test = (torch.zeros([1, 1, 1, args.n_poses]) < 1).to(self.device) + # self.tmp_audio = torch.from_numpy(np.load('tmp_audio.npy')).unsqueeze(0).to(self.device) + # self.tmp_mfcc = torch.from_numpy(np.load('10_kieks_0_9_16.npz')['mfcc'][:args.n_poses]).to(torch.float32).unsqueeze(0).to(self.device) + self.mask_local_train = torch.ones(self.batch_size, args.n_poses).bool().to(self.device) + self.mask_local_test = torch.ones(1, args.n_poses).bool().to(self.device) + + # def _load_and_sync_parameters(self): + # resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + # + # if resume_checkpoint: + # self.resume_step = parse_resume_step_from_filename(resume_checkpoint) + # logger.log(f"loading model from checkpoint: {resume_checkpoint}...") + # self.model.load_state_dict( + # dist_util.load_state_dict( + # resume_checkpoint, map_location=self.device + # ) + # ) + + # def _load_optimizer_state(self): + # main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + # opt_checkpoint = bf.join( + # bf.dirname(main_checkpoint), f"opt{self.resume_step:09}.pt" + # ) + # if bf.exists(opt_checkpoint): + # logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") + # state_dict = dist_util.load_state_dict( + # opt_checkpoint, map_location=self.device + # ) + # self.opt.load_state_dict(state_dict) + + def run_loop(self): + + for epoch in range(self.num_epochs): + # print(f'Starting epoch {epoch}') + # for _ in tqdm(range(10)): # 4 steps, batch size, chmod 777 + for batch in tqdm(self.data): + if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): + break + + cond_ = {'y':{}} + + # cond_['y']['text'] = ['A person turns left with medium speed.', 'A human goes slowly about 1.5 meters forward.'] + + # motion = torch.rand(2, 135, 1, 80).to(self.device) + # pose_seq, _, style, audio, mfcc, wavlm = batch # (batch, 240, 135), (batch, 30), (batch, 64000) + # pose_seq, _, style, _, _, wavlm = batch + pose_seq, style, wavlm = batch + motion = pose_seq.permute(0, 2, 1).unsqueeze(2).to(self.device) + + cond_['y']['seed'] = motion[..., 0:self.n_seed] + # motion = motion[..., self.n_seed:] + cond_['y']['style'] = style.to(self.device) + cond_['y']['mask_local'] = self.mask_local_train + + if self.args.audio_feat == 'wav encoder': + # cond_['y']['audio'] = torch.rand(240, 2, 32).to(self.device) + cond_['y']['audio'] = self.WavEncoder(audio.to(self.device)).permute(1, 0, 2) # (batch, 240, 32) + elif self.args.audio_feat == 'mfcc': + # cond_['y']['audio'] = torch.rand(80, 2, 13).to(self.device) + cond_['y']['audio'] = mfcc.to(torch.float32).to(self.device).permute(1, 0, 2) # [self.n_seed:, ...] # (batch, 80, 13) + elif self.args.audio_feat == 'wavlm': + cond_['y']['audio'] = wavlm.to(torch.float32).to(self.device) + + cond_['y']['mask'] = self.mask_train # [..., self.n_seed:] + + self.run_step(motion, cond_) + if self.step % self.log_interval == 0: + for k,v in logger.get_current().name2val.items(): + if k == 'loss': + print('step[{}]: loss[{:0.5f}]'.format(self.step+self.resume_step, v)) + + # if self.step % 10000 == 0: + # sample_fn = self.diffusion.p_sample_loop + # + # model_kwargs_ = {'y': {}} + # model_kwargs_['y']['mask'] = self.mask_test # [..., self.n_seed:] + # model_kwargs_['y']['seed'] = torch.zeros([1, 1141, 1, self.n_seed]).to(self.device) + # model_kwargs_['y']['style'] = torch.zeros([1, 6]).to(self.device) + # model_kwargs_['y']['mask_local'] = self.mask_local_test + # if self.args.audio_feat == 'wav encoder': + # model_kwargs_['y']['audio'] = self.WavEncoder(self.tmp_audio).permute(1, 0, 2) + # # model_kwargs_['y']['audio'] = torch.rand(240, 1, 32).to(self.device) + # elif self.args.audio_feat == 'mfcc': + # model_kwargs_['y']['audio'] = self.tmp_mfcc.permute(1, 0, 2) # [self.n_seed:, ...] + # # model_kwargs_['y']['audio'] = torch.rand(80, 1, 13).to(self.device) + # elif self.args.audio_feat == 'wavlm': + # model_kwargs_['y']['audio'] = torch.randn(1, 1, 1024).to(self.device) + # + # sample = sample_fn( + # self.model, + # (1, 1141, 1, self.args.n_poses), # - self.n_seed + # clip_denoised=False, + # model_kwargs=model_kwargs_, + # skip_timesteps=0, # 0 is the default value - i.e. don't skip any step + # init_image=None, + # progress=True, + # dump_steps=None, + # noise=None, + # const_noise=False, + # ) # (1, 135, 1, 240) + # + # sampled_seq = sample.squeeze(0).permute(1, 2, 0) + # data_mean_ = np.load("../../ubisoft-laforge-ZeroEGGS-main/Data/processed_v1/processed/mean.npz")['mean'] + # data_std_ = np.load("../../ubisoft-laforge-ZeroEGGS-main/Data/processed_v1/processed/std.npz")['std'] + # + # data_mean = np.array(data_mean_).squeeze() + # data_std = np.array(data_std_).squeeze() + # std = np.clip(data_std, a_min=0.01, a_max=None) + # out_poses = np.multiply(np.array(sampled_seq[0].detach().cpu()), std) + data_mean + # + # pipeline_path = '../../../My/process/resource/data_pipe_20_rotation.sav' + # save_path = 'inference_zeggs_mymodel3_wavlm' + # prefix = str(datetime.now().strftime('%Y%m%d_%H%M%S')) + # if not os.path.exists(save_path): + # os.mkdir(save_path) + # # make_bvh_GENEA2020_BT(save_path, prefix, out_poses, smoothing=False, pipeline_path=pipeline_path) + # + # pose2bvh(out_poses, os.path.join(save_path, prefix + '.bvh'), length=self.args.n_poses) + + if self.step % 50000 == 0: + self.save() + # self.model.eval() + # self.evaluate() + # self.model.train() + + # Run for a finite amount of time in integration tests. + if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: + return + self.step += 1 + if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): + break + # Save the last checkpoint if it wasn't already saved. + # if (self.step - 1) % 50000 != 0: + # self.save() + # self.evaluate() + + + def run_step(self, batch, cond): + self.forward_backward(batch, cond) # torch.Size([64, 251, 1, 196]) cond['y'].keys() dict_keys(['mask', 'lengths', 'text', 'tokens']) + self.mp_trainer.optimize(self.opt) + self._anneal_lr() + self.log_step() + + def forward_backward(self, batch, cond): + self.mp_trainer.zero_grad() + for i in range(0, batch.shape[0], self.microbatch): + # Eliminates the microbatch feature + assert i == 0 + assert self.microbatch == self.batch_size + micro = batch + micro_cond = cond + last_batch = (i + self.microbatch) >= batch.shape[0] + t, weights = self.schedule_sampler.sample(micro.shape[0], self.device) + + compute_losses = functools.partial( + self.diffusion.training_losses, + self.ddp_model, + micro, # [bs, ch, image_size, image_size] # x_start, (2, 135, 1, 240) + t, # [bs](int) sampled timesteps + model_kwargs=micro_cond, + dataset='kit' + ) + + if last_batch or not self.use_ddp: + losses = compute_losses() + else: + with self.ddp_model.no_sync(): + losses = compute_losses() + + if isinstance(self.schedule_sampler, LossAwareSampler): + self.schedule_sampler.update_with_local_losses( + t, losses["loss"].detach() + ) + + loss = (losses["loss"] * weights).mean() + log_loss_dict( + self.diffusion, t, {k: v * weights for k, v in losses.items()} + ) + self.mp_trainer.backward(loss) + + def _anneal_lr(self): + if not self.lr_anneal_steps: + return + frac_done = (self.step + self.resume_step) / self.lr_anneal_steps + lr = self.lr * (1 - frac_done) + for param_group in self.opt.param_groups: + param_group["lr"] = lr + + def log_step(self): + logger.logkv("step", self.step + self.resume_step) + logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) + + + def ckpt_file_name(self): + return f"model{(self.step+self.resume_step):09d}.pt" + + + def save(self): + def save_checkpoint(params): + state_dict = self.mp_trainer.master_params_to_state_dict(params) + + # Do not save CLIP weights + clip_weights = [e for e in state_dict.keys() if e.startswith('clip_model.')] + for e in clip_weights: + del state_dict[e] + + logger.log(f"saving model...") + filename = self.ckpt_file_name() + with bf.BlobFile(bf.join(self.save_dir, filename), "wb") as f: + torch.save(state_dict, f) + + save_checkpoint(self.mp_trainer.master_params) + + with bf.BlobFile( + bf.join(self.save_dir, f"opt{(self.step+self.resume_step):09d}.pt"), + "wb", + ) as f: + torch.save(self.opt.state_dict(), f) + + +def parse_resume_step_from_filename(filename): + """ + Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the + checkpoint's number of steps. + """ + split = filename.split("model") + if len(split) < 2: + return 0 + split1 = split[-1].split(".")[0] + try: + return int(split1) + except ValueError: + return 0 + + +def get_blob_logdir(): + # You can change this to be a separate path to save checkpoints to + # a blobstore or some external drive. + return logger.get_dir() + + +def find_resume_checkpoint(): + # On your infrastructure, you may want to override this to automatically + # discover the latest checkpoint on your blob storage, etc. + return None + + +def log_loss_dict(diffusion, ts, losses): + for key, values in losses.items(): + logger.logkv_mean(key, values.mean().item()) + # Log the quantiles (four quartiles, in particular). + for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): + quartile = int(4 * sub_t / diffusion.num_timesteps) + logger.logkv_mean(f"{key}_q{quartile}", sub_loss) diff --git a/main/utils/PYTORCH3D_LICENSE b/main/utils/PYTORCH3D_LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..bed0cebe976e160c9087d1b1054473bdacf75b3b --- /dev/null +++ b/main/utils/PYTORCH3D_LICENSE @@ -0,0 +1,30 @@ +BSD License + +For PyTorch3D software + +Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/main/utils/config.py b/main/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..091d790e963959c326917688ee267e6a4ec136d1 --- /dev/null +++ b/main/utils/config.py @@ -0,0 +1,17 @@ +import os + +SMPL_DATA_PATH = "./body_models/smpl" + +SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl") +SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl") +JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy') + +ROT_CONVENTION_TO_ROT_NUMBER = { + 'legacy': 23, + 'no_hands': 21, + 'full_hands': 51, + 'mitten_hands': 33, +} + +GENDERS = ['neutral', 'male', 'female'] +NUM_BETAS = 10 \ No newline at end of file diff --git a/main/utils/dist_util.py b/main/utils/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5580a7890010ed4acdfcee8cb4eb7f8618769c --- /dev/null +++ b/main/utils/dist_util.py @@ -0,0 +1,77 @@ +""" +Helpers for distributed training. +""" + +import socket + +import torch as th +import torch.distributed as dist + +# Change this to reflect your cluster layout. +# The GPU for a given rank is (rank % GPUS_PER_NODE). +GPUS_PER_NODE = 8 + +SETUP_RETRY_COUNT = 3 + +used_device = 0 + +def setup_dist(device=0): + """ + Setup a distributed process group. + """ + global used_device + used_device = device + if dist.is_initialized(): + return + # os.environ["CUDA_VISIBLE_DEVICES"] = str(device) # f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" + + # comm = MPI.COMM_WORLD + # backend = "gloo" if not th.cuda.is_available() else "nccl" + + # if backend == "gloo": + # hostname = "localhost" + # else: + # hostname = socket.gethostbyname(socket.getfqdn()) + # os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) + # os.environ["RANK"] = str(comm.rank) + # os.environ["WORLD_SIZE"] = str(comm.size) + + # port = comm.bcast(_find_free_port(), root=used_device) + # os.environ["MASTER_PORT"] = str(port) + # dist.init_process_group(backend=backend, init_method="env://") + + +def dev(): + """ + Get the device to use for torch.distributed. + """ + global used_device + if th.cuda.is_available() and used_device>=0: + return th.device(f"cuda:{used_device}") + return th.device("cpu") + + +def load_state_dict(path, **kwargs): + """ + Load a PyTorch file without redundant fetches across MPI ranks. + """ + return th.load(path, **kwargs) + + +def sync_params(params): + """ + Synchronize a sequence of Tensors across ranks from rank 0. + """ + for p in params: + with th.no_grad(): + dist.broadcast(p, 0) + + +def _find_free_port(): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + finally: + s.close() diff --git a/main/utils/fixseed.py b/main/utils/fixseed.py new file mode 100644 index 0000000000000000000000000000000000000000..6f44f6ca263dcc410102a50970ce1b78405ba1f1 --- /dev/null +++ b/main/utils/fixseed.py @@ -0,0 +1,18 @@ +import numpy as np +import torch +import random + + +def fixseed(seed): + torch.backends.cudnn.benchmark = False + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +# SEED = 10 +# EVALSEED = 0 +# # Provoc warning: not fully functionnal yet +# # torch.set_deterministic(True) +# torch.backends.cudnn.benchmark = False +# fixseed(SEED) diff --git a/main/utils/misc.py b/main/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..abe0cdc30f5d183cef8966dba59c44c7b9eb9b0a --- /dev/null +++ b/main/utils/misc.py @@ -0,0 +1,40 @@ +import torch + + +def to_numpy(tensor): + if torch.is_tensor(tensor): + return tensor.cpu().numpy() + elif type(tensor).__module__ != 'numpy': + raise ValueError("Cannot convert {} to numpy array".format( + type(tensor))) + return tensor + + +def to_torch(ndarray): + if type(ndarray).__module__ == 'numpy': + return torch.from_numpy(ndarray) + elif not torch.is_tensor(ndarray): + raise ValueError("Cannot convert {} to torch tensor".format( + type(ndarray))) + return ndarray + + +def cleanexit(): + import sys + import os + try: + sys.exit(0) + except SystemExit: + os._exit(0) + +def load_model_wo_clip(model, state_dict): + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + assert len(unexpected_keys) == 0 + assert all([k.startswith('clip_model.') for k in missing_keys]) + +def freeze_joints(x, joints_to_freeze): + # Freezes selected joint *rotations* as they appear in the first frame + # x [bs, [root+n_joints], joint_dim(6), seqlen] + frozen = x.detach().clone() + frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1] + return frozen diff --git a/main/utils/model_util.py b/main/utils/model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a1bad78475fa1aec87478d7ae772afd1d698feb5 --- /dev/null +++ b/main/utils/model_util.py @@ -0,0 +1,100 @@ +import pdb + +from model.mdm import MDM +from diffusion import gaussian_diffusion as gd +from diffusion.respace import SpacedDiffusion, space_timesteps + + +def load_model_wo_clip(model, state_dict): + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print(missing_keys, unexpected_keys) + assert len(unexpected_keys) == 0 + assert all([k.startswith('clip_model.') for k in missing_keys]) + + +def create_model_and_diffusion(args, data): + model = MDM(**get_model_args(args, data)) + diffusion = create_gaussian_diffusion(args) + return model, diffusion + + +def get_model_args(args, data): + + # default args + clip_version = 'ViT-B/32' + action_emb = 'tensor' + if args.unconstrained: + cond_mode = 'no_cond' + elif args.dataset in ['kit', 'humanml']: + cond_mode = 'text' + else: + cond_mode = 'action' + if hasattr(data.dataset, 'num_actions'): + num_actions = data.dataset.num_actions + else: + num_actions = 1 + + # SMPL defaults + data_rep = 'rot6d' + njoints = 25 + nfeats = 6 + + if args.dataset == 'humanml': + data_rep = 'hml_vec' + njoints = 263 + nfeats = 1 + elif args.dataset == 'kit': + data_rep = 'hml_vec' + njoints = 251 + nfeats = 1 + + return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions, + 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True, + 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4, + 'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode, + 'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch, + 'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset} + + +def create_gaussian_diffusion(): + noise_schedule = 'cosine' + sigma_small = True + lambda_vel = 0.0 + lambda_rcxyz = 0.0 + lambda_fc = 0.0 + + # default params + predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal! + steps = 1000 + scale_beta = 1. # no scaling + timestep_respacing = '' # can be used for ddim sampling, we don't use it. + learn_sigma = False + rescale_timesteps = False + + betas = gd.get_named_beta_schedule(noise_schedule, steps, scale_beta) + loss_type = gd.LossType.MSE + + if not timestep_respacing: + timestep_respacing = [steps] + + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + rescale_timesteps=rescale_timesteps, + lambda_vel=lambda_vel, + lambda_rcxyz=lambda_rcxyz, + lambda_fc=lambda_fc, + ) \ No newline at end of file diff --git a/main/utils/parser_util.py b/main/utils/parser_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5c27c337643a116cdc4a265f80227bef6c77cc --- /dev/null +++ b/main/utils/parser_util.py @@ -0,0 +1,237 @@ +from argparse import ArgumentParser +import argparse +import os +import json + + +def parse_and_load_from_model(parser): + # args according to the loaded model + # do not try to specify them from cmd line since they will be overwritten + add_data_options(parser) + add_model_options(parser) + add_diffusion_options(parser) + args = parser.parse_args() + args_to_overwrite = [] + for group_name in ['dataset', 'model', 'diffusion']: + args_to_overwrite += get_args_per_group_name(parser, args, group_name) + + # load args from model + model_path = get_model_path_from_args() + args_path = os.path.join(os.path.dirname(model_path), 'args.json') + assert os.path.exists(args_path), 'Arguments json file was not found!' + with open(args_path, 'r') as fr: + model_args = json.load(fr) + + for a in args_to_overwrite: + if a in model_args.keys(): + setattr(args, a, model_args[a]) + + elif 'cond_mode' in model_args: # backward compitability + unconstrained = (model_args['cond_mode'] == 'no_cond') + setattr(args, 'unconstrained', unconstrained) + + else: + print('Warning: was not able to load [{}], using default value [{}] instead.'.format(a, args.__dict__[a])) + + if args.cond_mask_prob == 0: + args.guidance_param = 1 + return args + + +def get_args_per_group_name(parser, args, group_name): + for group in parser._action_groups: + if group.title == group_name: + group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions} + return list(argparse.Namespace(**group_dict).__dict__.keys()) + return ValueError('group_name was not found.') + +def get_model_path_from_args(): + try: + dummy_parser = ArgumentParser() + dummy_parser.add_argument('model_path') + dummy_args, _ = dummy_parser.parse_known_args() + return dummy_args.model_path + except: + raise ValueError('model_path argument must be specified.') + + +def add_base_options(parser): + group = parser.add_argument_group('base') + group.add_argument("--cuda", default=True, type=bool, help="Use cuda device, otherwise use CPU.") + group.add_argument("--device", default=0, type=int, help="Device id to use.") + group.add_argument("--seed", default=10, type=int, help="For fixing random seed.") + group.add_argument("--batch_size", default=64, type=int, help="Batch size during training.") + + +def add_diffusion_options(parser): + group = parser.add_argument_group('diffusion') + group.add_argument("--noise_schedule", default='cosine', choices=['linear', 'cosine'], type=str, + help="Noise schedule type") + group.add_argument("--diffusion_steps", default=1000, type=int, + help="Number of diffusion steps (denoted T in the paper)") + group.add_argument("--sigma_small", default=True, type=bool, help="Use smaller sigma values.") + + +def add_model_options(parser): + group = parser.add_argument_group('model') + group.add_argument("--arch", default='trans_enc', + choices=['trans_enc', 'trans_dec', 'gru'], type=str, + help="Architecture types as reported in the paper.") + group.add_argument("--emb_trans_dec", default=False, type=bool, + help="For trans_dec architecture only, if true, will inject condition as a class token" + " (in addition to cross-attention).") + group.add_argument("--layers", default=8, type=int, + help="Number of layers.") + group.add_argument("--latent_dim", default=512, type=int, + help="Transformer/GRU width.") + group.add_argument("--cond_mask_prob", default=.1, type=float, + help="The probability of masking the condition during training." + " For classifier-free guidance learning.") + group.add_argument("--lambda_rcxyz", default=0.0, type=float, help="Joint positions loss.") + group.add_argument("--lambda_vel", default=0.0, type=float, help="Joint velocity loss.") + group.add_argument("--lambda_fc", default=0.0, type=float, help="Foot contact loss.") + group.add_argument("--unconstrained", action='store_true', + help="Model is trained unconditionally. That is, it is constrained by neither text nor action. " + "Currently tested on HumanAct12 only.") + + + +def add_data_options(parser): + group = parser.add_argument_group('dataset') + group.add_argument("--dataset", default='humanml', choices=['humanml', 'kit', 'humanact12', 'uestc'], type=str, + help="Dataset name (choose from list).") + group.add_argument("--data_dir", default="", type=str, + help="If empty, will use defaults according to the specified dataset.") + + +def add_training_options(parser): + group = parser.add_argument_group('training') + group.add_argument("--save_dir", required=True, type=str, + help="Path to save checkpoints and results.") + group.add_argument("--overwrite", action='store_true', + help="If True, will enable to use an already existing save_dir.") + group.add_argument("--train_platform_type", default='NoPlatform', choices=['NoPlatform', 'ClearmlPlatform', 'TensorboardPlatform'], type=str, + help="Choose platform to log results. NoPlatform means no logging.") + group.add_argument("--lr", default=1e-4, type=float, help="Learning rate.") + group.add_argument("--weight_decay", default=0.0, type=float, help="Optimizer weight decay.") + group.add_argument("--lr_anneal_steps", default=0, type=int, help="Number of learning rate anneal steps.") + group.add_argument("--eval_batch_size", default=32, type=int, + help="Batch size during evaluation loop. Do not change this unless you know what you are doing. " + "T2m precision calculation is based on fixed batch size 32.") + group.add_argument("--eval_split", default='test', choices=['val', 'test'], type=str, + help="Which split to evaluate on during training.") + group.add_argument("--eval_during_training", action='store_true', + help="If True, will run evaluation during training.") + group.add_argument("--eval_rep_times", default=3, type=int, + help="Number of repetitions for evaluation loop during training.") + group.add_argument("--eval_num_samples", default=1_000, type=int, + help="If -1, will use all samples in the specified split.") + group.add_argument("--log_interval", default=1_000, type=int, + help="Log losses each N steps") + group.add_argument("--save_interval", default=50_000, type=int, + help="Save checkpoints and run evaluation each N steps") + group.add_argument("--num_steps", default=600_000, type=int, + help="Training will stop after the specified number of steps.") + group.add_argument("--num_frames", default=60, type=int, + help="Limit for the maximal number of frames. In HumanML3D and KIT this field is ignored.") + group.add_argument("--resume_checkpoint", default="", type=str, + help="If not empty, will start from the specified checkpoint (path to model###.pt file).") + + +def add_sampling_options(parser): + group = parser.add_argument_group('sampling') + group.add_argument("--model_path", required=True, type=str, + help="Path to model####.pt file to be sampled.") + group.add_argument("--output_dir", default='', type=str, + help="Path to results dir (auto created by the script). " + "If empty, will create dir in parallel to checkpoint.") + group.add_argument("--num_samples", default=10, type=int, + help="Maximal number of prompts to sample, " + "if loading dataset from file, this field will be ignored.") + group.add_argument("--num_repetitions", default=3, type=int, + help="Number of repetitions, per sample (text prompt/action)") + group.add_argument("--guidance_param", default=2.5, type=float, + help="For classifier-free sampling - specifies the s parameter, as defined in the paper.") + + +def add_generate_options(parser): + group = parser.add_argument_group('generate') + group.add_argument("--motion_length", default=6.0, type=float, + help="The length of the sampled motion [in seconds]. " + "Maximum is 9.8 for HumanML3D (text-to-motion), and 2.0 for HumanAct12 (action-to-motion)") + group.add_argument("--input_text", default='', type=str, + help="Path to a text file lists text prompts to be synthesized. If empty, will take text prompts from dataset.") + group.add_argument("--action_file", default='', type=str, + help="Path to a text file that lists names of actions to be synthesized. Names must be a subset of dataset/uestc/info/action_classes.txt if sampling from uestc, " + "or a subset of [warm_up,walk,run,jump,drink,lift_dumbbell,sit,eat,turn steering wheel,phone,boxing,throw] if sampling from humanact12. " + "If no file is specified, will take action names from dataset.") + group.add_argument("--text_prompt", default='', type=str, + help="A text prompt to be generated. If empty, will take text prompts from dataset.") + group.add_argument("--action_name", default='', type=str, + help="An action name to be generated. If empty, will take text prompts from dataset.") + + +def add_edit_options(parser): + group = parser.add_argument_group('edit') + group.add_argument("--edit_mode", default='in_between', choices=['in_between', 'upper_body'], type=str, + help="Defines which parts of the input motion will be edited.\n" + "(1) in_between - suffix and prefix motion taken from input motion, " + "middle motion is generated.\n" + "(2) upper_body - lower body joints taken from input motion, " + "upper body is generated.") + group.add_argument("--text_condition", default='', type=str, + help="Editing will be conditioned on this text prompt. " + "If empty, will perform unconditioned editing.") + group.add_argument("--prefix_end", default=0.25, type=float, + help="For in_between editing - Defines the end of input prefix (ratio from all frames).") + group.add_argument("--suffix_start", default=0.75, type=float, + help="For in_between editing - Defines the start of input suffix (ratio from all frames).") + + +def add_evaluation_options(parser): + group = parser.add_argument_group('eval') + group.add_argument("--model_path", required=True, type=str, + help="Path to model####.pt file to be sampled.") + group.add_argument("--eval_mode", default='wo_mm', choices=['wo_mm', 'mm_short', 'debug', 'full'], type=str, + help="wo_mm (t2m only) - 20 repetitions without multi-modality metric; " + "mm_short (t2m only) - 5 repetitions with multi-modality metric; " + "debug - short run, less accurate results." + "full (a2m only) - 20 repetitions.") + group.add_argument("--guidance_param", default=2.5, type=float, + help="For classifier-free sampling - specifies the s parameter, as defined in the paper.") + + +def train_args(): + parser = ArgumentParser() + add_base_options(parser) + add_data_options(parser) + add_model_options(parser) + add_diffusion_options(parser) + add_training_options(parser) + return parser.parse_args() + + +def generate_args(): + parser = ArgumentParser() + # args specified by the user: (all other will be loaded from the model) + add_base_options(parser) + add_sampling_options(parser) + add_generate_options(parser) + return parse_and_load_from_model(parser) + + +def edit_args(): + parser = ArgumentParser() + # args specified by the user: (all other will be loaded from the model) + add_base_options(parser) + add_sampling_options(parser) + add_edit_options(parser) + return parse_and_load_from_model(parser) + + +def evaluation_parser(): + parser = ArgumentParser() + # args specified by the user: (all other will be loaded from the model) + add_base_options(parser) + add_evaluation_options(parser) + return parse_and_load_from_model(parser) \ No newline at end of file diff --git a/main/utils/rotation_conversions.py b/main/utils/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..210ae1f0878b3ab223ec3d51d4053751dceb47ff --- /dev/null +++ b/main/utils/rotation_conversions.py @@ -0,0 +1,552 @@ +# This code is based on https://github.com/Mathux/ACTOR.git +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Check PYTORCH3D_LICENCE before use + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in dataset as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versorย with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..413befcec0678bd0a37ddd606a6fbbee88190419 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +torch==1.9 +numpy==1.19.5 +lmdb +pyarrow==0.14.0 +pyyaml +easydict +ConfigArgParse +tensorboard +torchvision +einops +matplotlib +tqdm +torchsnooper +ema-pytorch +accelerate +soundfile +pandas==1.3.4 +transforms3d +scipy +scikit-learn +librosa +omegaconf +sox +rich +ffmpeg-normalize +blobfile diff --git a/ubisoft-laforge-ZeroEGGS-main/.gitattributes b/ubisoft-laforge-ZeroEGGS-main/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..efcd5443b2454798be4fa6e591972005de190e39 --- /dev/null +++ b/ubisoft-laforge-ZeroEGGS-main/.gitattributes @@ -0,0 +1,3 @@ +*.zip filter=lfs diff=lfs merge=lfs -text +*.z01 filter=lfs diff=lfs merge=lfs -text +*.z02 filter=lfs diff=lfs merge=lfs -text diff --git a/ubisoft-laforge-ZeroEGGS-main/.gitignore b/ubisoft-laforge-ZeroEGGS-main/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..604333749bcad611f205d6e80891da1c443d60b7 --- /dev/null +++ b/ubisoft-laforge-ZeroEGGS-main/.gitignore @@ -0,0 +1,57 @@ +.DS_Store +.huskyrc.json +out +log.log +**/node_modules +*.pyc +*.vsix +**/.vscode/.ropeproject/** +**/testFiles/**/.cache/** +*.noseids +.nyc_output +.vscode-test +__pycache__ +npm-debug.log +**/.mypy_cache/** +!yarn.lock +coverage/ +cucumber-report.json +**/.vscode-test/** +**/.vscode test/** +**/.vscode-smoke/** +**/.venv*/ +port.txt +precommit.hook +pythonFiles/lib/** +debug_coverage*/** +languageServer/** +languageServer.*/** +bin/** +obj/** +.pytest_cache +tmp/** +.python-version +.vs/ +test-results*.xml +xunit-test-results.xml +build/ci/performance/performance-results.json +!build/ +debug*.log +debugpy*.log +pydevd*.log +nodeLanguageServer/** +nodeLanguageServer.*/** +dist/** + +# added +*.npz +data/clean +data/original +data/processed_v1 +data/processed_v2 + +.vscode +.idea +.ipynb_checkpoints +*.bvh +*.wav diff --git a/ubisoft-laforge-ZeroEGGS-main/License.md b/ubisoft-laforge-ZeroEGGS-main/License.md new file mode 100644 index 0000000000000000000000000000000000000000..528e669c22ceab3b967d4d2d235b16ecfb674912 --- /dev/null +++ b/ubisoft-laforge-ZeroEGGS-main/License.md @@ -0,0 +1,171 @@ +Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License + +By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this +Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License ("Public License"). To the +extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of +Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under these terms and conditions. + +Section 1 โ€“ Definitions. + +Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed +Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a +manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public +License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always +produced where the Licensed Material is synched in timed relation with a moving image. +Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without +limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights +are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not +Copyright and Similar Rights. +Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented +under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or +similar international agreements. +Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and +Similar Rights that applies to Your use of the Licensed Material. +Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this +Public License. +Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are +limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has +authority to license. +Licensor means the individual(s) or entity(ies) granting rights under this Public License. +NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For +purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and +Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary +compensation in connection with the exchange. +Share means to provide material to the public by any means or process that requires permission under the Licensed +Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or +importation, and to make material available to the public including in ways that members of the public may access the +material from a place and at a time individually chosen by them. +Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European +Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as +well as other essentially equivalent rights anywhere in the world. +You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding +meaning. + +Section 2 โ€“ Scope. + +License grant. +Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, +non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: +reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and +produce and reproduce, but not Share, Adapted Material for NonCommercial purposes only. +Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public +License does not apply, and You do not need to comply with its terms and conditions. +Term. The term of this Public License is specified in Section 6(a). +Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all +media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The +Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications +necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective +Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2( +a)(4) never produces Adapted Material. +Downstream recipients. +Offer from the Licensor โ€“ Licensed Material. Every recipient of the Licensed Material automatically receives an offer +from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. +No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any +Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any +recipient of the Licensed Material. +No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You +are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status +by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). +Other rights. + +Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, +and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to +assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed +Rights, but not otherwise. +Patent and trademark rights are not licensed under this Public License. +To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed +Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory +licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when +the Licensed Material is used other than for NonCommercial purposes. + +Section 3 โ€“ License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the following conditions. + +Attribution. + +If You Share the Licensed Material, You must: + +retain the following if it is supplied by the Licensor with the Licensed Material: +identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any +reasonable manner requested by the Licensor (including by pseudonym if designated); +a copyright notice; +a notice that refers to this Public License; +a notice that refers to the disclaimer of warranties; +a URI or hyperlink to the Licensed Material to the extent reasonably practicable; +indicate if You modified the Licensed Material and retain an indication of any previous modifications; and +indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink +to, this Public License. +For the avoidance of doubt, You do not have permission under this Public License to Share Adapted Material. +You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in +which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or +hyperlink to a resource that includes the required information. +If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent +reasonably practicable. + +Section 4 โ€“ Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: + +for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a +substantial portion of the contents of the database for NonCommercial purposes only and provided You do not Share +Adapted Material; +if You include all or a substantial portion of the database contents in a database in which You have Sui Generis +Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is +Adapted Material; and +You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the +database. +For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License +where the Licensed Rights include other Copyright and Similar Rights. + +Section 5 โ€“ Disclaimer of Warranties and Limitation of Liability. + +Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed +Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed +Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, +merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or +the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed +in full or in part, this disclaimer may not apply to You. +To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without +limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, +or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if +the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of +liability is not allowed in full or in part, this limitation may not apply to You. +The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the +extent possible, most closely approximates an absolute disclaimer and waiver of all liability. + +Section 6 โ€“ Term and Termination. + +This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to +comply with this Public License, then Your rights under this Public License terminate automatically. +Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: + +automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the +violation; or +upon express reinstatement by the Licensor. +For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your +violations of this Public License. +For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop +distributing the Licensed Material at any time; however, doing so will not terminate this Public License. +Sections 1, 5, 6, 7, and 8 survive termination of this Public License. + +Section 7 โ€“ Other Terms and Conditions. + +The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly +agreed. +Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and +independent of the terms and conditions of this Public License. + +Section 8 โ€“ Interpretation. + +For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or +impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public +License. +To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically +reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be +severed from this Public License without affecting the enforceability of the remaining terms and conditions. +No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed +to by the Licensor. +Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and +immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. \ No newline at end of file diff --git a/ubisoft-laforge-ZeroEGGS-main/README.md b/ubisoft-laforge-ZeroEGGS-main/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9f950feb231774892ecf8286953d3992e2c220d7 --- /dev/null +++ b/ubisoft-laforge-ZeroEGGS-main/README.md @@ -0,0 +1,181 @@ +
+ +# ZEGGS + +## ZeroEGGS: Zero-shot Example-based Gesture Generation from Speech + +[![Paper](http://img.shields.io/badge/paper-arxiv.2209.07556-B31B1B.svg)](https://arxiv.org/abs/2209.07556) + + +
+ +This repository contains the code for the ZeroEGGS project from +this [article](https://arxiv.org/abs/2209.07556). +It also contains our stylized speech and gesture dataset + +
+ +[![IMAGE ALT TEXT](http://img.youtube.com/vi/YFg7QKWkjwQ/0.jpg)](http://www.youtube.com/watch?v=YFg7QKWkjwQ "Click to watch the video demo") + +[Click](http://www.youtube.com/watch?v=YFg7QKWkjwQ) to watch the video demo + +
+ +## Environment Setup + +Create and activate a virtual environment to work in, e.g. using Conda: + +```sh +conda create -n zeggs python=3.8 +conda activate zeggs +``` + +Install CUDA and PyTorch 1.12.x For CUDA 11.3, this would look like: + +```sh +conda install pytorch torchvision cudatoolkit=11.3 -c pytorch +``` + +Install the remaining requirements with pip: + +```sh +pip install -r requirements.txt +``` + +> You may need to install [`sox`](http://sox.sourceforge.net/) on your system + +## ZEGGS Dataset + +
+ +![zeggs_data](media/zeggs_data.gif) + +
+ZEGGS dataset contains 67 sequences of monologues performed by a female actor speaking in English and covers 19 different motion styles. + +The following styles are present in the ZEGGS dataset: + +
+ +| **Style** | **Length (mins)** | **Style** | **Length (mins)** | +|--------------|-----------------------|-------|---------------| +| Agreement | 5.25 | Pensive | 6.21 | +| Angry | 7.95 | Relaxed | 10.81 | +| Disagreement | 5.33 | Sad | 11.80 | +| Distracted | 5.29 | Sarcastic | 6.52 | +| Flirty | 3.27 | Scared | 5.58 | +| Happy | 10.08 | Sneaky | 6.27 | +| Laughing | 3.85 | Still | 5.33 | +| Oration | 3.98 | Threatening | 5.84 | +| Neutral | 11.13 | Tired | 7.13 | +| Old | 11.37 | Total | 134.65 | + +
+ +### Access to the data +> This repository contains large files. In order to clone this repository including +> the the large zip files, you need to use [git lfs](https://github.com/git-lfs/git-lfs/wiki/Installation). +> If you still get errors, directly download `zip` files. + +The speech and gesture data are contained in the `./data/Zeggs_data.zip`, `./data/Zeggs_data.z01`, and `./data/Zeggs_data.z02` files. You must put all of these parts to the same folder, and extract `.zip` file by WinRAR or Winzip. + +When you extract the zip file, there are two folders: + +- `original` folder contains the original data where the animation and audio files are in their raw version and not + processed. + +- `clean` contains aligned animation and audio data and without unwanted audio of other speaker. For more details on how + these files have been processed check `data_pipeline.py` + +All the animation sequences are in the BVH file format and all the audio data are in WAV format. + +## Data Preparation + +Extract the data from the `Zeggs_data.zip` file and place it in the `data` folder. Next run: + +```sh +python data_pipeline.py +``` + +This processes data and creates the necessary files for training and evaluation in the "processed" folder. You can +customize the data pipeline by changing `data_pipeline_conf.json` config file. Two suggested configurations are provided +in the `configs` folder. You should change the configuration file name in the script. + +## Training + +To train the model, run: + +```sh +python ./main.py -o -n +``` + +For example, to train the model with the default configuration, run: + +```sh +python ./main.py -o "../configs/configs_v1.json" -n "zeggs_v1" +``` + +## Inference + +After training is finished or using provided pretrained models (provided in `./data/outputs`), you can generate gestures +given speech and style as +input +using `generate.py`. The output will be save in `bvh` format. For full functionality (blending, transitions, using +pre-extracted style encodings, etc. ) you need +to directly use `generate_gesture` function. Otherwise, you can use CLI as explained below. + +### Using the CLI + +You can run the inference using the CLI in two ways: + +#### 1. Generating a single sample from a single audio/style pair + +The CLI command looks like this: + +```sh +python ./generate.py -o -s + + + + +
This may take a little bit of time ...
+                Data Info                
+โ”โ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”ณโ”โ”“
+โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ โ”ƒ
+โ”กโ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ•‡โ”โ”ฉ
+โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚
+โ”œโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ค
+โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚
+โ”œโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ผโ”€โ”ค
+โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚
+โ””โ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”ดโ”€โ”˜
+Total length of dataset is 1454108.0 
+frames - 24235.1 seconds
+{
+    'base_path': '../data',
+    'processed_data_path': 
+'processed_v1',
+    'save_trimmed_audio': True,
+    'save_trimmed_animation': True,
+    'save_normalized_animations': False,
+    'save_final_data': True,
+    'audio_conf': {'pre_emphasis': False,
+'pre_emph_coeff': 0.97, 'centered': True,
+'real_amplitude': True, 
+'normalize_mel_bins': True, 
+'normalize_range': True, 'min_clipping': 
+1e-05, 'sampling_rate': 16000, 
+'mel_fmin': 20, 'mel_fmax': 7600, 
+'n_mel_channels': 80, 'filter_length': 
+800, 'hop_length': 200, 
+'resample_method': 'linear', 
+'normalize_loudness': True},
+    'audio_feature_type': ['mel_spec', 
+'energy'],
+    'visualize_spectrogram': False,
+    'visualize_gaze': False,
+    'len_ratios': [0.9, 1.0, 1.1]
+}
+
+
+ + diff --git a/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/data_pipeline_conf.json b/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/data_pipeline_conf.json new file mode 100644 index 0000000000000000000000000000000000000000..bf5ae49da636ecfd120ef02ee58f3d437358c19c --- /dev/null +++ b/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/data_pipeline_conf.json @@ -0,0 +1,36 @@ +{ + "base_path": "../data", + "processed_data_path": "processed_v1", + "save_trimmed_audio": true, + "save_trimmed_animation": true, + "save_normalized_animations": false, + "save_final_data": true, + "audio_conf": { + "pre_emphasis": false, + "pre_emph_coeff": 0.97, + "centered": true, + "real_amplitude": true, + "normalize_mel_bins": true, + "normalize_range": true, + "min_clipping": 1e-05, + "sampling_rate": 16000, + "mel_fmin": 20, + "mel_fmax": 7600, + "n_mel_channels": 80, + "filter_length": 800, + "hop_length": 200, + "resample_method": "linear", + "normalize_loudness": true + }, + "audio_feature_type": [ + "mel_spec", + "energy" + ], + "visualize_spectrogram": false, + "visualize_gaze": false, + "len_ratios": [ + 0.9, + 1.0, + 1.1 + ] +} \ No newline at end of file diff --git a/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/info.csv b/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/info.csv new file mode 100644 index 0000000000000000000000000000000000000000..e8ff690fc177aaa53085b9f95046297c1d93a339 --- /dev/null +++ b/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/info.csv @@ -0,0 +1,135 @@ +audio_filename,audio_start_time,audio_end_time,audio_duration,audio_clap_time,anim_fbx_file,anim_start_time,anim_end_time,anim_duration,anim_clap_time,style,capture_session,acting_start_time,acting_end_time,anim_bvh,validation +001_Neutral_0.wav,09:28:00:01,09:30:26:12,00:02:26:11,,20210604_001_Neutral_0_02.fbx,09:28:03:35,09:30:18:50,00:02:15:15,09:28:10:22,Neutral,1,09:28:15:21,09:30:16:47,001_Neutral_0.bvh,FALSE +002_Neutral_1.wav,09:37:24:01,09:39:54:14,00:02:30:13,,20210604_002_Neutral_1_01.fbx,09:37:28:11,09:39:35:10,00:02:24:59,09:37:34:56,Neutral,1,09:37:38:45,09:39:51:37,002_Neutral_1.bvh,FALSE +003_Neutral_2.wav,09:43:51:16,09:46:27:27,00:02:36:11,,20210604_003_Neutral_2_01.fbx,09:43:55:07,09:46:21:30,00:02:26:23,09:44:02:05,Neutral,1,09:44:08:39,09:46:18:38,003_Neutral_2.bvh,FALSE +004_Neutral_3.wav,09:48:03:04,09:50:45:10,00:02:42:05,,20210604_004_Neutral_3_01.fbx,09:48:07:57,09:50:40:34,00:02:32:37,09:48:16:02,Neutral,1,09:48:20:52,09:50:38:51,004_Neutral_3.bvh,FALSE +005_Neutral_4.wav,09:52:16:06,09:55:07:20,00:02:51:13,,20210604_005_Neutral_4_01.fbx,09:52:20:11,09:55:02:36,00:02:42:25,09:52:26:49,Neutral,1,09:52:32:06,09:54:59:14,005_Neutral_4.bvh,TRUE +006_Sad_0.wav,09:32:22:14,09:35:11:09,00:02:48:24,,20210604_006_Sad_0_01.fbx,09:32:27:35,09:34:57:42,00:02:30:07,09:32:33:31,Sad,1,09:32:36:10,09:34:55:16,006_Sad_0.bvh,FALSE +007_Sad_1.wav,09:56:39:23,09:59:36:08,00:02:56:14,,20210604_007_Sad_1_01.fbx,09:56:39:27,09:59:30:18,00:02:50:51,09:56:47:03,Sad,1,09:56:52:05,09:59:28:06,007_Sad_1.bvh,FALSE +008_Sad_2.wav,10:01:32:10,10:04:41:11,00:03:09:00,,20210604_008_Sad_2_01.fbx,10:01:37:42,10:04:35:55,00:02:58:13,10:01:45:51,Sad,1,10:01:52:29,10:04:32:37,008_Sad_2.bvh,FALSE +009_Sad_3.wav,10:08:14:27,10:11:15:22,00:03:00:24,,20210604_009_Sad_3_01.fbx,10:08:21:50,10:10:56:03,00:02:34:13,10:08:29:15,Sad,1,10:08:33:30,10:10:49:34,009_Sad_3.bvh,FALSE +010_Sad_4.wav,10:13:02:01,10:15:27:29,00:02:25:27,,20210604_010_Sad_4_01.fbx,10:13:05:40,10:15:18:29,00:02:12:49,10:13:13:41,Sad,1,10:13:18:41,10:15:15:26,010_Sad_4.bvh,TRUE +011_Happy_0.wav,10:19:42:18,10:22:17:11,00:02:34:23,,20210604_011_Happy_0_01.fbx,10:19:48:41,10:22:12:34,00:02:23:53,10:19:56:51,Happy,1,10:19:59:53,10:22:11:02,011_Happy_0.bvh,FALSE +012_Happy_1.wav,10:23:52:13,10:26:10:24,00:02:18:10,,20210604_012_Happy_1_01.fbx,10:23:57:21,10:26:07:12,00:02:09:51,10:24:04:36,Happy,1,10:24:10:44,10:26:03:12,012_Happy_1.bvh,FALSE +013_Happy_2.wav,10:27:23:08,10:29:50:05,00:02:26:26,,20210604_013_Happy_2_01.fbx,10:27:27:07,10:29:45:48,00:02:18:41,10:27:35:01,Happy,1,10:27:39:42,10:29:42:46,013_Happy_2.bvh,FALSE +014_Happy_3.wav,10:41:29:02,10:43:30:21,00:02:01:19,,20210604_014_Happy_3_01.fbx,10:41:28:43,10:43:26:14,00:01:57:31,10:41:37:15,Happy,1,10:41:42:56,10:43:24:00,014_Happy_3.bvh,FALSE +015_Happy_4.wav,10:44:38:02,10:47:20:14,00:02:42:12,,20210604_015_Happy_4_01.fbx,10:44:39:57,10:47:11:46,00:02:31:49,10:44:48:25,Happy,1,10:44:52:42,10:47:09:59,015_Happy_4.bvh,TRUE +016_Relaxed_0.wav,10:50:03:25,10:52:32:25,00:02:29:00,,20210604_016_Relaxed_0_01.fbx,10:50:05:51,10:52:29:28,00:02:23:37,10:50:12:26,Relaxed,1,10:50:15:35,10:52:27:37,016_Relaxed_0.bvh,FALSE +017_Relaxed_1.wav,10:54:10:20,10:56:44:28,00:02:34:08,,20210604_017_Relaxed_1_01.fbx,10:54:10:53,10:56:39:18,00:02:28:25,10:54:17:48,Relaxed,1,10:54:21:17,10:56:36:33,017_Relaxed_1.bvh,FALSE +018_Relaxed_2.wav,10:58:16:00,11:00:50:05,00:02:34:05,,20210604_018_Relaxed_2_01.fbx,10:58:20:25,11:00:42:16,00:02:21:51,10:58:30:16,Relaxed,1,10:58:37:25,11:00:40:07,018_Relaxed_2.bvh,FALSE +019_Relaxed_3.wav,11:02:15:18,11:04:51:23,00:02:36:04,,20210604_019_Relaxed_3_01.fbx,11:02:17:07,11:04:37:32,00:02:20:25,11:02:25:14,Relaxed,1,11:02:29:28,11:04:34:56,019_Relaxed_3.bvh,FALSE +020_Relaxed_4.wav,11:06:36:29,11:09:05:06,00:02:28:07,,20210604_020_Relaxed_4_01.fbx,11:06:33:35,11:09:00:20,00:02:26:45,11:06:40:35,Relaxed,1,11:06:44:56,11:08:57:47,020_Relaxed_4.bvh,TRUE +021_Old_0.wav,11:15:16:08,11:18:04:04,00:02:47:26,,20210604_021_Old_0_03.fbx,11:15:17:03,11:17:56:04,00:02:39:01,11:15:24:49,Old,1,11:15:29:37,11:17:53:13,021_Old_0.bvh,FALSE +022_Old_1.wav,11:20:07:18,11:22:22:04,00:02:14:15,,20210604_022_Old_1_01.fbx,11:20:04:57,11:22:14:50,00:02:09:53,11:20:12:11,Old,1,11:20:15:39,11:22:12:48,022_Old_1.bvh,FALSE +023_Old_2.wav,11:23:28:22,11:26:22:04,00:02:53:11,,20210604_023_Old_2_01.fbx,11:23:29:35,11:26:19:10,00:02:49:35,11:23:37:32,Old,1,11:23:41:56,11:26:16:53,023_Old_2.bvh,FALSE +024_Old_3.wav,11:29:12:11,11:31:40:02,00:02:27:20,,20210604_024_Old_3_01.fbx,11:29:12:27,11:31:38:10,00:02:25:43,No Clapping,Old,1,11:29:17:50,11:31:35:28,024_Old_3.bvh,FALSE +025_Old_4.wav,11:33:12:25,11:35:32:05,00:02:19:09,,20210604_025_Old_4_01.fbx,11:33:11:19,11:35:28:46,00:02:17:27,No Clapping,Old,1,11:33:18:01,11:35:26:41,025_Old_4.bvh,TRUE +026_Angry_0.wav,11:49:26:25,11:51:59:14,00:02:32:18,,20210604_026_Angry_0_01.fbx,11:49:35:28,11:51:56:29,00:02:21:01,11:49:41:46,Angry,1,11:49:45:25,11:51:54:45,026_Angry_0.bvh,FALSE +027_Angry_1.wav,11:57:32:22,11:59:35:04,00:02:02:12,,20210604_027_Angry_1_01.fbx,11:57:30:15,11:59:32:06,00:02:01:51,11:57:37:40,Angry,1,11:57:41:51,11:59:29:59,027_Angry_1.bvh,FALSE +028_Angry_2.wav,12:00:45:14,12:02:57:12,00:02:11:27,,20210604_028_Angry_2_01.fbx,12:00:47:19,12:02:52:50,00:02:05:31,12:00:55:11,Angry,1,12:00:59:08,12:02:51:09,028_Angry_2.bvh,FALSE +029_Angry_3.wav,12:04:08:28,12:06:31:16,00:02:22:18,,20210604_029_Angry_3_01.fbx,12:04:06:55,12:06:29:46,00:02:55:51,12:04:14:44,Angry,1,12:04:19:24,12:06:26:44,029_Angry_3.bvh,TRUE +030_Agreement_0.wav,11:36:47:02,11:39:10:27,00:02:23:25,,20210604_031_Agreement_01.fbx,11:36:52:17,11:39:04:38,00:02:12:21,No Clapping,Agreement,1,11:36:56:35,11:39:03:27,030_Agreement_0.bvh,FALSE +031_Disagreement_0.wav,11:41:12:29,11:43:26:16,00:02:13:16,,20210604_032_Disagreement_02.fbx,11:41:10:41,11:43:22:52,00:02:12:11,11:41:17:55,Disagreement,1,11:41:21:47,11:43:21:14,031_Disagreement_0.bvh,FALSE +001_Neutral_0.wav,09:28:00:01,09:30:26:12,00:02:26:11,,20210604_001_Neutral_0_02_mirror.fbx,09:28:03:35,09:30:18:50,00:02:15:15,09:28:10:22,Neutral,1,09:28:15:21,09:30:13:17,001_Neutral_0_mirror.bvh,FALSE +002_Neutral_1.wav,09:37:24:01,09:39:54:14,00:02:30:13,,20210604_002_Neutral_1_01_mirror.fbx,09:37:28:11,09:39:35:10,00:02:24:59,09:37:34:56,Neutral,1,09:37:38:45,09:39:51:37,002_Neutral_1_mirror.bvh,FALSE +003_Neutral_2.wav,09:43:51:16,09:46:27:27,00:02:36:11,,20210604_003_Neutral_2_01_mirror.fbx,09:43:55:07,09:46:21:30,00:02:26:23,09:44:02:05,Neutral,1,09:44:08:39,09:46:18:38,003_Neutral_2_mirror.bvh,FALSE +004_Neutral_3.wav,09:48:03:04,09:50:45:10,00:02:42:05,,20210604_004_Neutral_3_01_mirror.fbx,09:48:07:57,09:50:40:34,00:02:32:37,09:48:16:02,Neutral,1,09:48:20:52,09:50:38:51,004_Neutral_3_mirror.bvh,FALSE +005_Neutral_4.wav,09:52:16:06,09:55:07:20,00:02:51:13,,20210604_005_Neutral_4_01_mirror.fbx,09:52:20:11,09:55:02:36,00:02:42:25,09:52:26:49,Neutral,1,09:52:32:06,09:54:59:14,005_Neutral_4_mirror.bvh,TRUE +006_Sad_0.wav,09:32:22:14,09:35:11:09,00:02:48:24,,20210604_006_Sad_0_01_mirror.fbx,09:32:27:35,09:34:57:42,00:02:30:07,09:32:33:31,Sad,1,09:32:36:10,09:34:55:16,006_Sad_0_mirror.bvh,FALSE +007_Sad_1.wav,09:56:39:23,09:59:36:08,00:02:56:14,,20210604_007_Sad_1_01_mirror.fbx,09:56:39:27,09:59:30:18,00:02:50:51,09:56:47:03,Sad,1,09:56:52:05,09:59:28:06,007_Sad_1_mirror.bvh,FALSE +008_Sad_2.wav,10:01:32:10,10:04:41:11,00:03:09:00,,20210604_008_Sad_2_01_mirror.fbx,10:01:37:42,10:04:35:55,00:02:58:13,10:01:45:51,Sad,1,10:01:52:29,10:04:32:37,008_Sad_2_mirror.bvh,FALSE +009_Sad_3.wav,10:08:14:27,10:11:15:22,00:03:00:24,,20210604_009_Sad_3_01_mirror.fbx,10:08:21:50,10:10:56:03,00:02:34:13,10:08:29:15,Sad,1,10:08:33:30,10:10:49:34,009_Sad_3_mirror.bvh,FALSE +010_Sad_4.wav,10:13:02:01,10:15:27:29,00:02:25:27,,20210604_010_Sad_4_01_mirror.fbx,10:13:05:40,10:15:18:29,00:02:12:49,10:13:13:41,Sad,1,10:13:18:41,10:15:15:26,010_Sad_4_mirror.bvh,TRUE +011_Happy_0.wav,10:19:42:18,10:22:17:11,00:02:34:23,,20210604_011_Happy_0_01_mirror.fbx,10:19:48:41,10:22:12:34,00:02:23:53,10:19:56:51,Happy,1,10:19:59:53,10:22:11:02,011_Happy_0_mirror.bvh,FALSE +012_Happy_1.wav,10:23:52:13,10:26:10:24,00:02:18:10,,20210604_012_Happy_1_01_mirror.fbx,10:23:57:21,10:26:07:12,00:02:09:51,10:24:04:36,Happy,1,10:24:10:44,10:26:03:12,012_Happy_1_mirror.bvh,FALSE +013_Happy_2.wav,10:27:23:08,10:29:50:05,00:02:26:26,,20210604_013_Happy_2_01_mirror.fbx,10:27:27:07,10:29:45:48,00:02:18:41,10:27:35:01,Happy,1,10:27:39:42,10:29:42:46,013_Happy_2_mirror.bvh,FALSE +014_Happy_3.wav,10:41:29:02,10:43:30:21,00:02:01:19,,20210604_014_Happy_3_01_mirror.fbx,10:41:28:43,10:43:26:14,00:01:57:31,10:41:37:15,Happy,1,10:41:42:56,10:43:24:00,014_Happy_3_mirror.bvh,FALSE +015_Happy_4.wav,10:44:38:02,10:47:20:14,00:02:42:12,,20210604_015_Happy_4_01_mirror.fbx,10:44:39:57,10:47:11:46,00:02:31:49,10:44:48:25,Happy,1,10:44:52:42,10:47:09:59,015_Happy_4_mirror.bvh,TRUE +016_Relaxed_0.wav,10:50:03:25,10:52:32:25,00:02:29:00,,20210604_016_Relaxed_0_01_mirror.fbx,10:50:05:51,10:52:29:28,00:02:23:37,10:50:12:26,Relaxed,1,10:50:15:35,10:52:27:37,016_Relaxed_0_mirror.bvh,FALSE +017_Relaxed_1.wav,10:54:10:20,10:56:44:28,00:02:34:08,,20210604_017_Relaxed_1_01_mirror.fbx,10:54:10:53,10:56:39:18,00:02:28:25,10:54:17:48,Relaxed,1,10:54:21:17,10:56:36:33,017_Relaxed_1_mirror.bvh,FALSE +018_Relaxed_2.wav,10:58:16:00,11:00:50:05,00:02:34:05,,20210604_018_Relaxed_2_01_mirror.fbx,10:58:20:25,11:00:42:16,00:02:21:51,10:58:30:16,Relaxed,1,10:58:37:25,11:00:40:07,018_Relaxed_2_mirror.bvh,FALSE +019_Relaxed_3.wav,11:02:15:18,11:04:51:23,00:02:36:04,,20210604_019_Relaxed_3_01_mirror.fbx,11:02:17:07,11:04:37:32,00:02:20:25,11:02:25:14,Relaxed,1,11:02:29:28,11:04:34:56,019_Relaxed_3_mirror.bvh,FALSE +020_Relaxed_4.wav,11:06:36:29,11:09:05:06,00:02:28:07,,20210604_020_Relaxed_4_01_mirror.fbx,11:06:33:35,11:09:00:20,00:02:26:45,11:06:40:35,Relaxed,1,11:06:44:56,11:08:57:47,020_Relaxed_4_mirror.bvh,TRUE +021_Old_0.wav,11:15:16:08,11:18:04:04,00:02:47:26,,20210604_021_Old_0_03_mirror.fbx,11:15:17:03,11:17:56:04,00:02:39:01,11:15:24:49,Old,1,11:15:29:37,11:17:53:13,021_Old_0_mirror.bvh,FALSE +022_Old_1.wav,11:20:07:18,11:22:22:04,00:02:14:15,,20210604_022_Old_1_01_mirror.fbx,11:20:04:57,11:22:14:50,00:02:09:53,11:20:12:11,Old,1,11:20:15:39,11:22:12:48,022_Old_1_mirror.bvh,FALSE +023_Old_2.wav,11:23:28:22,11:26:22:04,00:02:53:11,,20210604_023_Old_2_01_mirror.fbx,11:23:29:35,11:26:19:10,00:02:49:35,11:23:37:32,Old,1,11:23:41:56,11:26:16:53,023_Old_2_mirror.bvh,FALSE +024_Old_3.wav,11:29:12:11,11:31:40:02,00:02:27:20,,20210604_024_Old_3_01_mirror.fbx,11:29:12:27,11:31:38:10,00:02:25:43,No Clapping,Old,1,11:29:17:50,11:31:35:28,024_Old_3_mirror.bvh,FALSE +025_Old_4.wav,11:33:12:25,11:35:32:05,00:02:19:09,,20210604_025_Old_4_01_mirror.fbx,11:33:11:19,11:35:28:46,00:02:17:27,No Clapping,Old,1,11:33:18:01,11:35:26:41,025_Old_4_mirror.bvh,TRUE +026_Angry_0.wav,11:49:26:25,11:51:59:14,00:02:32:18,,20210604_026_Angry_0_01_mirror.fbx,11:49:35:28,11:51:56:29,00:02:21:01,11:49:41:46,Angry,1,11:49:45:25,11:51:54:45,026_Angry_0_mirror.bvh,FALSE +027_Angry_1.wav,11:57:32:22,11:59:35:04,00:02:02:12,,20210604_027_Angry_1_01_mirror.fbx,11:57:30:15,11:59:32:06,00:02:01:51,11:57:37:40,Angry,1,11:57:41:51,11:59:29:59,027_Angry_1_mirror.bvh,FALSE +028_Angry_2.wav,12:00:45:14,12:02:57:12,00:02:11:27,,20210604_028_Angry_2_01_mirror.fbx,12:00:47:19,12:02:52:50,00:02:05:31,12:00:55:11,Angry,1,12:00:59:08,12:02:51:09,028_Angry_2_mirror.bvh,FALSE +029_Angry_3.wav,12:04:08:28,12:06:31:16,00:02:22:18,,20210604_029_Angry_3_01_mirror.fbx,12:04:06:55,12:06:29:46,00:02:55:51,12:04:14:44,Angry,1,12:04:19:24,12:06:26:44,029_Angry_3_mirror.bvh,TRUE +030_Agreement_0.wav,11:36:47:02,11:39:10:27,00:02:23:25,,20210604_031_Agreement_01_mirror.fbx,11:36:52:17,11:39:04:38,00:02:12:21,No Clapping,Agreement,1,11:36:56:35,11:39:03:27,030_Agreement_0_mirror.bvh,FALSE +031_Disagreement_0.wav,11:41:12:29,11:43:26:16,00:02:13:16,,20210604_032_Disagreement_02_mirror.fbx,11:41:10:41,11:43:22:52,00:02:12:11,11:41:17:55,Disagreement,1,11:41:21:47,11:43:21:14,031_Disagreement_0_mirror.bvh,FALSE +032_Agreement_1.wav,09:11:50:20 ,09:13:40:08 ,00:01:49:17 ,,20210826_001_Agreement_2_02.fbx,09:11:50:07,09:13:37:20,00:01:47:13,09:11:59:24,Agreement,2,09:12:06:02,09:13:35:25,032_Agreement_1.bvh,FALSE +033_Agreement_2.wav,09:14:08:07 ,09:16:00:13 ,00:01:52:05 ,,20210826_002_Agreement_3_01.fbx,09:14:05:47,09:15:57:56,00:01:59:09,09:14:11:50,Agreement,2,09:14:17:21,09:15:56:07,033_Agreement_2.bvh,TRUE +034_Disagreement_1.wav,09:17:46:11 ,09:19:55:11 ,00:02:09:00 ,,20210826_003_Disagreement_2_02.fbx,09:17:51:25,09:19:45:52,00:01:54:27,09:17:57:53,Disagreement,2,09:18:03:05,09:19:44:12,034_Disagreement_1.bvh,FALSE +035_Disagreement_2.wav,09:20:35:10 ,09:22:27:26 ,00:01:52:15 ,,20210826_004_Disagreement_3_01.fbx,09:20:32:19,09:22:25:22,00:01:53:03,09:20:38:33,Disagreement,2,09:20:44:10,09:22:23:43,035_Disagreement_2.bvh,TRUE +036_Flirty_0.wav,09:23:40:23 ,09:25:34:25 ,00:01:54:02 ,,20210826_005_Flirty_0_01.fbx,09:23:38:17,09:25:31:06,00:01:52:49,09:23:43:48,Flirty,2,09:23:49:02,09:25:29:24,036_Flirty_0.bvh,FALSE +037_Flirty_1.wav,09:26:05:13 ,09:27:51:10 ,00:01:45:26 ,,20210826_006_Flirty_1_01.fbx,09:26:06:29,09:27:49:48,00:01:43:19,09:26:12:11,Flirty,2,09:26:17:53,09:27:47:40,037_Flirty_1.bvh,FALSE +038_Flirty_2.wav,09:28:42:12 ,09:30:58:28 ,00:02:16:15 ,,20210826_007_Flirty_2_01.fbx,09:28:51:59,09:30:50:22,00:01:58:23,09:28:57:12,Flirty,2,09:29:02:20,09:30:48:23,038_Flirty_2.bvh,TRUE +039_Pensive_0.wav,09:39:06:17 ,09:41:44:00 ,00:02:37:12 ,,20210826_008_Pensive_0_01.fbx,09:39:11:07,09:41:36:44,00:02:25:37,09:39:16:28,Pensive,2,09:39:22:03,09:41:35:16,039_Pensive_0.bvh,FALSE +040_Pensive_1.wav,09:42:17:17 ,09:44:56:26 ,00:02:39:08 ,,20210826_009_Pensive_1_01.fbx,09:42:11:23,09:44:36:28,00:02:25:07,09:42:16:53,Pensive,2,09:42:23:20,09:44:34:31,040_Pensive_1.bvh,FALSE +041_Pensive_2.wav,09:45:21:26 ,09:47:34:25 ,00:02:12:29 ,,20210826_010_Pensive_3_01.fbx,09:45:18:17,09:47:19:10,00:02:00:53,09:45:23:51,Pensive,2,09:45:29:15,09:47:17:32,041_Pensive_2.bvh,TRUE +042_Scared_0.wav,09:49:00:06 ,09:51:16:28 ,00:02:16:21 ,,20210826_011_Scared_0_01.fbx,09:48:55:51,09:50:59:52,00:02:04:01,09:49:00:58,Scared,2,09:49:05:14,09:50:57:43,042_Scared_0.bvh,FALSE +043_Scared_1.wav,09:53:19:14 ,09:55:17:23 ,00:01:58:09 ,,20210826_012_Scared_1_01.fbx,09:53:16:17,09:55:08:16,00:01:51:59,09:53:21:24,Scared,2,09:53:25:42,09:55:05:53,043_Scared_1.bvh,FALSE +044_Scared_2.wav,09:55:34:24 ,09:57:57:11 ,00:02:22:16 ,,20210826_013_Scared_2_01.fbx,09:55:39:19,09:57:54:14,00:02:14:55,09:55:45:01,Scared,2,09:55:49:58,09:57:52:07,044_Scared_2.bvh,TRUE +045_Distracted_0.wav,09:58:55:17 ,10:00:53:22 ,00:01:58:05 ,,20210826_014_Distracted_0_01.fbx,09:58:57:15,10:00:47:16,00:01:50:01,09:59:02:42,Distracted,2,09:59:07:36,10:00:45:18,045_Distracted_0.bvh,FALSE +046_Distracted_1.wav,10:01:48:29 ,10:04:12:16 ,00:02:23:17 ,,20210826_015_Distracted_1_01.fbx,10:01:53:49,10:04:03:52,00:02:10:03,10:01:59:28,Distracted,2,10:02:04:18,10:04:02:29,046_Distracted_1.bvh,FALSE +047_Distracted_2.wav,10:04:42:23 ,10:06:51:10 ,00:02:08:17 ,,20210826_016_Distracted_2_01.fbx,10:04:43:53,10:06:38:48,00:01:54:55,10:04:49:40,Distracted,2,10:04:55:36,10:06:37:04,047_Distracted_2.bvh,TRUE +048_Sarcastic_0.wav,10:09:15:26 ,10:11:14:11 ,00:01:58:14 ,,20210826_020_Sarcastic_0_01.fbx,10:09:10:11,10:11:14:58,00:02:04:47,10:09:16:10,Sarcastic,2,10:09:22:21,10:11:13:17,048_Sarcastic_0.bvh,FALSE +049_Sarcastic_1.wav,10:11:45:10 ,10:14:40:14 ,00:02:55:03 ,,20210826_021_Sarcastic_1_01.fbx,10:11:42:47,10:14:28:18,00:02:45:31,10:11:48:00,Sarcastic,2,10:11:52:53,10:14:26:27,049_Sarcastic_1.bvh,FALSE +050_Sarcastic_2.wav,10:15:04:12 ,10:17:23:06 ,00:02:18:24 ,,20210826_022_Sarcastic_2_01.fbx,10:15:00:47,10:17:20:50,00:02:20:03,10:15:06:26,Sarcastic,2,10:15:12:38,10:17:19:06,050_Sarcastic_2.bvh,TRUE +051_Threatening_0.wav,10:19:33:27 ,10:21:50:22 ,00:02:16:24 ,,20210826_023_Threatening_0_01.fbx,10:19:40:09,10:21:46:24,00:02:06:15,10:19:45:52,Threatening,2,10:19:51:58,10:21:44:17,051_Threatening_0.bvh,FALSE +052_Threatening_1.wav,10:22:23:06 ,10:24:24:04 ,00:02:00:27 ,,20210826_024_Threatening_1_01.fbx,10:22:20:05,10:24:19:30,00:01:59:25,10:22:26:11,Threatening,2,10:22:31:35,10:24:17:46,052_Threatening_1.bvh,FALSE +053_Threatening_2.wav,10:25:17:25 ,10:27:41:00 ,00:02:23:04 ,,20210826_025_Threatening_2_01.fbx,10:25:11:05,10:27:36:32,00:02:25:27,10:25:16:59,Threatening,2,10:25:22:49,10:27:34:30,053_Threatening_2.bvh,TRUE +054_Still_0.wav,10:29:00:06 ,10:31:00:08 ,00:02:00:02 ,,20210826_026_Still_0_01.fbx,10:28:55:43,10:30:55:56,00:02:00:13,10:29:00:52,Still,2,10:29:05:25,10:30:53:43,054_Still_0.bvh,FALSE +055_Still_1.wav,10:36:58:01 ,10:38:51:14 ,00:01:53:12 ,,20210826_027_Still_1_01.fbx,10:36:50:23,10:38:49:14,00:01:58:51,10:36:57:45,Still,2,10:37:03:10,10:38:47:38,055_Still_1.bvh,FALSE +056_Still_2.wav,10:39:55:17 ,10:41:53:00 ,00:01:57:13 ,,20210826_028_Still_2_01.fbx,10:39:51:51,10:41:50:28,00:01:58:37,10:39:56:41,Still,2,10:40:01:53,10:41:48:56,056_Still_2.bvh,TRUE +057_Laughing_0.wav,10:44:56:12 ,10:47:14:25 ,00:02:18:13 ,,20210826_029_Laughing_0_01.fbx,10:44:58:37,10:47:04:56,00:02:06:19,10:45:04:34,Laughing,2,10:45:09:56,10:47:03:01,057_Laughing_0.bvh,FALSE +058_Laughing_1.wav,10:48:30:08 ,10:52:00:12 ,00:03:30:04 ,,20210826_030_Laughing_1_01.fbx,10:48:30:17,10:50:39:24,00:02:09:07,10:48:35:28,Laughing,2,10:48:39:50,10:50:37:40,058_Laughing_1.bvh,TRUE +059_Sneaky_0.wav,10:52:48:16 ,10:54:59:15 ,00:02:10:29 ,,20210826_032_Sneaky_0_01.fbx,10:52:48:39,10:54:58:40,00:02:10:01,10:52:54:13,Sneaky,2,10:52:59:39,10:54:56:55,059_Sneaky_0.bvh,FALSE +060_Sneaky_1.wav,10:56:12:13 ,10:58:15:18 ,00:02:03:05 ,,20210826_033_Sneaky_1_01.fbx,10:56:06:41,10:58:09:18,00:02:02:37,10:56:12:02,Sneaky,2,10:56:16:59,10:58:07:23,060_Sneaky_1.bvh,FALSE +061_Sneaky_2.wav,10:58:33:03 ,11:01:17:05 ,00:02:44:01 ,,20210826_034_Sneaky_2_01.fbx,10:58:31:51,11:01:14:06,00:02:42:15,10:58:38:30,Sneaky,2,10:58:43:35,11:01:12:01,061_Sneaky_2.bvh,TRUE +062_Tired_0.wav,11:02:59:09 ,11:06:23:15 ,00:03:24:05 ,,20210826_035_Tired_1_01.fbx,11:02:59:57,11:06:08:04,00:03:08:07,11:03:05:51,Tired,2,11:03:10:55,11:06:06:42,062_Tired_0.bvh,FALSE +063_Tired_1.wav,11:06:41:14 ,11:09:18:19 ,00:02:37:04 ,,20210826_036_Tired_2_01.fbx,11:06:41:39,11:09:14:32,00:02:32:53,11:06:50:36,Tired,2,11:06:55:32,11:09:11:32,063_Tired_1.bvh,FALSE +064_Tired_2.wav,11:09:48:04 ,11:12:06:06 ,00:02:18:02 ,,20210826_037_Tired_3_01.fbx,11:09:48:23,11:11:59:20,00:02:10:57,11:09:54:47,Tired,2,11:10:00:29,11:11:56:33,064_Tired_2.bvh,TRUE +065_Speech_0.wav,11:28:37:00 ,11:30:40:09 ,00:02:03:09 ,,20210826_056_Speech_0_02.fbx,11:28:38:23,11:30:12:52,00:01:34:29,11:28:43:10,Speech,2,11:28:48:04,11:30:11:07,065_Speech_0.bvh,TRUE +066_Speech_1.wav,11:31:29:22 ,11:33:28:06 ,00:01:58:14 ,,20210826_057_Speech_1_02.fbx,11:31:32:43,11:33:12:08,00:01:39:25,11:31:38:15,Speech,2,11:31:43:08,11:33:10:04,066_Speech_1.bvh,TRUE +067_Speech_2.wav,11:33:48:20 ,11:35:17:13 ,00:01:28:23 ,,20210826_058_Speech_2_01.fbx,11:33:48:21,11:35:11:02,00:01:22:41,11:33:54:17,Speech,2,11:34:00:10,11:35:09:00,067_Speech_2.bvh,TRUE +032_Agreement_1.wav,09:11:50:20 ,09:13:40:08 ,00:01:49:17 ,,20210826_001_Agreement_2_02_mirror.fbx,09:11:50:07,09:13:37:20,00:01:47:13,09:11:59:24,Agreement,2,09:12:06:02,09:13:35:25,032_Agreement_1_mirror.bvh,FALSE +033_Agreement_2.wav,09:14:08:07 ,09:16:00:13 ,00:01:52:05 ,,20210826_002_Agreement_3_01_mirror.fbx,09:14:05:47,09:15:57:56,00:01:59:09,09:14:11:50,Agreement,2,09:14:17:21,09:15:56:07,033_Agreement_2_mirror.bvh,TRUE +034_Disagreement_1.wav,09:17:46:11 ,09:19:55:11 ,00:02:09:00 ,,20210826_003_Disagreement_2_02_mirror.fbx,09:17:51:25,09:19:45:52,00:01:54:27,09:17:57:53,Disagreement,2,09:18:03:05,09:19:44:12,034_Disagreement_1_mirror.bvh,FALSE +035_Disagreement_2.wav,09:20:35:10 ,09:22:27:26 ,00:01:52:15 ,,20210826_004_Disagreement_3_01_mirror.fbx,09:20:32:19,09:22:25:22,00:01:53:03,09:20:38:33,Disagreement,2,09:20:44:10,09:22:23:43,035_Disagreement_2_mirror.bvh,TRUE +036_Flirty_0.wav,09:23:40:23 ,09:25:34:25 ,00:01:54:02 ,,20210826_005_Flirty_0_01_mirror.fbx,09:23:38:17,09:25:31:06,00:01:52:49,09:23:43:48,Flirty,2,09:23:49:02,09:25:29:24,036_Flirty_0_mirror.bvh,FALSE +037_Flirty_1.wav,09:26:05:13 ,09:27:51:10 ,00:01:45:26 ,,20210826_006_Flirty_1_01_mirror.fbx,09:26:06:29,09:27:49:48,00:01:43:19,09:26:12:11,Flirty,2,09:26:17:53,09:27:47:40,037_Flirty_1_mirror.bvh,FALSE +038_Flirty_2.wav,09:28:42:12 ,09:30:58:28 ,00:02:16:15 ,,20210826_007_Flirty_2_01_mirror.fbx,09:28:51:59,09:30:50:22,00:01:58:23,09:28:57:12,Flirty,2,09:29:02:20,09:30:48:23,038_Flirty_2_mirror.bvh,TRUE +039_Pensive_0.wav,09:39:06:17 ,09:41:44:00 ,00:02:37:12 ,,20210826_008_Pensive_0_01_mirror.fbx,09:39:11:07,09:41:36:44,00:02:25:37,09:39:16:28,Pensive,2,09:39:22:03,09:41:35:16,039_Pensive_0_mirror.bvh,FALSE +040_Pensive_1.wav,09:42:17:17 ,09:44:56:26 ,00:02:39:08 ,,20210826_009_Pensive_1_01_mirror.fbx,09:42:11:23,09:44:36:28,00:02:25:07,09:42:16:53,Pensive,2,09:42:23:20,09:44:34:31,040_Pensive_1_mirror.bvh,FALSE +041_Pensive_2.wav,09:45:21:26 ,09:47:34:25 ,00:02:12:29 ,,20210826_010_Pensive_3_01_mirror.fbx,09:45:18:17,09:47:19:10,00:02:00:53,09:45:23:51,Pensive,2,09:45:29:15,09:47:17:32,041_Pensive_2_mirror.bvh,TRUE +042_Scared_0.wav,09:49:00:06 ,09:51:16:28 ,00:02:16:21 ,,20210826_011_Scared_0_01_mirror.fbx,09:48:55:51,09:50:59:52,00:02:04:01,09:49:00:58,Scared,2,09:49:05:14,09:50:57:43,042_Scared_0_mirror.bvh,FALSE +043_Scared_1.wav,09:53:19:14 ,09:55:17:23 ,00:01:58:09 ,,20210826_012_Scared_1_01_mirror.fbx,09:53:16:17,09:55:08:16,00:01:51:59,09:53:21:24,Scared,2,09:53:25:42,09:55:05:53,043_Scared_1_mirror.bvh,FALSE +044_Scared_2.wav,09:55:34:24 ,09:57:57:11 ,00:02:22:16 ,,20210826_013_Scared_2_01_mirror.fbx,09:55:39:19,09:57:54:14,00:02:14:55,09:55:45:01,Scared,2,09:55:49:58,09:57:52:07,044_Scared_2_mirror.bvh,TRUE +045_Distracted_0.wav,09:58:55:17 ,10:00:53:22 ,00:01:58:05 ,,20210826_014_Distracted_0_01_mirror.fbx,09:58:57:15,10:00:47:16,00:01:50:01,09:59:02:42,Distracted,2,09:59:07:36,10:00:45:18,045_Distracted_0_mirror.bvh,FALSE +046_Distracted_1.wav,10:01:48:29 ,10:04:12:16 ,00:02:23:17 ,,20210826_015_Distracted_1_01_mirror.fbx,10:01:53:49,10:04:03:52,00:02:10:03,10:01:59:28,Distracted,2,10:02:04:18,10:04:02:29,046_Distracted_1_mirror.bvh,FALSE +047_Distracted_2.wav,10:04:42:23 ,10:06:51:10 ,00:02:08:17 ,,20210826_016_Distracted_2_01_mirror.fbx,10:04:43:53,10:06:38:48,00:01:54:55,10:04:49:40,Distracted,2,10:04:55:36,10:06:37:04,047_Distracted_2_mirror.bvh,TRUE +048_Sarcastic_0.wav,10:09:15:26 ,10:11:14:11 ,00:01:58:14 ,,20210826_020_Sarcastic_0_01_mirror.fbx,10:09:10:11,10:11:14:58,00:02:04:47,10:09:16:10,Sarcastic,2,10:09:22:21,10:11:13:17,048_Sarcastic_0_mirror.bvh,FALSE +049_Sarcastic_1.wav,10:11:45:10 ,10:14:40:14 ,00:02:55:03 ,,20210826_021_Sarcastic_1_01_mirror.fbx,10:11:42:47,10:14:28:18,00:02:45:31,10:11:48:00,Sarcastic,2,10:11:52:53,10:14:26:27,049_Sarcastic_1_mirror.bvh,FALSE +050_Sarcastic_2.wav,10:15:04:12 ,10:17:23:06 ,00:02:18:24 ,,20210826_022_Sarcastic_2_01_mirror.fbx,10:15:00:47,10:17:20:50,00:02:20:03,10:15:06:26,Sarcastic,2,10:15:12:38,10:17:19:06,050_Sarcastic_2_mirror.bvh,TRUE +051_Threatening_0.wav,10:19:33:27 ,10:21:50:22 ,00:02:16:24 ,,20210826_023_Threatening_0_01_mirror.fbx,10:19:40:09,10:21:46:24,00:02:06:15,10:19:45:52,Threatening,2,10:19:51:58,10:21:44:17,051_Threatening_0_mirror.bvh,FALSE +052_Threatening_1.wav,10:22:23:06 ,10:24:24:04 ,00:02:00:27 ,,20210826_024_Threatening_1_01_mirror.fbx,10:22:20:05,10:24:19:30,00:01:59:25,10:22:26:11,Threatening,2,10:22:31:35,10:24:17:46,052_Threatening_1_mirror.bvh,FALSE +053_Threatening_2.wav,10:25:17:25 ,10:27:41:00 ,00:02:23:04 ,,20210826_025_Threatening_2_01_mirror.fbx,10:25:11:05,10:27:36:32,00:02:25:27,10:25:16:59,Threatening,2,10:25:22:49,10:27:34:30,053_Threatening_2_mirror.bvh,TRUE +054_Still_0.wav,10:29:00:06 ,10:31:00:08 ,00:02:00:02 ,,20210826_026_Still_0_01_mirror.fbx,10:28:55:43,10:30:55:56,00:02:00:13,10:29:00:52,Still,2,10:29:05:25,10:30:53:43,054_Still_0_mirror.bvh,FALSE +055_Still_1.wav,10:36:58:01 ,10:38:51:14 ,00:01:53:12 ,,20210826_027_Still_1_01_mirror.fbx,10:36:50:23,10:38:49:14,00:01:58:51,10:36:57:45,Still,2,10:37:03:10,10:38:47:38,055_Still_1_mirror.bvh,FALSE +056_Still_2.wav,10:39:55:17 ,10:41:53:00 ,00:01:57:13 ,,20210826_028_Still_2_01_mirror.fbx,10:39:51:51,10:41:50:28,00:01:58:37,10:39:56:41,Still,2,10:40:01:53,10:41:48:56,056_Still_2_mirror.bvh,TRUE +057_Laughing_0.wav,10:44:56:12 ,10:47:14:25 ,00:02:18:13 ,,20210826_029_Laughing_0_01_mirror.fbx,10:44:58:37,10:47:04:56,00:02:06:19,10:45:04:34,Laughing,2,10:45:09:56,10:47:03:01,057_Laughing_0_mirror.bvh,FALSE +058_Laughing_1.wav,10:48:30:08 ,10:52:00:12 ,00:03:30:04 ,,20210826_030_Laughing_1_01_mirror.fbx,10:48:30:17,10:50:39:24,00:02:09:07,10:48:35:28,Laughing,2,10:48:39:50,10:50:37:40,058_Laughing_1_mirror.bvh,TRUE +059_Sneaky_0.wav,10:52:48:16 ,10:54:59:15 ,00:02:10:29 ,,20210826_032_Sneaky_0_01_mirror.fbx,10:52:48:39,10:54:58:40,00:02:10:01,10:52:54:13,Sneaky,2,10:52:59:39,10:54:56:55,059_Sneaky_0_mirror.bvh,FALSE +060_Sneaky_1.wav,10:56:12:13 ,10:58:15:18 ,00:02:03:05 ,,20210826_033_Sneaky_1_01_mirror.fbx,10:56:06:41,10:58:09:18,00:02:02:37,10:56:12:02,Sneaky,2,10:56:16:59,10:58:07:23,060_Sneaky_1_mirror.bvh,FALSE +061_Sneaky_2.wav,10:58:33:03 ,11:01:17:05 ,00:02:44:01 ,,20210826_034_Sneaky_2_01_mirror.fbx,10:58:31:51,11:01:14:06,00:02:42:15,10:58:38:30,Sneaky,2,10:58:43:35,11:01:12:01,061_Sneaky_2_mirror.bvh,TRUE +062_Tired_0.wav,11:02:59:09 ,11:06:23:15 ,00:03:24:05 ,,20210826_035_Tired_1_01_mirror.fbx,11:02:59:57,11:06:08:04,00:03:08:07,11:03:05:51,Tired,2,11:03:10:55,11:06:06:42,062_Tired_0_mirror.bvh,FALSE +063_Tired_1.wav,11:06:41:14 ,11:09:18:19 ,00:02:37:04 ,,20210826_036_Tired_2_01_mirror.fbx,11:06:41:39,11:09:14:32,00:02:32:53,11:06:50:36,Tired,2,11:06:55:32,11:09:11:32,063_Tired_1_mirror.bvh,FALSE +064_Tired_2.wav,11:09:48:04 ,11:12:06:06 ,00:02:18:02 ,,20210826_037_Tired_3_01_mirror.fbx,11:09:48:23,11:11:59:20,00:02:10:57,11:09:54:47,Tired,2,11:10:00:29,11:11:56:33,064_Tired_2_mirror.bvh,TRUE +065_Speech_0.wav,11:28:37:00 ,11:30:40:09 ,00:02:03:09 ,,20210826_056_Speech_0_02_mirror.fbx,11:28:38:23,11:30:12:52,00:01:34:29,11:28:43:10,Speech,2,11:28:48:04,11:30:11:07,065_Speech_0_mirror.bvh,TRUE +066_Speech_1.wav,11:31:29:22 ,11:33:28:06 ,00:01:58:14 ,,20210826_057_Speech_1_02_mirror.fbx,11:31:32:43,11:33:12:08,00:01:39:25,11:31:38:15,Speech,2,11:31:43:08,11:33:10:04,066_Speech_1_mirror.bvh,TRUE +067_Speech_2.wav,11:33:48:20 ,11:35:17:13 ,00:01:28:23 ,,20210826_058_Speech_2_01_mirror.fbx,11:33:48:21,11:35:11:02,00:01:22:41,11:33:54:17,Speech,2,11:34:00:10,11:35:09:00,067_Speech_2_mirror.bvh,TRUE diff --git a/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/mean.npz b/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/mean.npz new file mode 100644 index 0000000000000000000000000000000000000000..c338db6807ebccca38460e269d565fe1a7664398 --- /dev/null +++ b/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/mean.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:722d7ee19360675c7a9d4bf7309d7d96240cee72ebd47b6ba769907a17db3863 +size 4592 diff --git a/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/std.npz b/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/std.npz new file mode 100644 index 0000000000000000000000000000000000000000..39d4e489c530dbce6ae752d192326416d68eb728 --- /dev/null +++ b/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/processed/std.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61a666d8092125e586c228592f621b8ece3808ca61761c140587e82d7c591edd +size 4121 diff --git a/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/stats.npz b/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/stats.npz new file mode 100644 index 0000000000000000000000000000000000000000..1c532192b6322b4435a91a6b362e6f7f2eb67399 --- /dev/null +++ b/ubisoft-laforge-ZeroEGGS-main/data/processed_v1/stats.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa29a43cb94f36348376d2a12787dfe4a6563cd54e722821ca46473cf86a2cef +size 30460 diff --git a/ubisoft-laforge-ZeroEGGS-main/data/test/cli.txt b/ubisoft-laforge-ZeroEGGS-main/data/test/cli.txt new file mode 100644 index 0000000000000000000000000000000000000000..6e7249c2a2d4d19499107b756e26d61760a8ae85 --- /dev/null +++ b/ubisoft-laforge-ZeroEGGS-main/data/test/cli.txt @@ -0,0 +1,2 @@ +python .\generate.py -o ..\data\outputs\v1\options.json -s ..\data\clean\005_Neutral_4_x_1_0.bvh -a ..\data\clean\005_Neutral_4_x_1_0.wav +python .\generate.py -o ..\data\outputs\v1\options.json -c ..\data\test\evaluation.csv \ No newline at end of file diff --git a/ubisoft-laforge-ZeroEGGS-main/data/test/evaluation.csv b/ubisoft-laforge-ZeroEGGS-main/data/test/evaluation.csv new file mode 100644 index 0000000000000000000000000000000000000000..696d9e547871267be45cfbbf80f851e1fb13c1ca --- /dev/null +++ b/ubisoft-laforge-ZeroEGGS-main/data/test/evaluation.csv @@ -0,0 +1,21 @@ +base_path,audio,style,file_name,temperature,seed,use_gpu,frames,generate +..\data\clean,005_Neutral_4_x_1_0.wav,005_Neutral_4_x_1_0.bvh,Neutral,1,1234,TRUE,,TRUE +..\data\clean,010_Sad_4_x_1_0.wav,010_Sad_4_x_1_0.bvh,Sad,1,1234,TRUE,,TRUE +..\data\clean,015_Happy_4_x_1_0.wav,015_Happy_4_x_1_0.bvh,Happy,1,1234,TRUE,,TRUE +..\data\clean,020_Relaxed_4_x_1_0.wav,020_Relaxed_4_x_1_0.bvh,Relaxed,1,1234,TRUE,,TRUE +..\data\clean,025_Old_4_x_1_0.wav,025_Old_4_x_1_0.bvh,Old,1,1234,TRUE,,TRUE +..\data\clean,029_Angry_3_x_1_0.wav,029_Angry_3_x_1_0.bvh,Angry,1,1234,TRUE,250 850,TRUE +..\data\clean,033_Agreement_2_x_1_0.wav,033_Agreement_2_x_1_0.bvh,Agreement,1,1234,TRUE,,TRUE +..\data\clean,035_Disagreement_2_x_1_0.wav,035_Disagreement_2_x_1_0.bvh,Disagreement,1,1234,TRUE,,TRUE +..\data\clean,038_Flirty_2_x_1_0.wav,038_Flirty_2_x_1_0.bvh,Flirty,1,1234,TRUE,,TRUE +..\data\clean,041_Pensive_2_x_1_0.wav,041_Pensive_2_x_1_0.bvh,Pensive,1,1234,TRUE,,TRUE +..\data\clean,044_Scared_2_x_1_0.wav,044_Scared_2_x_1_0.bvh,Scared,1,1234,TRUE,,TRUE +..\data\clean,047_Distracted_2_x_1_0.wav,047_Distracted_2_x_1_0.bvh,Distracted,1,1234,TRUE,,TRUE +..\data\clean,050_Sarcastic_2_x_1_0.wav,050_Sarcastic_2_x_1_0.bvh,Sarcastic,1,1234,TRUE,,TRUE +..\data\clean,053_Threatening_2_x_1_0.wav,053_Threatening_2_x_1_0.bvh,Threatening,1,1234,TRUE,,TRUE +..\data\clean,056_Still_2_x_1_0.wav,056_Still_2_x_1_0.bvh,Still,1,1234,TRUE,,TRUE +..\data\clean,058_Laughing_1_x_1_0.wav,058_Laughing_1_x_1_0.bvh,Laughing,1,1234,TRUE,,TRUE +..\data\clean,061_Sneaky_2_x_1_0.wav,061_Sneaky_2_x_1_0.bvh,Sneaky,1,1234,TRUE,,TRUE +..\data\clean,064_Tired_2_x_1_0.wav,064_Tired_2_x_1_0.bvh,Tired,1,1234,TRUE,,TRUE +..\data\clean,067_Speech_2_x_1_0.wav,067_Speech_2_x_1_0.bvh,Speech_1,1,1234,TRUE,,TRUE +..\data\clean,067_Speech_2_x_1_0.wav,067_Speech_2_x_1_0.bvh,Speech_2,1,5678,TRUE,,TRUE diff --git a/ubisoft-laforge-ZeroEGGS-main/requirements.txt b/ubisoft-laforge-ZeroEGGS-main/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..069df5a125befbc2a8355aa695a2503756dc49c3 Binary files /dev/null and b/ubisoft-laforge-ZeroEGGS-main/requirements.txt differ