samadi10 commited on
Commit
eeaa83d
1 Parent(s): 2c84e5f

Added necessary files

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /.DS_STORE
GPT_eval_multi.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from torch.utils.tensorboard import SummaryWriter
5
+ import json
6
+ import clip
7
+
8
+ import options.option_transformer as option_trans
9
+ import models.vqvae as vqvae
10
+ import utils.utils_model as utils_model
11
+ import utils.eval_trans as eval_trans
12
+ from dataset import dataset_TM_eval
13
+ import models.t2m_trans as trans
14
+ from options.get_eval_option import get_opt
15
+ from models.evaluator_wrapper import EvaluatorModelWrapper
16
+ import warnings
17
+ warnings.filterwarnings('ignore')
18
+ from exit.utils import base_dir, init_save_folder
19
+
20
+ ##### ---- Exp dirs ---- #####
21
+ args = option_trans.get_args_parser()
22
+ torch.manual_seed(args.seed)
23
+
24
+ args.out_dir = f'{args.out_dir}/eval'
25
+ os.makedirs(args.out_dir, exist_ok = True)
26
+ init_save_folder(args)
27
+
28
+ ##### ---- Logger ---- #####
29
+ logger = utils_model.get_logger(args.out_dir)
30
+ writer = SummaryWriter(args.out_dir)
31
+ logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
32
+
33
+ from utils.word_vectorizer import WordVectorizer
34
+ w_vectorizer = WordVectorizer('./glove', 'our_vab')
35
+ val_loader = dataset_TM_eval.DATALoader(args.dataname, True, 32, w_vectorizer)
36
+
37
+ dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataname == 'kit' else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
38
+
39
+ wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
40
+ eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
41
+
42
+ ##### ---- Network ---- #####
43
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False) # Must set jit=False for training
44
+ clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16
45
+ clip_model.eval()
46
+ for p in clip_model.parameters():
47
+ p.requires_grad = False
48
+
49
+ # https://github.com/openai/CLIP/issues/111
50
+ class TextCLIP(torch.nn.Module):
51
+ def __init__(self, model) :
52
+ super(TextCLIP, self).__init__()
53
+ self.model = model
54
+
55
+ def forward(self,text):
56
+ with torch.no_grad():
57
+ word_emb = self.model.token_embedding(text).type(self.model.dtype)
58
+ word_emb = word_emb + self.model.positional_embedding.type(self.model.dtype)
59
+ word_emb = word_emb.permute(1, 0, 2) # NLD -> LND
60
+ word_emb = self.model.transformer(word_emb)
61
+ word_emb = self.model.ln_final(word_emb).permute(1, 0, 2).float()
62
+ enctxt = self.model.encode_text(text).float()
63
+ return enctxt, word_emb
64
+ clip_model = TextCLIP(clip_model)
65
+
66
+ net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
67
+ args.nb_code,
68
+ args.code_dim,
69
+ args.output_emb_width,
70
+ args.down_t,
71
+ args.stride_t,
72
+ args.width,
73
+ args.depth,
74
+ args.dilation_growth_rate)
75
+
76
+
77
+ trans_encoder = trans.Text2Motion_Transformer(net,
78
+ num_vq=args.nb_code,
79
+ embed_dim=args.embed_dim_gpt,
80
+ clip_dim=args.clip_dim,
81
+ block_size=args.block_size,
82
+ num_layers=args.num_layers,
83
+ num_local_layer=args.num_local_layer,
84
+ n_head=args.n_head_gpt,
85
+ drop_out_rate=args.drop_out_rate,
86
+ fc_rate=args.ff_rate)
87
+
88
+
89
+ print ('loading checkpoint from {}'.format(args.resume_pth))
90
+ ckpt = torch.load(args.resume_pth, map_location='cpu')
91
+ net.load_state_dict(ckpt['net'], strict=True)
92
+ net.eval()
93
+ net.cuda()
94
+
95
+ if args.resume_trans is not None:
96
+ print ('loading transformer checkpoint from {}'.format(args.resume_trans))
97
+ ckpt = torch.load(args.resume_trans, map_location='cpu')
98
+ trans_encoder.load_state_dict(ckpt['trans'], strict=True)
99
+ trans_encoder.train()
100
+ trans_encoder.cuda()
101
+
102
+
103
+ fid = []
104
+ div = []
105
+ top1 = []
106
+ top2 = []
107
+ top3 = []
108
+ matching = []
109
+ multi = []
110
+ repeat_time = 20
111
+
112
+ from tqdm import tqdm
113
+ for i in tqdm(range(repeat_time)):
114
+ pred_pose_eval, pose, m_length, clip_text, \
115
+ best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, best_multi, writer, logger = eval_trans.evaluation_transformer(args.out_dir, val_loader, net, trans_encoder, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, clip_model=clip_model, eval_wrapper=eval_wrapper, dataname=args.dataname, save = False, num_repeat=11, rand_pos=True)
116
+ fid.append(best_fid)
117
+ div.append(best_div)
118
+ top1.append(best_top1)
119
+ top2.append(best_top2)
120
+ top3.append(best_top3)
121
+ matching.append(best_matching)
122
+ multi.append(best_multi)
123
+
124
+ print('final result:')
125
+ print('fid: ', sum(fid)/repeat_time)
126
+ print('div: ', sum(div)/repeat_time)
127
+ print('top1: ', sum(top1)/repeat_time)
128
+ print('top2: ', sum(top2)/repeat_time)
129
+ print('top3: ', sum(top3)/repeat_time)
130
+ print('matching: ', sum(matching)/repeat_time)
131
+ print('multi: ', sum(multi)/repeat_time)
132
+
133
+ fid = np.array(fid)
134
+ div = np.array(div)
135
+ top1 = np.array(top1)
136
+ top2 = np.array(top2)
137
+ top3 = np.array(top3)
138
+ matching = np.array(matching)
139
+ multi = np.array(multi)
140
+ msg_final = f"FID. {np.mean(fid):.3f}, conf. {np.std(fid)*1.96/np.sqrt(repeat_time):.3f}, Diversity. {np.mean(div):.3f}, conf. {np.std(div)*1.96/np.sqrt(repeat_time):.3f}, TOP1. {np.mean(top1):.3f}, conf. {np.std(top1)*1.96/np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2)*1.96/np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3)*1.96/np.sqrt(repeat_time):.3f}, Matching. {np.mean(matching):.3f}, conf. {np.std(matching)*1.96/np.sqrt(repeat_time):.3f}, Multi. {np.mean(multi):.3f}, conf. {np.std(multi)*1.96/np.sqrt(repeat_time):.3f}"
141
+ logger.info(msg_final)
LICENSE-CC-BY-NC-ND-4.0.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Attribution-NonCommercial-NoDerivatives 4.0 International
2
+
3
+ > *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.*
4
+ >
5
+ > ### Using Creative Commons Public Licenses
6
+ >
7
+ > Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
8
+ >
9
+ > * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
10
+ >
11
+ > * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
12
+
13
+ ## Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License
14
+
15
+ By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
16
+
17
+ ### Section 1 – Definitions.
18
+
19
+ a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
20
+
21
+ b. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
22
+
23
+ e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
24
+
25
+ f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
26
+
27
+ h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
28
+
29
+ i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
30
+
31
+ h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
32
+
33
+ i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
34
+
35
+ j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
36
+
37
+ k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
38
+
39
+ l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
40
+
41
+ ### Section 2 – Scope.
42
+
43
+ a. ___License grant.___
44
+
45
+ 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
46
+
47
+ A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
48
+
49
+ B. produce and reproduce, but not Share, Adapted Material for NonCommercial purposes only.
50
+
51
+ 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
52
+
53
+ 3. __Term.__ The term of this Public License is specified in Section 6(a).
54
+
55
+ 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
56
+
57
+ 5. __Downstream recipients.__
58
+
59
+ A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
60
+
61
+ B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
62
+
63
+ 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
64
+
65
+ b. ___Other rights.___
66
+
67
+ 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
68
+
69
+ 2. Patent and trademark rights are not licensed under this Public License.
70
+
71
+ 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
72
+
73
+ ### Section 3 – License Conditions.
74
+
75
+ Your exercise of the Licensed Rights is expressly made subject to the following conditions.
76
+
77
+ a. ___Attribution.___
78
+
79
+ 1. If You Share the Licensed Material, You must:
80
+
81
+ A. retain the following if it is supplied by the Licensor with the Licensed Material:
82
+
83
+ i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
84
+
85
+ ii. a copyright notice;
86
+
87
+ iii. a notice that refers to this Public License;
88
+
89
+ iv. a notice that refers to the disclaimer of warranties;
90
+
91
+ v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
92
+
93
+ B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
94
+
95
+ C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
96
+
97
+ For the avoidance of doubt, You do not have permission under this Public License to Share Adapted Material.
98
+
99
+ 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
100
+
101
+ 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
102
+
103
+ ### Section 4 – Sui Generis Database Rights.
104
+
105
+ Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
106
+
107
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only and provided You do not Share Adapted Material;
108
+
109
+ b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
110
+
111
+ c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
112
+
113
+ For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
114
+
115
+ ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
116
+
117
+ a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
118
+
119
+ b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
120
+
121
+ c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
122
+
123
+ ### Section 6 – Term and Termination.
124
+
125
+ a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
126
+
127
+ b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
128
+
129
+ 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
130
+
131
+ 2. upon express reinstatement by the Licensor.
132
+
133
+ For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
134
+
135
+ c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
136
+
137
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
138
+
139
+ ### Section 7 – Other Terms and Conditions.
140
+
141
+ a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
142
+
143
+ b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
144
+
145
+ ### Section 8 – Interpretation.
146
+
147
+ a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
148
+
149
+ b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
150
+
151
+ c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
152
+
153
+ d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
154
+
155
+ > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
156
+ >
157
+ > Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org).
dataset/dataset_TM_eval.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ from os.path import join as pjoin
5
+ import random
6
+ import codecs as cs
7
+ from tqdm import tqdm
8
+
9
+ import utils.paramUtil as paramUtil
10
+ from torch.utils.data._utils.collate import default_collate
11
+
12
+
13
+ def collate_fn(batch):
14
+ batch.sort(key=lambda x: x[3], reverse=True)
15
+ return default_collate(batch)
16
+
17
+
18
+ '''For use of training text-2-motion generative model'''
19
+ class Text2MotionDataset(data.Dataset):
20
+ def __init__(self, dataset_name, is_test, w_vectorizer, feat_bias = 5, max_text_len = 20, unit_length = 4, shuffle=True):
21
+
22
+ self.max_length = 20
23
+ self.pointer = 0
24
+ self.dataset_name = dataset_name
25
+ self.is_test = is_test
26
+ self.max_text_len = max_text_len
27
+ self.unit_length = unit_length
28
+ self.w_vectorizer = w_vectorizer
29
+ if dataset_name == 't2m':
30
+ self.data_root = './dataset/HumanML3D'
31
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
32
+ self.text_dir = pjoin(self.data_root, 'texts')
33
+ self.joints_num = 22
34
+ radius = 4
35
+ fps = 20
36
+ self.max_motion_length = 196
37
+ dim_pose = 263
38
+ kinematic_chain = paramUtil.t2m_kinematic_chain
39
+ self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
40
+ elif dataset_name == 'kit':
41
+ self.data_root = './dataset/KIT-ML'
42
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
43
+ self.text_dir = pjoin(self.data_root, 'texts')
44
+ self.joints_num = 21
45
+ radius = 240 * 8
46
+ fps = 12.5
47
+ dim_pose = 251
48
+ self.max_motion_length = 196
49
+ kinematic_chain = paramUtil.kit_kinematic_chain
50
+ self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
51
+
52
+ mean = np.load(pjoin(self.meta_dir, 'mean.npy'))
53
+ std = np.load(pjoin(self.meta_dir, 'std.npy'))
54
+
55
+ if is_test:
56
+ split_file = pjoin(self.data_root, 'test.txt')
57
+ else:
58
+ split_file = pjoin(self.data_root, 'val.txt')
59
+
60
+ min_motion_len = 40 if self.dataset_name =='t2m' else 24
61
+ # min_motion_len = 64
62
+
63
+ joints_num = self.joints_num
64
+
65
+ data_dict = {}
66
+ id_list = []
67
+ with cs.open(split_file, 'r') as f:
68
+ for line in f.readlines():
69
+ id_list.append(line.strip())
70
+
71
+ new_name_list = []
72
+ length_list = []
73
+ for name in tqdm(id_list):
74
+ try:
75
+ motion = np.load(pjoin(self.motion_dir, name + '.npy'))
76
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
77
+ continue
78
+ text_data = []
79
+ flag = False
80
+ with cs.open(pjoin(self.text_dir, name + '.txt')) as f:
81
+ for line in f.readlines():
82
+ text_dict = {}
83
+ line_split = line.strip().split('#')
84
+ caption = line_split[0]
85
+ tokens = line_split[1].split(' ')
86
+ f_tag = float(line_split[2])
87
+ to_tag = float(line_split[3])
88
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
89
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
90
+
91
+ text_dict['caption'] = caption
92
+ text_dict['tokens'] = tokens
93
+ if f_tag == 0.0 and to_tag == 0.0:
94
+ flag = True
95
+ text_data.append(text_dict)
96
+ else:
97
+ try:
98
+ n_motion = motion[int(f_tag*fps) : int(to_tag*fps)]
99
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
100
+ continue
101
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
102
+ while new_name in data_dict:
103
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
104
+ data_dict[new_name] = {'motion': n_motion,
105
+ 'length': len(n_motion),
106
+ 'text':[text_dict]}
107
+ new_name_list.append(new_name)
108
+ length_list.append(len(n_motion))
109
+ except:
110
+ print(line_split)
111
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
112
+ # break
113
+
114
+ if flag:
115
+ data_dict[name] = {'motion': motion,
116
+ 'length': len(motion),
117
+ 'text': text_data}
118
+ new_name_list.append(name)
119
+ length_list.append(len(motion))
120
+ except Exception as e:
121
+ # print(e)
122
+ pass
123
+
124
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
125
+ self.mean = mean
126
+ self.std = std
127
+ self.length_arr = np.array(length_list)
128
+ self.data_dict = data_dict
129
+ self.name_list = name_list
130
+ self.reset_max_len(self.max_length)
131
+ self.shuffle = shuffle
132
+
133
+ def reset_max_len(self, length):
134
+ assert length <= self.max_motion_length
135
+ self.pointer = np.searchsorted(self.length_arr, length)
136
+ print("Pointer Pointing at %d"%self.pointer)
137
+ self.max_length = length
138
+
139
+ def inv_transform(self, data):
140
+ return data * self.std + self.mean
141
+
142
+ def forward_transform(self, data):
143
+ return (data - self.mean) / self.std
144
+
145
+ def __len__(self):
146
+ return len(self.data_dict) - self.pointer
147
+
148
+ def __getitem__(self, item):
149
+ idx = self.pointer + item
150
+ name = self.name_list[idx]
151
+ data = self.data_dict[name]
152
+ # data = self.data_dict[self.name_list[idx]]
153
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
154
+ # Randomly select a caption
155
+ text_data = random.choice(text_list)
156
+ caption, tokens = text_data['caption'], text_data['tokens']
157
+
158
+ if len(tokens) < self.max_text_len:
159
+ # pad with "unk"
160
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
161
+ sent_len = len(tokens)
162
+ tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - sent_len)
163
+ else:
164
+ # crop
165
+ tokens = tokens[:self.max_text_len]
166
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
167
+ sent_len = len(tokens)
168
+ pos_one_hots = []
169
+ word_embeddings = []
170
+ for token in tokens:
171
+ word_emb, pos_oh = self.w_vectorizer[token]
172
+ pos_one_hots.append(pos_oh[None, :])
173
+ word_embeddings.append(word_emb[None, :])
174
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
175
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
176
+
177
+ if self.unit_length < 10 and self.shuffle:
178
+ coin2 = np.random.choice(['single', 'single', 'double'])
179
+ else:
180
+ coin2 = 'single'
181
+
182
+ if coin2 == 'double':
183
+ m_length = (m_length // self.unit_length - 1) * self.unit_length
184
+ elif coin2 == 'single':
185
+ m_length = (m_length // self.unit_length) * self.unit_length
186
+ idx = random.randint(0, len(motion) - m_length)
187
+ motion = motion[idx:idx+m_length]
188
+
189
+ "Z Normalization"
190
+ motion = (motion - self.mean) / self.std
191
+
192
+ if m_length < self.max_motion_length and self.shuffle:
193
+ motion = np.concatenate([motion,
194
+ np.zeros((self.max_motion_length - m_length, motion.shape[1]))
195
+ ], axis=0)
196
+
197
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens), name
198
+
199
+
200
+
201
+
202
+ def DATALoader(dataset_name, is_test,
203
+ batch_size, w_vectorizer,
204
+ num_workers = 8, unit_length = 4, shuffle=True) :
205
+
206
+ val_loader = torch.utils.data.DataLoader(Text2MotionDataset(dataset_name, is_test, w_vectorizer, unit_length=unit_length, shuffle=shuffle),
207
+ batch_size,
208
+ shuffle = shuffle,
209
+ num_workers=num_workers,
210
+ collate_fn=collate_fn,
211
+ drop_last = True)
212
+ return val_loader
213
+
214
+
215
+ def cycle(iterable):
216
+ while True:
217
+ for x in iterable:
218
+ yield x
dataset/dataset_TM_train.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ from os.path import join as pjoin
5
+ import random
6
+ import codecs as cs
7
+ from tqdm import tqdm
8
+ import utils.paramUtil as paramUtil
9
+ from torch.utils.data._utils.collate import default_collate
10
+ import random
11
+ import math
12
+
13
+ def collate_fn(batch):
14
+ batch.sort(key=lambda x: x[3], reverse=True)
15
+ return default_collate(batch)
16
+
17
+
18
+ '''For use of training text-2-motion generative model'''
19
+ class Text2MotionDataset(data.Dataset):
20
+ def __init__(self, dataset_name, feat_bias = 5, unit_length = 4, codebook_size = 1024, tokenizer_name=None, up_low_sep=False):
21
+
22
+ self.max_length = 64
23
+ self.pointer = 0
24
+ self.dataset_name = dataset_name
25
+ self.up_low_sep = up_low_sep
26
+
27
+ self.unit_length = unit_length
28
+ # self.mot_start_idx = codebook_size
29
+ self.mot_end_idx = codebook_size
30
+ self.mot_pad_idx = codebook_size + 1 # [TODO] I think 513 (codebook_size+1) can be what ever, it will be croped out
31
+ if dataset_name == 't2m':
32
+ self.data_root = './dataset/HumanML3D'
33
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
34
+ self.text_dir = pjoin(self.data_root, 'texts')
35
+ self.joints_num = 22
36
+ radius = 4
37
+ fps = 20
38
+ self.max_motion_length = 26 if unit_length == 8 else 50
39
+ dim_pose = 263
40
+ kinematic_chain = paramUtil.t2m_kinematic_chain
41
+ elif dataset_name == 'kit':
42
+ self.data_root = './dataset/KIT-ML'
43
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
44
+ self.text_dir = pjoin(self.data_root, 'texts')
45
+ self.joints_num = 21
46
+ radius = 240 * 8
47
+ fps = 12.5
48
+ dim_pose = 251
49
+ self.max_motion_length = 26 if unit_length == 8 else 50
50
+ kinematic_chain = paramUtil.kit_kinematic_chain
51
+
52
+ split_file = pjoin(self.data_root, 'train.txt')
53
+
54
+
55
+ id_list = []
56
+ with cs.open(split_file, 'r') as f:
57
+ for line in f.readlines():
58
+ id_list.append(line.strip())
59
+
60
+ new_name_list = []
61
+ data_dict = {}
62
+ for name in tqdm(id_list):
63
+ try:
64
+ m_token_list = np.load(pjoin(tokenizer_name, '%s.npy'%name))
65
+
66
+ # Read text
67
+ with cs.open(pjoin(self.text_dir, name + '.txt')) as f:
68
+ text_data = []
69
+ flag = False
70
+ lines = f.readlines()
71
+
72
+ for line in lines:
73
+ try:
74
+ text_dict = {}
75
+ line_split = line.strip().split('#')
76
+ caption = line_split[0]
77
+ t_tokens = line_split[1].split(' ')
78
+ f_tag = float(line_split[2])
79
+ to_tag = float(line_split[3])
80
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
81
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
82
+
83
+ text_dict['caption'] = caption
84
+ text_dict['tokens'] = t_tokens
85
+ if f_tag == 0.0 and to_tag == 0.0:
86
+ flag = True
87
+ text_data.append(text_dict)
88
+ else:
89
+ # [INFO] Check with KIT, doesn't come here that mean f_tag & to_tag are 0.0 (tag for caption from-to frames)
90
+ m_token_list_new = [tokens[int(f_tag*fps/unit_length) : int(to_tag*fps/unit_length)] for tokens in m_token_list if int(f_tag*fps/unit_length) < int(to_tag*fps/unit_length)]
91
+
92
+ if len(m_token_list_new) == 0:
93
+ continue
94
+ new_name = '%s_%f_%f'%(name, f_tag, to_tag)
95
+
96
+ data_dict[new_name] = {'m_token_list': m_token_list_new,
97
+ 'text':[text_dict]}
98
+ new_name_list.append(new_name)
99
+ except:
100
+ pass
101
+
102
+ if flag:
103
+ data_dict[name] = {'m_token_list': m_token_list,
104
+ 'text':text_data}
105
+ new_name_list.append(name)
106
+ except:
107
+ pass
108
+ self.data_dict = data_dict
109
+ self.name_list = new_name_list
110
+
111
+ def __len__(self):
112
+ return len(self.data_dict)
113
+
114
+ def __getitem__(self, item):
115
+ data = self.data_dict[self.name_list[item]]
116
+ m_token_list, text_list = data['m_token_list'], data['text']
117
+ m_tokens = random.choice(m_token_list)
118
+
119
+ text_data = random.choice(text_list)
120
+ caption= text_data['caption']
121
+
122
+
123
+ coin = np.random.choice([False, False, True])
124
+ # print(len(m_tokens))
125
+ if coin:
126
+ # drop one token at the head or tail
127
+ coin2 = np.random.choice([True, False])
128
+ if coin2:
129
+ m_tokens = m_tokens[:-1]
130
+ else:
131
+ m_tokens = m_tokens[1:]
132
+ m_tokens_len = m_tokens.shape[0]
133
+
134
+ if self.up_low_sep:
135
+ new_len = random.randint(20, self.max_motion_length-1)
136
+ len_mult = math.ceil(new_len/m_tokens_len)
137
+ m_tokens = np.tile(m_tokens, (len_mult, 1))[:new_len]
138
+ m_tokens_len = new_len
139
+ if m_tokens_len+1 < self.max_motion_length:
140
+ m_tokens = np.concatenate([m_tokens, np.ones((1, 2), dtype=int) * self.mot_end_idx, np.ones((self.max_motion_length-1-m_tokens_len, 2), dtype=int) * self.mot_pad_idx], axis=0)
141
+ else:
142
+ m_tokens = np.concatenate([m_tokens, np.ones((1, 2), dtype=int) * self.mot_end_idx], axis=0)
143
+ else:
144
+ if m_tokens_len+1 < self.max_motion_length:
145
+ m_tokens = np.concatenate([m_tokens, np.ones((1), dtype=int) * self.mot_end_idx, np.ones((self.max_motion_length-1-m_tokens_len), dtype=int) * self.mot_pad_idx], axis=0)
146
+ else:
147
+ m_tokens = np.concatenate([m_tokens, np.ones((1), dtype=int) * self.mot_end_idx], axis=0)
148
+ return caption, m_tokens, m_tokens_len
149
+
150
+
151
+
152
+
153
+ def DATALoader(dataset_name,
154
+ batch_size, codebook_size, tokenizer_name, unit_length=4,
155
+ num_workers = 8, up_low_sep=False) :
156
+
157
+ train_loader = torch.utils.data.DataLoader(Text2MotionDataset(dataset_name, codebook_size = codebook_size, tokenizer_name = tokenizer_name, unit_length=unit_length, up_low_sep=up_low_sep),
158
+ batch_size,
159
+ shuffle=True,
160
+ num_workers=num_workers,
161
+ #collate_fn=collate_fn,
162
+ drop_last = True)
163
+
164
+
165
+ return train_loader
166
+
167
+
168
+ def cycle(iterable):
169
+ while True:
170
+ for x in iterable:
171
+ yield x
172
+
173
+
dataset/dataset_VQ.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ from os.path import join as pjoin
5
+ import random
6
+ import codecs as cs
7
+ from tqdm import tqdm
8
+
9
+
10
+
11
+ class VQMotionDataset(data.Dataset):
12
+ def __init__(self, dataset_name, window_size = 64, unit_length = 4):
13
+ self.window_size = window_size
14
+ self.unit_length = unit_length
15
+ self.dataset_name = dataset_name
16
+
17
+ if dataset_name == 't2m':
18
+ self.data_root = './dataset/HumanML3D'
19
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
20
+ self.text_dir = pjoin(self.data_root, 'texts')
21
+ self.joints_num = 22
22
+ self.max_motion_length = 196
23
+ self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
24
+
25
+ elif dataset_name == 'kit':
26
+ self.data_root = './dataset/KIT-ML'
27
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
28
+ self.text_dir = pjoin(self.data_root, 'texts')
29
+ self.joints_num = 21
30
+
31
+ self.max_motion_length = 196
32
+ self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
33
+
34
+ joints_num = self.joints_num
35
+
36
+ mean = np.load(pjoin(self.meta_dir, 'mean.npy'))
37
+ std = np.load(pjoin(self.meta_dir, 'std.npy'))
38
+
39
+ split_file = pjoin(self.data_root, 'train.txt')
40
+
41
+ self.data = []
42
+ self.lengths = []
43
+ id_list = []
44
+ with cs.open(split_file, 'r') as f:
45
+ for line in f.readlines():
46
+ id_list.append(line.strip())
47
+
48
+ for name in tqdm(id_list):
49
+ try:
50
+ motion = np.load(pjoin(self.motion_dir, name + '.npy'))
51
+ if motion.shape[0] < self.window_size:
52
+ continue
53
+ self.lengths.append(motion.shape[0] - self.window_size)
54
+ self.data.append(motion)
55
+ except:
56
+ # Some motion may not exist in KIT dataset
57
+ pass
58
+
59
+
60
+ self.mean = mean
61
+ self.std = std
62
+ print("Total number of motions {}".format(len(self.data)))
63
+
64
+ def inv_transform(self, data):
65
+ return data * self.std + self.mean
66
+
67
+ def compute_sampling_prob(self) :
68
+
69
+ prob = np.array(self.lengths, dtype=np.float32)
70
+ prob /= np.sum(prob)
71
+ return prob
72
+
73
+ def __len__(self):
74
+ return len(self.data)
75
+
76
+ def __getitem__(self, item):
77
+ motion = self.data[item]
78
+
79
+ idx = random.randint(0, len(motion) - self.window_size)
80
+
81
+ motion = motion[idx:idx+self.window_size]
82
+ "Z Normalization"
83
+ motion = (motion - self.mean) / self.std
84
+
85
+ return motion
86
+
87
+ def DATALoader(dataset_name,
88
+ batch_size,
89
+ num_workers = 8,
90
+ window_size = 64,
91
+ unit_length = 4):
92
+
93
+ trainSet = VQMotionDataset(dataset_name, window_size=window_size, unit_length=unit_length)
94
+ prob = trainSet.compute_sampling_prob()
95
+ sampler = torch.utils.data.WeightedRandomSampler(prob, num_samples = len(trainSet) * 1000, replacement=True)
96
+ train_loader = torch.utils.data.DataLoader(trainSet,
97
+ batch_size,
98
+ shuffle=True,
99
+ #sampler=sampler,
100
+ num_workers=num_workers,
101
+ #collate_fn=collate_fn,
102
+ drop_last = True)
103
+
104
+ return train_loader
105
+
106
+ def cycle(iterable):
107
+ while True:
108
+ for x in iterable:
109
+ yield x
dataset/dataset_tokenize.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ from os.path import join as pjoin
5
+ import random
6
+ import codecs as cs
7
+ from tqdm import tqdm
8
+
9
+
10
+
11
+ class VQMotionDataset(data.Dataset):
12
+ def __init__(self, dataset_name, feat_bias = 5, window_size = 64, unit_length = 8, fill_max_len=False):
13
+ self.window_size = window_size
14
+ self.unit_length = unit_length
15
+ self.feat_bias = feat_bias
16
+ self.fill_max_len = fill_max_len
17
+
18
+ self.dataset_name = dataset_name
19
+ min_motion_len = 40 if dataset_name =='t2m' else 24
20
+
21
+ if dataset_name == 't2m':
22
+ self.data_root = './dataset/HumanML3D'
23
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
24
+ self.text_dir = pjoin(self.data_root, 'texts')
25
+ self.joints_num = 22
26
+ radius = 4
27
+ fps = 20
28
+ self.max_motion_length = 196
29
+ self.dim_pose = 263
30
+ self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
31
+ #kinematic_chain = paramUtil.t2m_kinematic_chain
32
+ elif dataset_name == 'kit':
33
+ self.data_root = './dataset/KIT-ML'
34
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
35
+ self.text_dir = pjoin(self.data_root, 'texts')
36
+ self.joints_num = 21
37
+ radius = 240 * 8
38
+ fps = 12.5
39
+ self.dim_pose = 251
40
+ self.max_motion_length = 196
41
+ self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
42
+ #kinematic_chain = paramUtil.kit_kinematic_chain
43
+
44
+ joints_num = self.joints_num
45
+
46
+ mean = np.load(pjoin(self.meta_dir, 'mean.npy'))
47
+ std = np.load(pjoin(self.meta_dir, 'std.npy'))
48
+
49
+ split_file = pjoin(self.data_root, 'train.txt')
50
+
51
+ data_dict = {}
52
+ id_list = []
53
+ with cs.open(split_file, 'r') as f:
54
+ for line in f.readlines():
55
+ id_list.append(line.strip())
56
+
57
+ new_name_list = []
58
+ length_list = []
59
+ for name in tqdm(id_list):
60
+ try:
61
+ motion = np.load(pjoin(self.motion_dir, name + '.npy'))
62
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
63
+ continue
64
+
65
+ data_dict[name] = {'motion': motion,
66
+ 'length': len(motion),
67
+ 'name': name}
68
+ new_name_list.append(name)
69
+ length_list.append(len(motion))
70
+ except:
71
+ # Some motion may not exist in KIT dataset
72
+ pass
73
+
74
+
75
+ self.mean = mean
76
+ self.std = std
77
+ self.length_arr = np.array(length_list)
78
+ self.data_dict = data_dict
79
+ self.name_list = new_name_list
80
+
81
+ def inv_transform(self, data):
82
+ return data * self.std + self.mean
83
+
84
+ def __len__(self):
85
+ return len(self.data_dict)
86
+
87
+ def __getitem__(self, item):
88
+ name = self.name_list[item]
89
+ data = self.data_dict[name]
90
+ motion, m_length = data['motion'], data['length']
91
+
92
+ m_length = (m_length // self.unit_length) * self.unit_length
93
+
94
+ idx = random.randint(0, len(motion) - m_length)
95
+ motion = motion[idx:idx+m_length]
96
+
97
+ if self.fill_max_len:
98
+ motion_zero = np.zeros((self.max_motion_length, self.dim_pose))
99
+ motion_zero[:m_length] = motion
100
+ motion = motion_zero
101
+ motion = (motion - self.mean) / self.std
102
+ return motion, m_length
103
+
104
+ "Z Normalization"
105
+ motion = (motion - self.mean) / self.std
106
+
107
+ return motion, name
108
+
109
+ def DATALoader(dataset_name,
110
+ batch_size = 1,
111
+ num_workers = 8, unit_length = 4, shuffle=True) :
112
+
113
+ train_loader = torch.utils.data.DataLoader(VQMotionDataset(dataset_name, unit_length=unit_length, fill_max_len=batch_size!=1),
114
+ batch_size,
115
+ shuffle=shuffle,
116
+ num_workers=num_workers,
117
+ #collate_fn=collate_fn,
118
+ drop_last = True)
119
+
120
+ return train_loader
121
+
122
+ def cycle(iterable):
123
+ while True:
124
+ for x in iterable:
125
+ yield x
environment.yml ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: MMM
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=conda_forge
9
+ - _openmp_mutex=4.5=2_gnu
10
+ - abseil-cpp=20230802.0=h6a678d5_2
11
+ - absl-py=2.1.0=py312h06a4308_0
12
+ - aiohttp=3.9.5=py312h5eee18b_0
13
+ - aiosignal=1.2.0=pyhd3eb1b0_0
14
+ - asttokens=2.4.1=pyhd8ed1ab_0
15
+ - attrs=23.1.0=py312h06a4308_0
16
+ - blas=1.0=mkl
17
+ - blinker=1.6.2=py312h06a4308_0
18
+ - brotli=1.0.9=h5eee18b_8
19
+ - brotli-bin=1.0.9=h5eee18b_8
20
+ - brotli-python=1.0.9=py312h6a678d5_8
21
+ - bzip2=1.0.8=h5eee18b_6
22
+ - c-ares=1.19.1=h5eee18b_0
23
+ - ca-certificates=2024.6.2=hbcca054_0
24
+ - cachetools=5.3.3=py312h06a4308_0
25
+ - certifi=2024.2.2=py312h06a4308_0
26
+ - cffi=1.16.0=py312h5eee18b_1
27
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
28
+ - click=8.1.7=py312h06a4308_0
29
+ - comm=0.2.2=pyhd8ed1ab_0
30
+ - contourpy=1.2.0=py312hdb19cb5_0
31
+ - cryptography=42.0.5=py312hdda0065_1
32
+ - cuda-cudart=11.8.89=0
33
+ - cuda-cupti=11.8.87=0
34
+ - cuda-libraries=11.8.0=0
35
+ - cuda-nvrtc=11.8.89=0
36
+ - cuda-nvtx=11.8.86=0
37
+ - cuda-runtime=11.8.0=0
38
+ - cuda-version=12.5=3
39
+ - cycler=0.11.0=pyhd3eb1b0_0
40
+ - cyrus-sasl=2.1.28=h52b45da_1
41
+ - dbus=1.13.18=hb2f20db_0
42
+ - debugpy=1.6.7=py312h6a678d5_0
43
+ - decorator=5.1.1=pyhd8ed1ab_0
44
+ - exceptiongroup=1.2.0=pyhd8ed1ab_2
45
+ - executing=2.0.1=pyhd8ed1ab_0
46
+ - expat=2.6.2=h6a678d5_0
47
+ - ffmpeg=4.3=hf484d3e_0
48
+ - filelock=3.13.1=py312h06a4308_0
49
+ - fontconfig=2.14.1=h4c34cd2_2
50
+ - fonttools=4.51.0=py312h5eee18b_0
51
+ - freetype=2.12.1=h4a9f257_0
52
+ - frozenlist=1.4.0=py312h5eee18b_0
53
+ - glib=2.78.4=h6a678d5_0
54
+ - glib-tools=2.78.4=h6a678d5_0
55
+ - gmp=6.2.1=h295c915_3
56
+ - gnutls=3.6.15=he1e5248_0
57
+ - google-auth=2.29.0=py312h06a4308_0
58
+ - google-auth-oauthlib=0.4.1=py_2
59
+ - grpc-cpp=1.48.2=he1ff14a_4
60
+ - grpcio=1.48.2=py312he1ff14a_4
61
+ - gst-plugins-base=1.14.1=h6a678d5_1
62
+ - gstreamer=1.14.1=h5eee18b_1
63
+ - gtest=1.14.0=hdb19cb5_1
64
+ - icu=73.1=h6a678d5_0
65
+ - idna=3.7=py312h06a4308_0
66
+ - importlib-metadata=7.1.0=pyha770c72_0
67
+ - importlib_metadata=7.1.0=hd8ed1ab_0
68
+ - intel-openmp=2023.1.0=hdb19cb5_46306
69
+ - ipykernel=6.29.3=pyhd33586a_0
70
+ - ipython=8.25.0=pyh707e725_0
71
+ - jedi=0.19.1=pyhd8ed1ab_0
72
+ - jinja2=3.1.4=py312h06a4308_0
73
+ - jpeg=9e=h5eee18b_1
74
+ - jupyter_client=8.6.2=pyhd8ed1ab_0
75
+ - jupyter_core=5.5.0=py312h06a4308_0
76
+ - kiwisolver=1.4.4=py312h6a678d5_0
77
+ - krb5=1.20.1=h143b758_1
78
+ - lame=3.100=h7b6447c_0
79
+ - lcms2=2.12=h3be6417_0
80
+ - ld_impl_linux-64=2.38=h1181459_1
81
+ - lerc=3.0=h295c915_0
82
+ - libbrotlicommon=1.0.9=h5eee18b_8
83
+ - libbrotlidec=1.0.9=h5eee18b_8
84
+ - libbrotlienc=1.0.9=h5eee18b_8
85
+ - libclang=14.0.6=default_hc6dbbc7_1
86
+ - libclang13=14.0.6=default_he11475f_1
87
+ - libcublas=11.11.3.6=0
88
+ - libcufft=10.9.0.58=0
89
+ - libcufile=1.10.0.4=0
90
+ - libcups=2.4.2=h2d74bed_1
91
+ - libcurand=10.3.6.39=0
92
+ - libcusolver=11.4.1.48=0
93
+ - libcusparse=11.7.5.86=0
94
+ - libdeflate=1.17=h5eee18b_1
95
+ - libedit=3.1.20230828=h5eee18b_0
96
+ - libffi=3.4.4=h6a678d5_1
97
+ - libgcc-ng=13.2.0=h77fa898_7
98
+ - libgfortran-ng=11.2.0=h00389a5_1
99
+ - libgfortran5=11.2.0=h1234567_1
100
+ - libglib=2.78.4=hdc74915_0
101
+ - libgomp=13.2.0=h77fa898_7
102
+ - libiconv=1.16=h5eee18b_3
103
+ - libidn2=2.3.4=h5eee18b_0
104
+ - libjpeg-turbo=2.0.0=h9bf148f_0
105
+ - libllvm14=14.0.6=hdb19cb5_3
106
+ - libnpp=11.8.0.86=0
107
+ - libnvjpeg=11.9.0.86=0
108
+ - libpng=1.6.39=h5eee18b_0
109
+ - libpq=12.17=hdbd6064_0
110
+ - libprotobuf=3.20.3=he621ea3_0
111
+ - libsodium=1.0.18=h36c2ea0_1
112
+ - libstdcxx-ng=11.2.0=h1234567_1
113
+ - libtasn1=4.19.0=h5eee18b_0
114
+ - libtiff=4.5.1=h6a678d5_0
115
+ - libunistring=0.9.10=h27cfd23_0
116
+ - libuuid=1.41.5=h5eee18b_0
117
+ - libwebp-base=1.3.2=h5eee18b_0
118
+ - libxcb=1.15=h7f8727e_0
119
+ - libxkbcommon=1.0.1=h5eee18b_1
120
+ - libxml2=2.10.4=hfdd30dd_2
121
+ - llvm-openmp=14.0.6=h9e868ea_0
122
+ - lz4-c=1.9.4=h6a678d5_1
123
+ - markdown=3.4.1=py312h06a4308_0
124
+ - markupsafe=2.1.3=py312h5eee18b_0
125
+ - matplotlib=3.8.4=py312h06a4308_0
126
+ - matplotlib-base=3.8.4=py312h526ad5a_0
127
+ - matplotlib-inline=0.1.7=pyhd8ed1ab_0
128
+ - mkl=2023.1.0=h213fc3f_46344
129
+ - mkl-service=2.4.0=py312h5eee18b_1
130
+ - mkl_fft=1.3.8=py312h5eee18b_0
131
+ - mkl_random=1.2.4=py312hdb19cb5_0
132
+ - mpmath=1.3.0=py312h06a4308_0
133
+ - multidict=6.0.4=py312h5eee18b_0
134
+ - mysql=5.7.24=h721c034_2
135
+ - ncurses=6.4=h6a678d5_0
136
+ - nest-asyncio=1.6.0=pyhd8ed1ab_0
137
+ - nettle=3.7.3=hbbd107a_1
138
+ - networkx=3.1=py312h06a4308_0
139
+ - numpy=1.26.4=py312hc5e2394_0
140
+ - numpy-base=1.26.4=py312h0da6c21_0
141
+ - oauthlib=3.2.2=py312h06a4308_0
142
+ - openh264=2.1.1=h4ff587b_0
143
+ - openjpeg=2.4.0=h3ad879b_0
144
+ - openssl=3.3.0=h4ab18f5_3
145
+ - packaging=23.2=py312h06a4308_0
146
+ - parso=0.8.4=pyhd8ed1ab_0
147
+ - pcre2=10.42=hebb0a14_1
148
+ - pexpect=4.9.0=pyhd8ed1ab_0
149
+ - pickleshare=0.7.5=py_1003
150
+ - pillow=10.3.0=py312h5eee18b_0
151
+ - pip=24.0=py312h06a4308_0
152
+ - platformdirs=4.2.2=pyhd8ed1ab_0
153
+ - plotly=5.19.0=py312he106c6f_0
154
+ - ply=3.11=py312h06a4308_1
155
+ - prompt-toolkit=3.0.42=pyha770c72_0
156
+ - protobuf=3.20.3=py312h6a678d5_0
157
+ - psutil=5.9.0=py312h5eee18b_0
158
+ - ptyprocess=0.7.0=pyhd3deb0d_0
159
+ - pure_eval=0.2.2=pyhd8ed1ab_0
160
+ - pyasn1=0.4.8=pyhd3eb1b0_0
161
+ - pyasn1-modules=0.2.8=py_0
162
+ - pybind11-abi=5=hd3eb1b0_0
163
+ - pycparser=2.21=pyhd3eb1b0_0
164
+ - pygments=2.18.0=pyhd8ed1ab_0
165
+ - pyjwt=2.8.0=py312h06a4308_0
166
+ - pyopenssl=24.0.0=py312h06a4308_0
167
+ - pyparsing=3.0.9=py312h06a4308_0
168
+ - pyqt=5.15.10=py312h6a678d5_0
169
+ - pyqt5-sip=12.13.0=py312h5eee18b_0
170
+ - pysocks=1.7.1=py312h06a4308_0
171
+ - python=3.12.3=h996f2a0_1
172
+ - python-dateutil=2.9.0=pyhd8ed1ab_0
173
+ - pytorch=2.3.0=py3.12_cuda11.8_cudnn8.7.0_0
174
+ - pytorch-cuda=11.8=h7e8668a_5
175
+ - pytorch-mutex=1.0=cuda
176
+ - pyyaml=6.0.1=py312h5eee18b_0
177
+ - pyzmq=25.1.2=py312h6a678d5_0
178
+ - qt-main=5.15.2=h53bd1ea_10
179
+ - re2=2022.04.01=h295c915_0
180
+ - readline=8.2=h5eee18b_0
181
+ - requests=2.32.2=py312h06a4308_0
182
+ - requests-oauthlib=1.3.0=py_0
183
+ - rsa=4.7.2=pyhd3eb1b0_1
184
+ - scipy=1.13.0=py312hc5e2394_0
185
+ - setuptools=69.5.1=py312h06a4308_0
186
+ - sip=6.7.12=py312h6a678d5_0
187
+ - six=1.16.0=pyhd3eb1b0_1
188
+ - sqlite=3.45.3=h5eee18b_0
189
+ - stack_data=0.6.2=pyhd8ed1ab_0
190
+ - sympy=1.12=py312h06a4308_0
191
+ - tbb=2021.8.0=hdb19cb5_0
192
+ - tenacity=8.2.2=py312h06a4308_1
193
+ - tensorboard=2.6.0=py_0
194
+ - tensorboard-plugin-wit=1.6.0=py_0
195
+ - tk=8.6.14=h39e8969_0
196
+ - torchvision=0.18.0=py312_cu118
197
+ - tornado=6.3.3=py312h5eee18b_0
198
+ - traitlets=5.14.3=pyhd8ed1ab_0
199
+ - typing_extensions=4.11.0=py312h06a4308_0
200
+ - tzdata=2024a=h04d1e81_0
201
+ - unicodedata2=15.1.0=py312h5eee18b_0
202
+ - urllib3=2.2.1=py312h06a4308_0
203
+ - wcwidth=0.2.13=pyhd8ed1ab_0
204
+ - werkzeug=3.0.3=py312h06a4308_0
205
+ - wheel=0.43.0=py312h06a4308_0
206
+ - xz=5.4.6=h5eee18b_1
207
+ - yaml=0.2.5=h7b6447c_0
208
+ - yarl=1.9.3=py312h5eee18b_0
209
+ - zeromq=4.3.5=h6a678d5_0
210
+ - zipp=3.17.0=pyhd8ed1ab_0
211
+ - zlib=1.2.13=h5eee18b_1
212
+ - zstd=1.5.5=hc292b87_2
213
+ - pip:
214
+ - beautifulsoup4==4.12.3
215
+ - einops==0.8.0
216
+ - fastjsonschema==2.19.1
217
+ - fsspec==2024.5.0
218
+ - ftfy==6.2.0
219
+ - gdown==5.2.0
220
+ - jsonschema==4.22.0
221
+ - jsonschema-specifications==2023.12.1
222
+ - nbformat==5.10.4
223
+ - referencing==0.35.1
224
+ - regex==2024.5.15
225
+ - rpds-py==0.18.1
226
+ - soupsieve==2.5
227
+ - tqdm==4.66.4
exit/t2m-mean.npy ADDED
Binary file (2.23 kB). View file
 
exit/t2m-std.npy ADDED
Binary file (2.23 kB). View file
 
exit/utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_model(model):
2
+ if hasattr(model, 'module'):
3
+ return model.module
4
+ return model
5
+
6
+ import numpy as np
7
+ import torch
8
+ from utils.motion_process import recover_from_ric
9
+ import copy
10
+ import plotly.graph_objects as go
11
+ import shutil
12
+ import datetime
13
+ import os
14
+ import math
15
+
16
+ kit_bone = [[0, 11], [11, 12], [12, 13], [13, 14], [14, 15], [0, 16], [16, 17], [17, 18], [18, 19], [19, 20], [0, 1], [1, 2], [2, 3], [3, 4], [3, 5], [5, 6], [6, 7], [3, 8], [8, 9], [9, 10]]
17
+ t2m_bone = [[0,2], [2,5],[5,8],[8,11],
18
+ [0,1],[1,4],[4,7],[7,10],
19
+ [0,3],[3,6],[6,9],[9,12],[12,15],
20
+ [9,14],[14,17],[17,19],[19,21],
21
+ [9,13],[13,16],[16,18],[18,20]]
22
+ kit_kit_bone = kit_bone + (np.array(kit_bone)+21).tolist()
23
+ t2m_t2m_bone = t2m_bone + (np.array(t2m_bone)+22).tolist()
24
+
25
+ def axis_standard(skeleton):
26
+ skeleton = skeleton.copy()
27
+ # skeleton = -skeleton
28
+ # skeleton[:, :, 0] *= -1
29
+ # xyz => zxy
30
+ skeleton[..., [1, 2]] = skeleton[..., [2, 1]]
31
+ skeleton[..., [0, 1]] = skeleton[..., [1, 0]]
32
+ return skeleton
33
+
34
+ def visualize_2motions(motion1, std, mean, dataset_name, length, motion2=None, save_path=None):
35
+ motion1 = motion1 * std + mean
36
+ if motion2 is not None:
37
+ motion2 = motion2 * std + mean
38
+ if dataset_name == 'kit':
39
+ first_total_standard = 60
40
+ bone_link = kit_bone
41
+ if motion2 is not None:
42
+ bone_link = kit_kit_bone
43
+ joints_num = 21
44
+ scale = 1/1000
45
+ else:
46
+ first_total_standard = 63
47
+ bone_link = t2m_bone
48
+ if motion2 is not None:
49
+ bone_link = t2m_t2m_bone
50
+ joints_num = 22
51
+ scale = 1#/1000
52
+ joint1 = recover_from_ric(torch.from_numpy(motion1).float(), joints_num).numpy()
53
+ if motion2 is not None:
54
+ joint2 = recover_from_ric(torch.from_numpy(motion2).float(), joints_num).numpy()
55
+ joint_original_forward = np.concatenate((joint1, joint2), axis=1)
56
+ else:
57
+ joint_original_forward = joint1
58
+ animate3d(joint_original_forward[:length]*scale,
59
+ BONE_LINK=bone_link,
60
+ first_total_standard=first_total_standard,
61
+ save_path=save_path) # 'init.html'
62
+
63
+ def animate3d(skeleton, BONE_LINK=t2m_bone, first_total_standard=-1, root_path=None, root_path2=None, save_path=None, axis_standard=axis_standard, axis_visible=True):
64
+ # [animation] https://community.plotly.com/t/3d-scatter-animation/46368/6
65
+
66
+ SHIFT_SCALE = 0
67
+ START_FRAME = 0
68
+ NUM_FRAMES = skeleton.shape[0]
69
+ skeleton = skeleton[START_FRAME:NUM_FRAMES+START_FRAME]
70
+ skeleton = axis_standard(skeleton)
71
+ if BONE_LINK is not None:
72
+ # ground truth
73
+ bone_ids = np.array(BONE_LINK)
74
+ _from = skeleton[:, bone_ids[:, 0]]
75
+ _to = skeleton[:, bone_ids[:, 1]]
76
+ # [f 3(from,to,none) d]
77
+ bones = np.empty(
78
+ (_from.shape[0], 3*_from.shape[1], 3), dtype=_from.dtype)
79
+ bones[:, 0::3] = _from
80
+ bones[:, 1::3] = _to
81
+ bones[:, 2::3] = np.full_like(_to, None)
82
+ display_points = bones
83
+ mode = 'lines+markers'
84
+ else:
85
+ display_points = skeleton
86
+ mode = 'markers'
87
+ # follow this thread: https://community.plotly.com/t/3d-scatter-animation/46368/6
88
+ fig = go.Figure(
89
+ data=go.Scatter3d( x=display_points[0, :first_total_standard, 0],
90
+ y=display_points[0, :first_total_standard, 1],
91
+ z=display_points[0, :first_total_standard, 2],
92
+ name='Nodes0',
93
+ mode=mode,
94
+ marker=dict(size=3, color='blue',)),
95
+ layout=go.Layout(
96
+ scene=dict(aspectmode='data',
97
+ camera=dict(eye=dict(x=3, y=0, z=0.1)))
98
+ )
99
+ )
100
+ if first_total_standard != -1:
101
+ fig.add_traces(data=go.Scatter3d(
102
+ x=display_points[0, first_total_standard:, 0],
103
+ y=display_points[0, first_total_standard:, 1],
104
+ z=display_points[0, first_total_standard:, 2],
105
+ name='Nodes1',
106
+ mode=mode,
107
+ marker=dict(size=3, color='red',)))
108
+
109
+ if root_path is not None:
110
+ root_path = axis_standard(root_path)
111
+ fig.add_traces(data=go.Scatter3d(
112
+ x=root_path[:, 0],
113
+ y=root_path[:, 1],
114
+ z=root_path[:, 2],
115
+ name='root_path',
116
+ mode=mode,
117
+ marker=dict(size=2, color='green',)))
118
+ if root_path2 is not None:
119
+ root_path2 = axis_standard(root_path2)
120
+ fig.add_traces(data=go.Scatter3d(
121
+ x=root_path2[:, 0],
122
+ y=root_path2[:, 1],
123
+ z=root_path2[:, 2],
124
+ name='root_path2',
125
+ mode=mode,
126
+ marker=dict(size=2, color='red',)))
127
+
128
+ frames = []
129
+ # frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{0}'})
130
+
131
+ def update_trace(k):
132
+ fig.update_traces(x=display_points[k, :first_total_standard, 0],
133
+ y=display_points[k, :first_total_standard, 1],
134
+ z=display_points[k, :first_total_standard, 2],
135
+ mode=mode,
136
+ marker=dict(size=3, ),
137
+ # traces=[0],
138
+ selector = ({'name':'Nodes0'}))
139
+ if first_total_standard != -1:
140
+ fig.update_traces(x=display_points[k, first_total_standard:, 0],
141
+ y=display_points[k, first_total_standard:, 1],
142
+ z=display_points[k, first_total_standard:, 2],
143
+ mode=mode,
144
+ marker=dict(size=3, ),
145
+ # traces=[0],
146
+ selector = ({'name':'Nodes1'}))
147
+
148
+ for k in range(0, len(display_points)):
149
+ update_trace(k)
150
+ frames.append({'data':copy.deepcopy(fig['data']),'name':f'frame{k}'})
151
+ update_trace(0)
152
+
153
+ # frames = [go.Frame(data=[go.Scatter3d(
154
+ # x=display_points[k, :, 0],
155
+ # y=display_points[k, :, 1],
156
+ # z=display_points[k, :, 2],
157
+ # mode=mode,
158
+ # marker=dict(size=3, ))],
159
+ # traces=[0],
160
+ # name=f'frame{k}'
161
+ # )for k in range(len(display_points))]
162
+
163
+
164
+
165
+ fig.update(frames=frames)
166
+
167
+ def frame_args(duration):
168
+ return {
169
+ "frame": {"duration": duration},
170
+ "mode": "immediate",
171
+ "fromcurrent": True,
172
+ "transition": {"duration": duration, "easing": "linear"},
173
+ }
174
+
175
+ sliders = [
176
+ {"pad": {"b": 10, "t": 60},
177
+ "len": 0.9,
178
+ "x": 0.1,
179
+ "y": 0,
180
+
181
+ "steps": [
182
+ {"args": [[f.name], frame_args(0)],
183
+ "label": str(k),
184
+ "method": "animate",
185
+ } for k, f in enumerate(fig.frames)
186
+ ]
187
+ }
188
+ ]
189
+
190
+ fig.update_layout(
191
+ updatemenus=[{"buttons": [
192
+ {
193
+ "args": [None, frame_args(1000/25)],
194
+ "label": "Play",
195
+ "method": "animate",
196
+ },
197
+ {
198
+ "args": [[None], frame_args(0)],
199
+ "label": "Pause",
200
+ "method": "animate",
201
+ }],
202
+
203
+ "direction": "left",
204
+ "pad": {"r": 10, "t": 70},
205
+ "type": "buttons",
206
+ "x": 0.1,
207
+ "y": 0,
208
+ }
209
+ ],
210
+ sliders=sliders
211
+ )
212
+ range_x, aspect_x = get_range(skeleton, 0)
213
+ range_y, aspect_y = get_range(skeleton, 1)
214
+ range_z, aspect_z = get_range(skeleton, 2)
215
+
216
+ fig.update_layout(scene=dict(xaxis=dict(range=range_x, visible=axis_visible),
217
+ yaxis=dict(range=range_y, visible=axis_visible),
218
+ zaxis=dict(range=range_z, visible=axis_visible)
219
+ ),
220
+ scene_aspectmode='manual',
221
+ scene_aspectratio=dict(
222
+ x=aspect_x, y=aspect_y, z=aspect_z)
223
+ )
224
+
225
+ fig.update_layout(sliders=sliders)
226
+ fig.show()
227
+ if save_path is not None:
228
+ fig.write_html(save_path, auto_open=False, include_plotlyjs='cdn', full_html=False)
229
+
230
+ def get_range(skeleton, index):
231
+ _min, _max = skeleton[:, :, index].min(), skeleton[:, :, index].max()
232
+ return [_min, _max], _max-_min
233
+
234
+ # [INFO] from http://juditacs.github.io/2018/12/27/masked-attention.html
235
+ def generate_src_mask(T, length):
236
+ B = len(length)
237
+ mask = torch.arange(T).repeat(B, 1).to(length.device) < length.unsqueeze(-1)
238
+ return mask
239
+
240
+ def copyComplete(source, target):
241
+ '''https://stackoverflow.com/questions/19787348/copy-file-keep-permissions-and-owner'''
242
+ # copy content, stat-info (mode too), timestamps...
243
+ if os.path.isfile(source):
244
+ shutil.copy2(source, target)
245
+ else:
246
+ shutil.copytree(source, target, ignore=shutil.ignore_patterns('__pycache__'))
247
+ # copy owner and group
248
+ st = os.stat(source)
249
+ os.chown(target, st.st_uid, st.st_gid)
250
+
251
+ data_permission = os.access('/data/epinyoan', os.R_OK | os.W_OK | os.X_OK)
252
+ base_dir = '/data' if data_permission else '/home'
253
+ def init_save_folder(args, copysource=True):
254
+ import glob
255
+ global base_dir
256
+ if args.exp_name != 'TEMP':
257
+ date = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
258
+ args.out_dir = f"./{args.out_dir}/{date}_{args.exp_name}/"
259
+ save_source = f'{args.out_dir}source/'
260
+ os.makedirs(save_source, mode=os.umask(0), exist_ok=False)
261
+ else:
262
+ args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
263
+
264
+ def uniform(shape, device = None):
265
+ return torch.zeros(shape, device = device).float().uniform_(0, 1)
266
+
267
+ def cosine_schedule(t):
268
+ return torch.cos(t * math.pi * 0.5)
269
+
270
+ def log(t, eps = 1e-20):
271
+ return torch.log(t.clamp(min = eps))
272
+
273
+ def gumbel_noise(t):
274
+ noise = torch.zeros_like(t).uniform_(0, 1)
275
+ return -log(-log(noise))
276
+
277
+ def gumbel_sample(t, temperature = 1., dim = -1):
278
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
279
+
280
+ def top_k(logits, thres = 0.9):
281
+ # [INFO] select top 10% samples of last index by fill value to the rest as -inf
282
+ k = math.ceil((1 - thres) * logits.shape[-1])
283
+ val, ind = logits.topk(k, dim = -1)
284
+ probs = torch.full_like(logits, float('-inf'))
285
+ probs.scatter_(2, ind, val)
286
+ return probs
287
+
288
+ # https://github.com/lucidrains/DALLE-pytorch/issues/318
289
+ # https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
290
+ from torch.nn import functional as F
291
+ def top_p(logits, thres = 0.1):
292
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
293
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
294
+
295
+ # # Remove tokens with cumulative probability above the threshold
296
+ sorted_indices_to_remove = cumulative_probs > (1 - thres)
297
+ # Shift the indices to the right to keep also the first token above the threshold
298
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
299
+ sorted_indices_to_remove[..., 0] = 0
300
+
301
+ # # scatter sorted tensors to original indexing
302
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
303
+
304
+ logits[indices_to_remove] = float('-inf')
305
+ return logits
generate.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ import models.vqvae as vqvae
4
+ from models.vqvae_sep import VQVAE_SEP
5
+ import models.t2m_trans as trans
6
+ import models.t2m_trans_uplow as trans_uplow
7
+ import numpy as np
8
+ from exit.utils import visualize_2motions
9
+ import options.option_transformer as option_trans
10
+
11
+
12
+
13
+ ##### ---- CLIP ---- #####
14
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cpu'), jit=False) # Must set jit=False for training
15
+ clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16
16
+ clip_model.eval()
17
+ for p in clip_model.parameters():
18
+ p.requires_grad = False
19
+
20
+ # https://github.com/openai/CLIP/issues/111
21
+ class TextCLIP(torch.nn.Module):
22
+ def __init__(self, model) :
23
+ super(TextCLIP, self).__init__()
24
+ self.model = model
25
+
26
+ def forward(self,text):
27
+ with torch.no_grad():
28
+ word_emb = self.model.token_embedding(text).type(self.model.dtype)
29
+ word_emb = word_emb + self.model.positional_embedding.type(self.model.dtype)
30
+ word_emb = word_emb.permute(1, 0, 2) # NLD -> LND
31
+ word_emb = self.model.transformer(word_emb)
32
+ word_emb = self.model.ln_final(word_emb).permute(1, 0, 2).float()
33
+ enctxt = self.model.encode_text(text).float()
34
+ return enctxt, word_emb
35
+ clip_model = TextCLIP(clip_model)
36
+
37
+ def get_vqvae(args, is_upper_edit):
38
+ if not is_upper_edit:
39
+ return vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
40
+ args.nb_code,
41
+ args.code_dim,
42
+ args.output_emb_width,
43
+ args.down_t,
44
+ args.stride_t,
45
+ args.width,
46
+ args.depth,
47
+ args.dilation_growth_rate)
48
+ else:
49
+ return VQVAE_SEP(args, ## use args to define different parameters in different quantizers
50
+ args.nb_code,
51
+ args.code_dim,
52
+ args.output_emb_width,
53
+ args.down_t,
54
+ args.stride_t,
55
+ args.width,
56
+ args.depth,
57
+ args.dilation_growth_rate,
58
+ moment={'mean': torch.from_numpy(args.mean).float(),
59
+ 'std': torch.from_numpy(args.std).float()},
60
+ sep_decoder=True)
61
+
62
+ def get_maskdecoder(args, vqvae, is_upper_edit):
63
+ tranformer = trans if not is_upper_edit else trans_uplow
64
+ return tranformer.Text2Motion_Transformer(vqvae,
65
+ num_vq=args.nb_code,
66
+ embed_dim=args.embed_dim_gpt,
67
+ clip_dim=args.clip_dim,
68
+ block_size=args.block_size,
69
+ num_layers=args.num_layers,
70
+ num_local_layer=args.num_local_layer,
71
+ n_head=args.n_head_gpt,
72
+ drop_out_rate=args.drop_out_rate,
73
+ fc_rate=args.ff_rate)
74
+
75
+ class MMM(torch.nn.Module):
76
+ def __init__(self, args=None, is_upper_edit=False):
77
+ super().__init__()
78
+ self.is_upper_edit = is_upper_edit
79
+
80
+
81
+ args.dataname = args.dataset_name = 't2m'
82
+
83
+ self.vqvae = get_vqvae(args, is_upper_edit)
84
+ ckpt = torch.load(args.resume_pth, map_location='cpu')
85
+ self.vqvae.load_state_dict(ckpt['net'], strict=True)
86
+ if is_upper_edit:
87
+ class VQVAE_WRAPPER(torch.nn.Module):
88
+ def __init__(self, vqvae) :
89
+ super().__init__()
90
+ self.vqvae = vqvae
91
+
92
+ def forward(self, *args, **kwargs):
93
+ return self.vqvae(*args, **kwargs)
94
+ self.vqvae = VQVAE_WRAPPER(self.vqvae)
95
+ self.vqvae.eval()
96
+ self.vqvae
97
+
98
+ self.maskdecoder = get_maskdecoder(args, self.vqvae, is_upper_edit)
99
+ ckpt = torch.load(args.resume_trans, map_location='cpu')
100
+ self.maskdecoder.load_state_dict(ckpt['trans'], strict=True)
101
+ self.maskdecoder.train()
102
+ self.maskdecoder
103
+
104
+ def forward(self, text, lengths=-1, rand_pos=True):
105
+ b = len(text)
106
+ feat_clip_text = clip.tokenize(text, truncate=True)
107
+ feat_clip_text, word_emb = clip_model(feat_clip_text)
108
+ index_motion = self.maskdecoder(feat_clip_text, word_emb, type="sample", m_length=lengths, rand_pos=rand_pos, if_test=False)
109
+
110
+ m_token_length = torch.ceil((lengths)/4).int()
111
+ pred_pose_all = torch.zeros((b, 196, 263))
112
+ for k in range(b):
113
+ pred_pose = self.vqvae(index_motion[k:k+1, :m_token_length[k]], type='decode')
114
+ pred_pose_all[k:k+1, :int(lengths[k].item())] = pred_pose
115
+ return pred_pose_all
116
+
117
+ def inbetween_eval(self, base_pose, m_length, start_f, end_f, inbetween_text):
118
+ bs, seq = base_pose.shape[:2]
119
+ tokens = -1*torch.ones((bs, 50), dtype=torch.long)
120
+ m_token_length = torch.ceil((m_length)/4).int()
121
+ start_t = torch.round((start_f)/4).int()
122
+ end_t = torch.round((end_f)/4).int()
123
+
124
+ for k in range(bs):
125
+ index_motion = self.vqvae(base_pose[k:k+1, :m_length[k]], type='encode')
126
+ tokens[k, :start_t[k]] = index_motion[0][:start_t[k]]
127
+ tokens[k, end_t[k]:m_token_length[k]] = index_motion[0][end_t[k]:m_token_length[k]]
128
+
129
+ text = clip.tokenize(inbetween_text, truncate=True)
130
+ feat_clip_text, word_emb_clip = clip_model(text)
131
+
132
+ mask_id = self.maskdecoder.num_vq + 2
133
+ tokens[tokens==-1] = mask_id
134
+ inpaint_index = self.maskdecoder(feat_clip_text, word_emb_clip, type="sample", m_length=m_length, token_cond=tokens)
135
+
136
+ pred_pose_eval = torch.zeros((bs, seq, base_pose.shape[-1]))
137
+ for k in range(bs):
138
+ pred_pose = self.vqvae(inpaint_index[k:k+1, :m_token_length[k]], type='decode')
139
+ pred_pose_eval[k:k+1, :int(m_length[k].item())] = pred_pose
140
+ return pred_pose_eval
141
+
142
+ def long_range(self, text, lengths, num_transition_token=2, output='concat', index_motion=None):
143
+ b = len(text)
144
+ feat_clip_text = clip.tokenize(text, truncate=True)
145
+ feat_clip_text, word_emb = clip_model(feat_clip_text)
146
+ if index_motion is None:
147
+ index_motion = self.maskdecoder(feat_clip_text, word_emb, type="sample", m_length=lengths, rand_pos=False)
148
+
149
+ m_token_length = torch.ceil((lengths)/4).int()
150
+ if output == 'eval':
151
+ frame_length = m_token_length * 4
152
+ m_token_length = m_token_length.clone()
153
+ m_token_length = m_token_length - 2*num_transition_token
154
+ m_token_length[[0,-1]] += num_transition_token # first and last have transition only half
155
+
156
+ half_token_length = (m_token_length/2).int()
157
+ idx_full_len = half_token_length >= 24
158
+ half_token_length[idx_full_len] = half_token_length[idx_full_len] - 1
159
+
160
+ mask_id = self.maskdecoder.num_vq + 2
161
+ tokens = -1*torch.ones((b-1, 50), dtype=torch.long)
162
+ transition_train_length = []
163
+
164
+ for i in range(b-1):
165
+ if output == 'concat':
166
+ i_index_motion = index_motion[i]
167
+ i1_index_motion = index_motion[i+1]
168
+ if output == 'eval':
169
+ if i == 0:
170
+ i_index_motion = index_motion[i, :m_token_length[i]]
171
+ else:
172
+ i_index_motion = index_motion[i, num_transition_token:m_token_length[i] + num_transition_token]
173
+ if i == b-1:
174
+ i1_index_motion = index_motion[i+1, :m_token_length[i+1]]
175
+ else:
176
+ i1_index_motion = index_motion[i+1,
177
+ num_transition_token:m_token_length[i+1] + num_transition_token]
178
+ left_end = half_token_length[i]
179
+ right_start = left_end + num_transition_token
180
+ end = right_start + half_token_length[i+1]
181
+
182
+ tokens[i, :left_end] = i_index_motion[m_token_length[i]-left_end: m_token_length[i]]
183
+ tokens[i, left_end:right_start] = mask_id
184
+ tokens[i, right_start:end] = i1_index_motion[:half_token_length[i+1]]
185
+ transition_train_length.append(end)
186
+ transition_train_length = torch.tensor(transition_train_length).to(index_motion.device)
187
+ text = clip.tokenize(text[:-1], truncate=True)
188
+ feat_clip_text, word_emb_clip = clip_model(text)
189
+ inpaint_index = self.maskdecoder(feat_clip_text, word_emb_clip, type="sample", m_length=transition_train_length*4, token_cond=tokens, max_steps=1)
190
+
191
+ if output == 'concat':
192
+ all_tokens = []
193
+ for i in range(b-1):
194
+ all_tokens.append(index_motion[i, :m_token_length[i]])
195
+ all_tokens.append(inpaint_index[i, tokens[i] == mask_id])
196
+ all_tokens.append(index_motion[-1, :m_token_length[-1]])
197
+ all_tokens = torch.cat(all_tokens).unsqueeze(0)
198
+ pred_pose = self.vqvae(all_tokens, type='decode')
199
+ return pred_pose
200
+ elif output == 'eval':
201
+ all_tokens = []
202
+ for i in range(b):
203
+ motion_token = index_motion[i, :m_token_length[i]]
204
+ if i == 0:
205
+ first_current_trans_tok = inpaint_index[i, tokens[i] == mask_id]
206
+ all_tokens.append(motion_token)
207
+ all_tokens.append(first_current_trans_tok)
208
+ else:
209
+ if i < b-1:
210
+ first_current_trans_tok = inpaint_index[i, tokens[i] == mask_id]
211
+ all_tokens.append(motion_token)
212
+ all_tokens.append(first_current_trans_tok)
213
+ else:
214
+ all_tokens.append(motion_token)
215
+ all_tokens = torch.cat(all_tokens)
216
+ pred_pose_concat = self.vqvae(all_tokens.unsqueeze(0), type='decode')
217
+
218
+ trans_frame = num_transition_token*4
219
+ pred_pose = torch.zeros((b, 196, 263))
220
+ current_point = 0
221
+ for i in range(b):
222
+ if i == 0:
223
+ start_f = torch.tensor(0)
224
+ end_f = frame_length[i]
225
+ else:
226
+ start_f = current_point - trans_frame
227
+ end_f = start_f + frame_length[i]
228
+ current_point = end_f
229
+ pred_pose[i, :frame_length[i]] = pred_pose_concat[0, start_f: end_f]
230
+ return pred_pose
231
+
232
+ def upper_edit(self, pose, m_length, upper_text, lower_mask=None):
233
+ pose = pose.clone().float() # bs, nb_joints, joints_dim, seq_len
234
+ m_tokens_len = torch.ceil((m_length)/4)
235
+ bs, seq = pose.shape[:2]
236
+ max_motion_length = int(seq/4) + 1
237
+ mot_end_idx = self.vqvae.vqvae.num_code
238
+ mot_pad_idx = self.vqvae.vqvae.num_code + 1
239
+ mask_id = self.vqvae.vqvae.num_code + 2
240
+ target_lower = []
241
+ for k in range(bs):
242
+ target = self.vqvae(pose[k:k+1, :m_length[k]], type='encode')
243
+ if m_tokens_len[k]+1 < max_motion_length:
244
+ target = torch.cat([target,
245
+ torch.ones((1, 1, 2), dtype=int, device=target.device) * mot_end_idx,
246
+ torch.ones((1, max_motion_length-1-m_tokens_len[k].int().item(), 2), dtype=int, device=target.device) * mot_pad_idx], axis=1)
247
+ else:
248
+ target = torch.cat([target,
249
+ torch.ones((1, 1, 2), dtype=int, device=target.device) * mot_end_idx], axis=1)
250
+ target_lower.append(target[..., 1])
251
+ target_lower = torch.cat(target_lower, axis=0)
252
+
253
+ ### lower mask ###
254
+ if lower_mask is not None:
255
+ lower_mask = torch.cat([lower_mask, torch.zeros(bs, 1, dtype=int)], dim=1).bool()
256
+ target_lower_masked = target_lower.clone()
257
+ target_lower_masked[lower_mask] = mask_id
258
+ select_end = target_lower == mot_end_idx
259
+ target_lower_masked[select_end] = target_lower[select_end]
260
+ else:
261
+ target_lower_masked = target_lower
262
+ ##################
263
+
264
+ pred_len = m_length
265
+ pred_tok_len = m_tokens_len
266
+ pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1]))
267
+
268
+ # __upper_text__ = ['A man punches with right hand.'] * 32
269
+ text = clip.tokenize(upper_text, truncate=True)
270
+ feat_clip_text, word_emb_clip = clip_model(text)
271
+ # index_motion = trans_encoder(feat_clip_text, idx_lower=target_lower_masked, word_emb=word_emb_clip, type="sample", m_length=pred_len, rand_pos=True, CFG=-1)
272
+ index_motion = self.maskdecoder(feat_clip_text, target_lower_masked, word_emb_clip, type="sample", m_length=pred_len, rand_pos=True)
273
+ for i in range(bs):
274
+ all_tokens = torch.cat([
275
+ index_motion[i:i+1, :int(pred_tok_len[i].item()), None],
276
+ target_lower[i:i+1, :int(pred_tok_len[i].item()), None]
277
+ ], axis=-1)
278
+ pred_pose = self.vqvae(all_tokens, type='decode')
279
+ pred_pose_eval[i:i+1, :int(pred_len[i].item())] = pred_pose
280
+
281
+ return pred_pose_eval
282
+
283
+
284
+ if __name__ == '__main__':
285
+ args = option_trans.get_args_parser()
286
+
287
+ # python generate.py --resume-pth '/home/epinyoan/git/MaskText2Motion/T2M-BD/output/vq/2023-07-19-04-17-17_12_VQVAE_20batchResetNRandom_8192_32/net_last.pth' --resume-trans '/home/epinyoan/git/MaskText2Motion/T2M-BD/output/t2m/2023-10-12-10-11-15_HML3D_45_crsAtt1lyr_40breset_WRONG_THIS_20BRESET/net_last.pth' --text 'the person crouches and walks forward.' --length 156
288
+
289
+ mmm = MMM(args)
290
+ pred_pose = mmm([args.text], torch.tensor([args.length]), rand_pos=False)
291
+
292
+ std = np.load('./exit/t2m-std.npy')
293
+ mean = np.load('./exit/t2m-mean.npy')
294
+ file_name = '_'.join(args.text.split(' '))+'_'+str(args.length)
295
+ visualize_2motions(pred_pose[0].detach().cpu().numpy(), std, mean, 't2m', args.length, save_path='./output/'+file_name+'.html')
296
+
297
+
models/encdec.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.resnet import Resnet1D
3
+
4
+ class PrintModule(nn.Module):
5
+ def __init__(self, me=''):
6
+ super().__init__()
7
+ self.me = me
8
+
9
+ def forward(self, x):
10
+ print(self.me, x.shape)
11
+ return x
12
+
13
+ class Encoder(nn.Module):
14
+ def __init__(self,
15
+ input_emb_width = 3,
16
+ output_emb_width = 512,
17
+ down_t = 3,
18
+ stride_t = 2,
19
+ width = 512,
20
+ depth = 3,
21
+ dilation_growth_rate = 3,
22
+ activation='relu',
23
+ norm=None):
24
+ super().__init__()
25
+
26
+ blocks = []
27
+ filter_t, pad_t = stride_t * 2, stride_t // 2
28
+ blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1))
29
+ blocks.append(nn.ReLU())
30
+
31
+ for i in range(down_t):
32
+ input_dim = width
33
+ block = nn.Sequential(
34
+ nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t),
35
+ Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm),
36
+ )
37
+ blocks.append(block)
38
+ blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1))
39
+ self.model = nn.Sequential(*blocks)
40
+
41
+ def forward(self, x):
42
+ return self.model(x)
43
+
44
+ class Decoder(nn.Module):
45
+ def __init__(self,
46
+ input_emb_width = 3,
47
+ output_emb_width = 512,
48
+ down_t = 3,
49
+ stride_t = 2,
50
+ width = 512,
51
+ depth = 3,
52
+ dilation_growth_rate = 3,
53
+ activation='relu',
54
+ norm=None):
55
+ super().__init__()
56
+ blocks = []
57
+
58
+ filter_t, pad_t = stride_t * 2, stride_t // 2
59
+ blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1))
60
+ blocks.append(nn.ReLU())
61
+ for i in range(down_t):
62
+ out_dim = width
63
+ block = nn.Sequential(
64
+ Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm),
65
+ nn.Upsample(scale_factor=2, mode='nearest'),
66
+ nn.Conv1d(width, out_dim, 3, 1, 1)
67
+ )
68
+ blocks.append(block)
69
+ blocks.append(nn.Conv1d(width, width, 3, 1, 1))
70
+ blocks.append(nn.ReLU())
71
+ blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1))
72
+ self.model = nn.Sequential(*blocks)
73
+
74
+ def forward(self, x):
75
+ return self.model(x)
76
+
models/evaluator_wrapper.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from os.path import join as pjoin
4
+ import numpy as np
5
+ from models.modules import MovementConvEncoder, TextEncoderBiGRUCo, MotionEncoderBiGRUCo
6
+ from utils.word_vectorizer import POS_enumerator
7
+
8
+ def build_models(opt):
9
+ movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
10
+ text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
11
+ pos_size=opt.dim_pos_ohot,
12
+ hidden_size=opt.dim_text_hidden,
13
+ output_size=opt.dim_coemb_hidden,
14
+ device=opt.device)
15
+
16
+ motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
17
+ hidden_size=opt.dim_motion_hidden,
18
+ output_size=opt.dim_coemb_hidden,
19
+ device=opt.device)
20
+
21
+ checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
22
+ map_location=opt.device)
23
+ movement_enc.load_state_dict(checkpoint['movement_encoder'])
24
+ text_enc.load_state_dict(checkpoint['text_encoder'])
25
+ motion_enc.load_state_dict(checkpoint['motion_encoder'])
26
+ print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
27
+ return text_enc, motion_enc, movement_enc
28
+
29
+
30
+ class EvaluatorModelWrapper(object):
31
+
32
+ def __init__(self, opt):
33
+
34
+ if opt.dataset_name == 't2m':
35
+ opt.dim_pose = 263
36
+ elif opt.dataset_name == 'kit':
37
+ opt.dim_pose = 251
38
+ else:
39
+ raise KeyError('Dataset not Recognized!!!')
40
+
41
+ opt.dim_word = 300
42
+ opt.max_motion_length = 196
43
+ opt.dim_pos_ohot = len(POS_enumerator)
44
+ opt.dim_motion_hidden = 1024
45
+ opt.max_text_len = 20
46
+ opt.dim_text_hidden = 512
47
+ opt.dim_coemb_hidden = 512
48
+
49
+ # print(opt)
50
+
51
+ self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
52
+ self.opt = opt
53
+ self.device = opt.device
54
+
55
+ self.text_encoder.to(opt.device)
56
+ self.motion_encoder.to(opt.device)
57
+ self.movement_encoder.to(opt.device)
58
+
59
+ self.text_encoder.eval()
60
+ self.motion_encoder.eval()
61
+ self.movement_encoder.eval()
62
+
63
+ # Please note that the results does not following the order of inputs
64
+ def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
65
+ with torch.no_grad():
66
+ word_embs = word_embs.detach().to(self.device).float()
67
+ pos_ohot = pos_ohot.detach().to(self.device).float()
68
+ motions = motions.detach().to(self.device).float()
69
+
70
+ '''Movement Encoding'''
71
+ movements = self.movement_encoder(motions[..., :-4]).detach()
72
+ m_lens = m_lens // self.opt.unit_length
73
+ motion_embedding = self.motion_encoder(movements, m_lens)
74
+
75
+ '''Text Encoding'''
76
+ text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
77
+ return text_embedding, motion_embedding
78
+
79
+ # Please note that the results does not following the order of inputs
80
+ def get_motion_embeddings(self, motions, m_lens):
81
+ with torch.no_grad():
82
+ motions = motions.detach().to(self.device).float()
83
+
84
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
85
+ motions = motions[align_idx]
86
+ m_lens = m_lens[align_idx]
87
+
88
+ '''Movement Encoding'''
89
+ movements = self.movement_encoder(motions[..., :-4]).detach()
90
+ m_lens = m_lens // self.opt.unit_length
91
+ motion_embedding = self.motion_encoder(movements, m_lens)
92
+ return motion_embedding
models/modules.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils.rnn import pack_padded_sequence
4
+
5
+ def init_weight(m):
6
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
7
+ nn.init.xavier_normal_(m.weight)
8
+ # m.bias.data.fill_(0.01)
9
+ if m.bias is not None:
10
+ nn.init.constant_(m.bias, 0)
11
+
12
+
13
+ class MovementConvEncoder(nn.Module):
14
+ def __init__(self, input_size, hidden_size, output_size):
15
+ super(MovementConvEncoder, self).__init__()
16
+ self.main = nn.Sequential(
17
+ nn.Conv1d(input_size, hidden_size, 4, 2, 1),
18
+ nn.Dropout(0.2, inplace=True),
19
+ nn.LeakyReLU(0.2, inplace=True),
20
+ nn.Conv1d(hidden_size, output_size, 4, 2, 1),
21
+ nn.Dropout(0.2, inplace=True),
22
+ nn.LeakyReLU(0.2, inplace=True),
23
+ )
24
+ self.out_net = nn.Linear(output_size, output_size)
25
+ self.main.apply(init_weight)
26
+ self.out_net.apply(init_weight)
27
+
28
+ def forward(self, inputs):
29
+ inputs = inputs.permute(0, 2, 1)
30
+ outputs = self.main(inputs).permute(0, 2, 1)
31
+ # print(outputs.shape)
32
+ return self.out_net(outputs)
33
+
34
+
35
+
36
+ class TextEncoderBiGRUCo(nn.Module):
37
+ def __init__(self, word_size, pos_size, hidden_size, output_size, device):
38
+ super(TextEncoderBiGRUCo, self).__init__()
39
+ self.device = device
40
+
41
+ self.pos_emb = nn.Linear(pos_size, word_size)
42
+ self.input_emb = nn.Linear(word_size, hidden_size)
43
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
44
+ self.output_net = nn.Sequential(
45
+ nn.Linear(hidden_size * 2, hidden_size),
46
+ nn.LayerNorm(hidden_size),
47
+ nn.LeakyReLU(0.2, inplace=True),
48
+ nn.Linear(hidden_size, output_size)
49
+ )
50
+
51
+ self.input_emb.apply(init_weight)
52
+ self.pos_emb.apply(init_weight)
53
+ self.output_net.apply(init_weight)
54
+ self.hidden_size = hidden_size
55
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
56
+
57
+ # input(batch_size, seq_len, dim)
58
+ def forward(self, word_embs, pos_onehot, cap_lens):
59
+ num_samples = word_embs.shape[0]
60
+
61
+ pos_embs = self.pos_emb(pos_onehot)
62
+ inputs = word_embs + pos_embs
63
+ input_embs = self.input_emb(inputs)
64
+ hidden = self.hidden.repeat(1, num_samples, 1)
65
+
66
+ cap_lens = cap_lens.data.tolist()
67
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
68
+
69
+ gru_seq, gru_last = self.gru(emb, hidden)
70
+
71
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
72
+
73
+ return self.output_net(gru_last)
74
+
75
+
76
+ class MotionEncoderBiGRUCo(nn.Module):
77
+ def __init__(self, input_size, hidden_size, output_size, device):
78
+ super(MotionEncoderBiGRUCo, self).__init__()
79
+ self.device = device
80
+
81
+ self.input_emb = nn.Linear(input_size, hidden_size)
82
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
83
+ self.output_net = nn.Sequential(
84
+ nn.Linear(hidden_size*2, hidden_size),
85
+ nn.LayerNorm(hidden_size),
86
+ nn.LeakyReLU(0.2, inplace=True),
87
+ nn.Linear(hidden_size, output_size)
88
+ )
89
+
90
+ self.input_emb.apply(init_weight)
91
+ self.output_net.apply(init_weight)
92
+ self.hidden_size = hidden_size
93
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
94
+
95
+ # input(batch_size, seq_len, dim)
96
+ def forward(self, inputs, m_lens):
97
+ num_samples = inputs.shape[0]
98
+
99
+ input_embs = self.input_emb(inputs)
100
+ hidden = self.hidden.repeat(1, num_samples, 1)
101
+
102
+ cap_lens = m_lens.data.tolist()
103
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True, enforce_sorted=False)
104
+
105
+ gru_seq, gru_last = self.gru(emb, hidden)
106
+
107
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
108
+
109
+ return self.output_net(gru_last)
models/pos_encoding.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various positional encodings for the transformer.
3
+ """
4
+ import math
5
+ import torch
6
+ from torch import nn
7
+
8
+ def PE1d_sincos(seq_length, dim):
9
+ """
10
+ :param d_model: dimension of the model
11
+ :param length: length of positions
12
+ :return: length*d_model position matrix
13
+ """
14
+ if dim % 2 != 0:
15
+ raise ValueError("Cannot use sin/cos positional encoding with "
16
+ "odd dim (got dim={:d})".format(dim))
17
+ pe = torch.zeros(seq_length, dim)
18
+ position = torch.arange(0, seq_length).unsqueeze(1)
19
+ div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
20
+ -(math.log(10000.0) / dim)))
21
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
22
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
23
+
24
+ return pe.unsqueeze(1)
25
+
26
+
27
+ class PositionEmbedding(nn.Module):
28
+ """
29
+ Absolute pos embedding (standard), learned.
30
+ """
31
+ def __init__(self, seq_length, dim, dropout, grad=False):
32
+ super().__init__()
33
+ self.embed = nn.Parameter(data=PE1d_sincos(seq_length, dim), requires_grad=grad)
34
+ self.dropout = nn.Dropout(p=dropout)
35
+
36
+ def forward(self, x):
37
+ # x.shape: bs, seq_len, feat_dim
38
+ l = x.shape[1]
39
+ x = x.permute(1, 0, 2) + self.embed[:l].expand(x.permute(1, 0, 2).shape)
40
+ x = self.dropout(x.permute(1, 0, 2))
41
+ return x
42
+
43
+
models/quantize_cnn.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class QuantizeEMAReset(nn.Module):
7
+ def __init__(self, nb_code, code_dim, args):
8
+ super().__init__()
9
+ self.nb_code = nb_code
10
+ self.code_dim = code_dim
11
+ self.mu = args.mu
12
+ self.reset_codebook()
13
+ self.reset_count = 0
14
+ self.usage = torch.zeros((self.nb_code, 1))
15
+
16
+ def reset_codebook(self):
17
+ self.init = False
18
+ self.code_sum = None
19
+ self.code_count = None
20
+ self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim))
21
+
22
+ def _tile(self, x):
23
+ nb_code_x, code_dim = x.shape
24
+ if nb_code_x < self.nb_code:
25
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
26
+ std = 0.01 / np.sqrt(code_dim)
27
+ out = x.repeat(n_repeats, 1)
28
+ out = out + torch.randn_like(out) * std
29
+ else :
30
+ out = x
31
+ return out
32
+
33
+ def init_codebook(self, x):
34
+ out = self._tile(x)
35
+ self.codebook = out[:self.nb_code]
36
+ self.code_sum = self.codebook.clone()
37
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
38
+ self.init = True
39
+
40
+ @torch.no_grad()
41
+ def compute_perplexity(self, code_idx) :
42
+ # Calculate new centres
43
+ code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
44
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
45
+
46
+ code_count = code_onehot.sum(dim=-1) # nb_code
47
+ prob = code_count / torch.sum(code_count)
48
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
49
+ return perplexity
50
+
51
+ @torch.no_grad()
52
+ def update_codebook(self, x, code_idx):
53
+
54
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
55
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
56
+
57
+ code_sum = torch.matmul(code_onehot, x) # nb_code, w
58
+ code_count = code_onehot.sum(dim=-1) # nb_code
59
+
60
+ out = self._tile(x)
61
+ code_rand = out[torch.randperm(out.shape[0])[:self.nb_code]]
62
+
63
+ # Update centres
64
+ self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
65
+ self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
66
+
67
+ usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
68
+ self.usage = self.usage.to(usage.device)
69
+ if self.reset_count >= 20:
70
+ self.reset_count = 0
71
+ usage = (usage + self.usage >= 1.0).float()
72
+ else:
73
+ self.reset_count += 1
74
+ self.usage = (usage + self.usage >= 1.0).float()
75
+ usage = torch.ones_like(self.usage, device=x.device)
76
+ code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
77
+
78
+ self.codebook = usage * code_update + (1 - usage) * code_rand
79
+ prob = code_count / torch.sum(code_count)
80
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
81
+
82
+
83
+ return perplexity
84
+
85
+ def preprocess(self, x):
86
+ # NCT -> NTC -> [NT, C]
87
+ x = x.permute(0, 2, 1).contiguous()
88
+ x = x.view(-1, x.shape[-1])
89
+ return x
90
+
91
+ def quantize(self, x):
92
+ # Calculate latent code x_l
93
+ k_w = self.codebook.t()
94
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
95
+ keepdim=True) # (N * L, b)
96
+ _, code_idx = torch.min(distance, dim=-1)
97
+ return code_idx
98
+
99
+ def dequantize(self, code_idx):
100
+ x = F.embedding(code_idx, self.codebook)
101
+ return x
102
+
103
+
104
+ def forward(self, x):
105
+ N, width, T = x.shape
106
+
107
+ # Preprocess
108
+ x = self.preprocess(x)
109
+
110
+ # Init codebook if not inited
111
+ if self.training and not self.init:
112
+ self.init_codebook(x)
113
+
114
+ # quantize and dequantize through bottleneck
115
+ code_idx = self.quantize(x)
116
+ x_d = self.dequantize(code_idx)
117
+
118
+ # Update embeddings
119
+ if self.training:
120
+ perplexity = self.update_codebook(x, code_idx)
121
+ else :
122
+ perplexity = self.compute_perplexity(code_idx)
123
+
124
+ # Loss
125
+ commit_loss = F.mse_loss(x, x_d.detach())
126
+
127
+ # Passthrough
128
+ x_d = x + (x_d - x).detach()
129
+
130
+ # Postprocess
131
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
132
+
133
+ return x_d, commit_loss, perplexity
134
+
135
+
136
+
137
+ class Quantizer(nn.Module):
138
+ def __init__(self, n_e, e_dim, beta):
139
+ super(Quantizer, self).__init__()
140
+
141
+ self.e_dim = e_dim
142
+ self.n_e = n_e
143
+ self.beta = beta
144
+
145
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
146
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
147
+
148
+ def forward(self, z):
149
+
150
+ N, width, T = z.shape
151
+ z = self.preprocess(z)
152
+ assert z.shape[-1] == self.e_dim
153
+ z_flattened = z.contiguous().view(-1, self.e_dim)
154
+
155
+ # B x V
156
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
157
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
158
+ torch.matmul(z_flattened, self.embedding.weight.t())
159
+ # B x 1
160
+ min_encoding_indices = torch.argmin(d, dim=1)
161
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
162
+
163
+ # compute loss for embedding
164
+ loss = torch.mean((z_q - z.detach())**2) + self.beta * \
165
+ torch.mean((z_q.detach() - z)**2)
166
+
167
+ # preserve gradients
168
+ z_q = z + (z_q - z).detach()
169
+ z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
170
+
171
+ min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype)
172
+ e_mean = torch.mean(min_encodings, dim=0)
173
+ perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10)))
174
+ return z_q, loss, perplexity
175
+
176
+ def quantize(self, z):
177
+
178
+ assert z.shape[-1] == self.e_dim
179
+
180
+ # B x V
181
+ d = torch.sum(z ** 2, dim=1, keepdim=True) + \
182
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
183
+ torch.matmul(z, self.embedding.weight.t())
184
+ # B x 1
185
+ min_encoding_indices = torch.argmin(d, dim=1)
186
+ return min_encoding_indices
187
+
188
+ def dequantize(self, indices):
189
+
190
+ index_flattened = indices.view(-1)
191
+ z_q = self.embedding(index_flattened)
192
+ z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous()
193
+ return z_q
194
+
195
+ def preprocess(self, x):
196
+ # NCT -> NTC -> [NT, C]
197
+ x = x.permute(0, 2, 1).contiguous()
198
+ x = x.view(-1, x.shape[-1])
199
+ return x
200
+
201
+
202
+
203
+ class QuantizeReset(nn.Module):
204
+ def __init__(self, nb_code, code_dim, args):
205
+ super().__init__()
206
+ self.nb_code = nb_code
207
+ self.code_dim = code_dim
208
+ self.reset_codebook()
209
+ self.codebook = nn.Parameter(torch.randn(nb_code, code_dim))
210
+
211
+ def reset_codebook(self):
212
+ self.init = False
213
+ self.code_count = None
214
+
215
+ def _tile(self, x):
216
+ nb_code_x, code_dim = x.shape
217
+ if nb_code_x < self.nb_code:
218
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
219
+ std = 0.01 / np.sqrt(code_dim)
220
+ out = x.repeat(n_repeats, 1)
221
+ out = out + torch.randn_like(out) * std
222
+ else :
223
+ out = x
224
+ return out
225
+
226
+ def init_codebook(self, x):
227
+ out = self._tile(x)
228
+ self.codebook = nn.Parameter(out[:self.nb_code])
229
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
230
+ self.init = True
231
+
232
+ @torch.no_grad()
233
+ def compute_perplexity(self, code_idx) :
234
+ # Calculate new centres
235
+ code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
236
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
237
+
238
+ code_count = code_onehot.sum(dim=-1) # nb_code
239
+ prob = code_count / torch.sum(code_count)
240
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
241
+ return perplexity
242
+
243
+ def update_codebook(self, x, code_idx):
244
+
245
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
246
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
247
+
248
+ code_count = code_onehot.sum(dim=-1) # nb_code
249
+
250
+ out = self._tile(x)
251
+ code_rand = out[:self.nb_code]
252
+
253
+ # Update centres
254
+ self.code_count = code_count # nb_code
255
+ usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
256
+
257
+ self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand
258
+ prob = code_count / torch.sum(code_count)
259
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
260
+
261
+
262
+ return perplexity
263
+
264
+ def preprocess(self, x):
265
+ # NCT -> NTC -> [NT, C]
266
+ x = x.permute(0, 2, 1).contiguous()
267
+ x = x.view(-1, x.shape[-1])
268
+ return x
269
+
270
+ def quantize(self, x):
271
+ # Calculate latent code x_l
272
+ k_w = self.codebook.t()
273
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
274
+ keepdim=True) # (N * L, b)
275
+ _, code_idx = torch.min(distance, dim=-1)
276
+ return code_idx
277
+
278
+ def dequantize(self, code_idx):
279
+ x = F.embedding(code_idx, self.codebook)
280
+ return x
281
+
282
+
283
+ def forward(self, x):
284
+ N, width, T = x.shape
285
+ # Preprocess
286
+ x = self.preprocess(x)
287
+ # Init codebook if not inited
288
+ if self.training and not self.init:
289
+ self.init_codebook(x)
290
+ # quantize and dequantize through bottleneck
291
+ code_idx = self.quantize(x)
292
+ x_d = self.dequantize(code_idx)
293
+ # Update embeddings
294
+ if self.training:
295
+ perplexity = self.update_codebook(x, code_idx)
296
+ else :
297
+ perplexity = self.compute_perplexity(code_idx)
298
+
299
+ # Loss
300
+ commit_loss = F.mse_loss(x, x_d.detach())
301
+
302
+ # Passthrough
303
+ x_d = x + (x_d - x).detach()
304
+
305
+ # Postprocess
306
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
307
+
308
+ return x_d, commit_loss, perplexity
309
+
310
+
311
+ class QuantizeEMA(nn.Module):
312
+ def __init__(self, nb_code, code_dim, args):
313
+ super().__init__()
314
+ self.nb_code = nb_code
315
+ self.code_dim = code_dim
316
+ self.mu = 0.99
317
+ self.reset_codebook()
318
+
319
+ def reset_codebook(self):
320
+ self.init = False
321
+ self.code_sum = None
322
+ self.code_count = None
323
+ self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda())
324
+
325
+ def _tile(self, x):
326
+ nb_code_x, code_dim = x.shape
327
+ if nb_code_x < self.nb_code:
328
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
329
+ std = 0.01 / np.sqrt(code_dim)
330
+ out = x.repeat(n_repeats, 1)
331
+ out = out + torch.randn_like(out) * std
332
+ else :
333
+ out = x
334
+ return out
335
+
336
+ def init_codebook(self, x):
337
+ out = self._tile(x)
338
+ self.codebook = out[:self.nb_code]
339
+ self.code_sum = self.codebook.clone()
340
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
341
+ self.init = True
342
+
343
+ @torch.no_grad()
344
+ def compute_perplexity(self, code_idx) :
345
+ # Calculate new centres
346
+ code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
347
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
348
+
349
+ code_count = code_onehot.sum(dim=-1) # nb_code
350
+ prob = code_count / torch.sum(code_count)
351
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
352
+ return perplexity
353
+
354
+ @torch.no_grad()
355
+ def update_codebook(self, x, code_idx):
356
+
357
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
358
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
359
+
360
+ code_sum = torch.matmul(code_onehot, x) # nb_code, w
361
+ code_count = code_onehot.sum(dim=-1) # nb_code
362
+
363
+ # Update centres
364
+ self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
365
+ self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
366
+
367
+ code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
368
+
369
+ self.codebook = code_update
370
+ prob = code_count / torch.sum(code_count)
371
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
372
+
373
+ return perplexity
374
+
375
+ def preprocess(self, x):
376
+ # NCT -> NTC -> [NT, C]
377
+ x = x.permute(0, 2, 1).contiguous()
378
+ x = x.view(-1, x.shape[-1])
379
+ return x
380
+
381
+ def quantize(self, x):
382
+ # Calculate latent code x_l
383
+ k_w = self.codebook.t()
384
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
385
+ keepdim=True) # (N * L, b)
386
+ _, code_idx = torch.min(distance, dim=-1)
387
+ return code_idx
388
+
389
+ def dequantize(self, code_idx):
390
+ x = F.embedding(code_idx, self.codebook)
391
+ return x
392
+
393
+
394
+ def forward(self, x):
395
+ N, width, T = x.shape
396
+
397
+ # Preprocess
398
+ x = self.preprocess(x)
399
+
400
+ # Init codebook if not inited
401
+ if self.training and not self.init:
402
+ self.init_codebook(x)
403
+
404
+ # quantize and dequantize through bottleneck
405
+ code_idx = self.quantize(x)
406
+ x_d = self.dequantize(code_idx)
407
+
408
+ # Update embeddings
409
+ if self.training:
410
+ perplexity = self.update_codebook(x, code_idx)
411
+ else :
412
+ perplexity = self.compute_perplexity(code_idx)
413
+
414
+ # Loss
415
+ commit_loss = F.mse_loss(x, x_d.detach())
416
+
417
+ # Passthrough
418
+ x_d = x + (x_d - x).detach()
419
+
420
+ # Postprocess
421
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
422
+
423
+ return x_d, commit_loss, perplexity
models/resnet.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class nonlinearity(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def forward(self, x):
9
+ # swish
10
+ return x * torch.sigmoid(x)
11
+
12
+ class ResConv1DBlock(nn.Module):
13
+ def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None):
14
+ super().__init__()
15
+ padding = dilation
16
+ self.norm = norm
17
+ if norm == "LN":
18
+ self.norm1 = nn.LayerNorm(n_in)
19
+ self.norm2 = nn.LayerNorm(n_in)
20
+ elif norm == "GN":
21
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
22
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
23
+ elif norm == "BN":
24
+ self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
25
+ self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
26
+
27
+ else:
28
+ self.norm1 = nn.Identity()
29
+ self.norm2 = nn.Identity()
30
+
31
+ if activation == "relu":
32
+ self.activation1 = nn.ReLU()
33
+ self.activation2 = nn.ReLU()
34
+
35
+ elif activation == "silu":
36
+ self.activation1 = nonlinearity()
37
+ self.activation2 = nonlinearity()
38
+
39
+ elif activation == "gelu":
40
+ self.activation1 = nn.GELU()
41
+ self.activation2 = nn.GELU()
42
+
43
+
44
+
45
+ self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation)
46
+ self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,)
47
+
48
+
49
+ def forward(self, x):
50
+ x_orig = x
51
+ if self.norm == "LN":
52
+ x = self.norm1(x.transpose(-2, -1))
53
+ x = self.activation1(x.transpose(-2, -1))
54
+ else:
55
+ x = self.norm1(x)
56
+ x = self.activation1(x)
57
+
58
+ x = self.conv1(x)
59
+
60
+ if self.norm == "LN":
61
+ x = self.norm2(x.transpose(-2, -1))
62
+ x = self.activation2(x.transpose(-2, -1))
63
+ else:
64
+ x = self.norm2(x)
65
+ x = self.activation2(x)
66
+
67
+ x = self.conv2(x)
68
+ x = x + x_orig
69
+ return x
70
+
71
+ class Resnet1D(nn.Module):
72
+ def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None):
73
+ super().__init__()
74
+
75
+ blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)]
76
+ if reverse_dilation:
77
+ blocks = blocks[::-1]
78
+
79
+ self.model = nn.Sequential(*blocks)
80
+
81
+ def forward(self, x):
82
+ return self.model(x)
models/t2m_trans.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.distributions import Categorical
6
+ import models.pos_encoding as pos_encoding
7
+ from exit.utils import cosine_schedule, uniform, top_k, gumbel_sample, top_p
8
+ from tqdm import tqdm
9
+ from einops import rearrange, repeat
10
+ from exit.utils import get_model, generate_src_mask
11
+
12
+ class PatchUpSampling(nn.Module):
13
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.up_sampling = nn.Linear(dim, 4 * dim, bias=False)
17
+ self.norm = norm_layer(dim)
18
+
19
+ def forward(self, x):
20
+ """
21
+ x: B, F, C
22
+ """
23
+ x = self.norm(x)
24
+ x = self.up_sampling(x)
25
+ x0 = x[:, :, 0::4]
26
+ x1 = x[:, :, 1::4]
27
+ x2 = x[:, :, 2::4]
28
+ x3 = x[:, :, 3::4]
29
+ x = torch.cat([x0, x1, x2, x3], 1)
30
+ return x
31
+
32
+ class Decoder_Transformer(nn.Module):
33
+ def __init__(self,
34
+ code_dim=1024,
35
+ embed_dim=512,
36
+ output_dim=263,
37
+ block_size=16,
38
+ num_layers=2,
39
+ n_head=8,
40
+ drop_out_rate=0.1,
41
+ fc_rate=4):
42
+
43
+ super().__init__()
44
+ self.joint_embed = nn.Linear(code_dim, embed_dim)
45
+ self.drop = nn.Dropout(drop_out_rate)
46
+ # transformer block
47
+ self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
48
+ self.up_sample = PatchUpSampling(embed_dim)
49
+ self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
50
+ self.head = nn.Sequential(nn.LayerNorm(embed_dim),
51
+ nn.Linear(embed_dim, output_dim))
52
+ self.block_size = block_size
53
+ self.n_head = n_head
54
+ self.apply(self._init_weights)
55
+
56
+ def get_block_size(self):
57
+ return self.block_size
58
+
59
+ def _init_weights(self, module):
60
+ if isinstance(module, (nn.Linear, nn.Embedding)):
61
+ module.weight.data.normal_(mean=0.0, std=0.02)
62
+ if isinstance(module, nn.Linear) and module.bias is not None:
63
+ module.bias.data.zero_()
64
+ elif isinstance(module, nn.LayerNorm):
65
+ module.bias.data.zero_()
66
+ module.weight.data.fill_(1.0)
67
+
68
+ def forward(self, token_embeddings):
69
+ # token_embeddings = self.tok_emb(idx)
70
+ # B, T = src_mask.shape
71
+ # src_mask = src_mask.view(B, 1, 1, T).repeat(1, self.n_head, T, 1)
72
+
73
+ token_embeddings = token_embeddings.permute(0, 2, 1)
74
+ token_embeddings = self.joint_embed(token_embeddings)
75
+ x = self.pos_embed(token_embeddings)
76
+
77
+
78
+ for block in self.blocks:
79
+ x = block(x)
80
+ x = self.up_sample(x)
81
+
82
+
83
+ x = self.head(x).permute(0, 2, 1)
84
+ return x
85
+
86
+ # https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L342C9-L343C33
87
+ class PatchMerging(nn.Module):
88
+ def __init__(self, input_feats, dim, norm_layer=nn.LayerNorm):
89
+ super().__init__()
90
+ self.dim = dim
91
+ self.reduction = nn.Linear(4 * input_feats, dim, bias=False)
92
+ self.norm = norm_layer(4 * input_feats)
93
+
94
+ def forward(self, x):
95
+ """
96
+ x: B, F, C
97
+ """
98
+ x0 = x[:, 0::4, :] # B F/2 C
99
+ x1 = x[:, 1::4, :]
100
+ x2 = x[:, 2::4, :] # B F/2 C
101
+ x3 = x[:, 3::4, :]
102
+ x = torch.cat([x0, x1, x2, x3], -1) # B F/2 2*C
103
+ x = self.norm(x)
104
+ x = self.reduction(x)
105
+ return x
106
+
107
+ class Encoder_Transformer(nn.Module):
108
+ def __init__(self,
109
+ input_feats=1024,
110
+ embed_dim=512,
111
+ output_dim=263,
112
+ block_size=16,
113
+ num_layers=2,
114
+ n_head=8,
115
+ drop_out_rate=0.1,
116
+ fc_rate=4):
117
+
118
+ super().__init__()
119
+ self.joint_embed = nn.Linear(input_feats, embed_dim)
120
+ self.drop = nn.Dropout(drop_out_rate)
121
+ # transformer block
122
+ self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
123
+ self.weighted_mean_norm = nn.LayerNorm(embed_dim)
124
+ self.weighted_mean = torch.nn.Conv1d(in_channels=block_size, out_channels=1, kernel_size=1)
125
+
126
+ self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
127
+ self.head = nn.Sequential(nn.LayerNorm(embed_dim),
128
+ nn.Linear(embed_dim, output_dim))
129
+ self.block_size = block_size
130
+ self.n_head = n_head
131
+ self.apply(self._init_weights)
132
+
133
+ def get_block_size(self):
134
+ return self.block_size
135
+
136
+ def _init_weights(self, module):
137
+ if isinstance(module, (nn.Linear, nn.Embedding)):
138
+ module.weight.data.normal_(mean=0.0, std=0.02)
139
+ if isinstance(module, nn.Linear) and module.bias is not None:
140
+ module.bias.data.zero_()
141
+ elif isinstance(module, nn.LayerNorm):
142
+ module.bias.data.zero_()
143
+ module.weight.data.fill_(1.0)
144
+
145
+ def forward(self, joints):
146
+ # B, T = src_mask.shape
147
+
148
+ joints = joints.permute(0,2,1)
149
+ # token_embeddings = self.joint_embed(joints)
150
+
151
+ block_step_len = int(len(self.blocks)/3)
152
+
153
+ x = self.joint_embed(joints)
154
+ token_len = int(x.shape[1]/self.block_size)
155
+ _original_shape = list(x.shape)
156
+ x = x.view(x.shape[0]*token_len, self.block_size, -1)
157
+
158
+ x = self.pos_embed(x)
159
+ for block in self.blocks:
160
+ x = block(x)
161
+ x = self.weighted_mean_norm(x)
162
+ x = self.weighted_mean(x)
163
+ _original_shape[1] = int(_original_shape[1] / self.block_size)
164
+ x = x.view(*_original_shape)
165
+
166
+ x = self.head(x).permute(0, 2, 1)
167
+ return x
168
+
169
+ class Text2Motion_Transformer(nn.Module):
170
+
171
+ def __init__(self,
172
+ vqvae,
173
+ num_vq=1024,
174
+ embed_dim=512,
175
+ clip_dim=512,
176
+ block_size=16,
177
+ num_layers=2,
178
+ num_local_layer=0,
179
+ n_head=8,
180
+ drop_out_rate=0.1,
181
+ fc_rate=4):
182
+ super().__init__()
183
+ self.n_head = n_head
184
+ self.trans_base = CrossCondTransBase(vqvae, num_vq, embed_dim, clip_dim, block_size, num_layers, num_local_layer, n_head, drop_out_rate, fc_rate)
185
+ self.trans_head = CrossCondTransHead(num_vq, embed_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate)
186
+ self.block_size = block_size
187
+ self.num_vq = num_vq
188
+
189
+ # self.skip_trans = Skip_Connection_Transformer(num_vq, embed_dim, clip_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate)
190
+
191
+ def get_block_size(self):
192
+ return self.block_size
193
+
194
+ def forward(self, *args, type='forward', **kwargs):
195
+ '''type=[forward, sample]'''
196
+ if type=='forward':
197
+ return self.forward_function(*args, **kwargs)
198
+ elif type=='sample':
199
+ return self.sample(*args, **kwargs)
200
+ elif type=='inpaint':
201
+ return self.inpaint(*args, **kwargs)
202
+ else:
203
+ raise ValueError(f'Unknown "{type}" type')
204
+
205
+ def get_attn_mask(self, src_mask, att_txt=None):
206
+ if att_txt is None:
207
+ att_txt = torch.tensor([[True]]*src_mask.shape[0]).to(src_mask.device)
208
+ src_mask = torch.cat([att_txt, src_mask], dim=1)
209
+ B, T = src_mask.shape
210
+ src_mask = src_mask.view(B, 1, 1, T).repeat(1, self.n_head, T, 1)
211
+ return src_mask
212
+
213
+ def forward_function(self, idxs, clip_feature, src_mask=None, att_txt=None, word_emb=None):
214
+ if src_mask is not None:
215
+ src_mask = self.get_attn_mask(src_mask, att_txt)
216
+ feat = self.trans_base(idxs, clip_feature, src_mask, word_emb)
217
+ logits = self.trans_head(feat, src_mask)
218
+
219
+ return logits
220
+
221
+ def sample(self, clip_feature, word_emb, m_length=None, if_test=False, rand_pos=True, CFG=-1, token_cond=None, max_steps = 10):
222
+ max_length = 49
223
+ batch_size = clip_feature.shape[0]
224
+ mask_id = self.num_vq + 2
225
+ pad_id = self.num_vq + 1
226
+ end_id = self.num_vq
227
+ shape = (batch_size, self.block_size - 1)
228
+ topk_filter_thres = .9
229
+ starting_temperature = 1.0
230
+ scores = torch.ones(shape, dtype = torch.float32, device = clip_feature.device)
231
+
232
+ m_tokens_len = torch.ceil((m_length)/4).long()
233
+ src_token_mask = generate_src_mask(self.block_size-1, m_tokens_len+1)
234
+ src_token_mask_noend = generate_src_mask(self.block_size-1, m_tokens_len)
235
+ if token_cond is not None:
236
+ ids = token_cond.clone()
237
+ ids[~src_token_mask_noend] = pad_id
238
+ num_token_cond = (ids==mask_id).sum(-1)
239
+ else:
240
+ ids = torch.full(shape, mask_id, dtype = torch.long, device = clip_feature.device)
241
+
242
+ # [TODO] confirm that these 2 lines are not neccessary (repeated below and maybe don't need them at all)
243
+ ids[~src_token_mask] = pad_id # [INFO] replace with pad id
244
+ ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
245
+
246
+ sample_max_steps = torch.round(max_steps/max_length*m_tokens_len) + 1e-8
247
+ for step in range(max_steps):
248
+ timestep = torch.clip(step/(sample_max_steps), max=1)
249
+ if len(m_tokens_len)==1 and step > 0 and torch.clip(step-1/(sample_max_steps), max=1).cpu().item() == timestep:
250
+ break
251
+ rand_mask_prob = cosine_schedule(timestep) # timestep #
252
+ num_token_masked = (rand_mask_prob * m_tokens_len).long().clip(min=1)
253
+
254
+ if token_cond is not None:
255
+ num_token_masked = (rand_mask_prob * num_token_cond).long().clip(min=1)
256
+ scores[token_cond!=mask_id] = 0
257
+
258
+ # [INFO] rm no motion frames
259
+ scores[~src_token_mask_noend] = 0
260
+ scores = scores/scores.sum(-1)[:, None] # normalize only unmasked token
261
+
262
+ # if rand_pos:
263
+ # sorted_score_indices = scores.multinomial(scores.shape[-1], replacement=False) # stocastic
264
+ # else:
265
+ sorted, sorted_score_indices = scores.sort(descending=True) # deterministic
266
+
267
+ ids[~src_token_mask] = pad_id # [INFO] replace with pad id
268
+ ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
269
+ ## [INFO] Replace "mask_id" to "ids" that have highest "num_token_masked" "scores"
270
+ select_masked_indices = generate_src_mask(sorted_score_indices.shape[1], num_token_masked)
271
+ # [INFO] repeat last_id to make it scatter_ the existing last ids.
272
+ last_index = sorted_score_indices.gather(-1, num_token_masked.unsqueeze(-1)-1)
273
+ sorted_score_indices = sorted_score_indices * select_masked_indices + (last_index*~select_masked_indices)
274
+ ids.scatter_(-1, sorted_score_indices, mask_id)
275
+
276
+ logits = self.forward(ids, clip_feature, src_token_mask, word_emb=word_emb)[:,1:]
277
+ filtered_logits = logits #top_p(logits, .5) # #top_k(logits, topk_filter_thres)
278
+ if rand_pos:
279
+ temperature = 1 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
280
+ else:
281
+ temperature = 0 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
282
+
283
+ # [INFO] if temperature==0: is equal to argmax (filtered_logits.argmax(dim = -1))
284
+ # pred_ids = filtered_logits.argmax(dim = -1)
285
+ pred_ids = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
286
+ is_mask = ids == mask_id
287
+
288
+ ids = torch.where(
289
+ is_mask,
290
+ pred_ids,
291
+ ids
292
+ )
293
+
294
+ # if timestep == 1.:
295
+ # print(probs_without_temperature.shape)
296
+ probs_without_temperature = logits.softmax(dim = -1)
297
+ scores = 1 - probs_without_temperature.gather(-1, pred_ids[..., None])
298
+ scores = rearrange(scores, '... 1 -> ...')
299
+ scores = scores.masked_fill(~is_mask, 0)
300
+ if if_test:
301
+ return ids
302
+ return ids
303
+
304
+ def inpaint(self, first_tokens, last_tokens, clip_feature=None, word_emb=None, inpaint_len=2, rand_pos=False):
305
+ # support only one sample
306
+ assert first_tokens.shape[0] == 1
307
+ assert last_tokens.shape[0] == 1
308
+ max_steps = 20
309
+ max_length = 49
310
+ batch_size = first_tokens.shape[0]
311
+ mask_id = self.num_vq + 2
312
+ pad_id = self.num_vq + 1
313
+ end_id = self.num_vq
314
+ shape = (batch_size, self.block_size - 1)
315
+ scores = torch.ones(shape, dtype = torch.float32, device = first_tokens.device)
316
+
317
+ # force add first / last tokens
318
+ first_partition_pos_idx = first_tokens.shape[1]
319
+ second_partition_pos_idx = first_partition_pos_idx + inpaint_len
320
+ end_pos_idx = second_partition_pos_idx + last_tokens.shape[1]
321
+
322
+ m_tokens_len = torch.ones(batch_size, device = first_tokens.device)*end_pos_idx
323
+
324
+ src_token_mask = generate_src_mask(self.block_size-1, m_tokens_len+1)
325
+ src_token_mask_noend = generate_src_mask(self.block_size-1, m_tokens_len)
326
+ ids = torch.full(shape, mask_id, dtype = torch.long, device = first_tokens.device)
327
+
328
+ ids[:, :first_partition_pos_idx] = first_tokens
329
+ ids[:, second_partition_pos_idx:end_pos_idx] = last_tokens
330
+ src_token_mask_noend[:, :first_partition_pos_idx] = False
331
+ src_token_mask_noend[:, second_partition_pos_idx:end_pos_idx] = False
332
+
333
+ # [TODO] confirm that these 2 lines are not neccessary (repeated below and maybe don't need them at all)
334
+ ids[~src_token_mask] = pad_id # [INFO] replace with pad id
335
+ ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
336
+
337
+ temp = []
338
+ sample_max_steps = torch.round(max_steps/max_length*m_tokens_len) + 1e-8
339
+
340
+ if clip_feature is None:
341
+ clip_feature = torch.zeros(1, 512).to(first_tokens.device)
342
+ att_txt = torch.zeros((batch_size,1), dtype=torch.bool, device = first_tokens.device)
343
+ else:
344
+ att_txt = torch.ones((batch_size,1), dtype=torch.bool, device = first_tokens.device)
345
+
346
+ for step in range(max_steps):
347
+ timestep = torch.clip(step/(sample_max_steps), max=1)
348
+ rand_mask_prob = cosine_schedule(timestep) # timestep #
349
+ num_token_masked = (rand_mask_prob * m_tokens_len).long().clip(min=1)
350
+ # [INFO] rm no motion frames
351
+ scores[~src_token_mask_noend] = 0
352
+ # [INFO] rm begin and end frames
353
+ scores[:, :first_partition_pos_idx] = 0
354
+ scores[:, second_partition_pos_idx:end_pos_idx] = 0
355
+ scores = scores/scores.sum(-1)[:, None] # normalize only unmasked token
356
+
357
+ sorted, sorted_score_indices = scores.sort(descending=True) # deterministic
358
+
359
+ ids[~src_token_mask] = pad_id # [INFO] replace with pad id
360
+ ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
361
+ ## [INFO] Replace "mask_id" to "ids" that have highest "num_token_masked" "scores"
362
+ select_masked_indices = generate_src_mask(sorted_score_indices.shape[1], num_token_masked)
363
+ # [INFO] repeat last_id to make it scatter_ the existing last ids.
364
+ last_index = sorted_score_indices.gather(-1, num_token_masked.unsqueeze(-1)-1)
365
+ sorted_score_indices = sorted_score_indices * select_masked_indices + (last_index*~select_masked_indices)
366
+ ids.scatter_(-1, sorted_score_indices, mask_id)
367
+
368
+ # [TODO] force replace begin/end tokens b/c the num mask will be more than actual inpainting frames
369
+ ids[:, :first_partition_pos_idx] = first_tokens
370
+ ids[:, second_partition_pos_idx:end_pos_idx] = last_tokens
371
+
372
+ logits = self.forward(ids, clip_feature, src_token_mask, word_emb=word_emb)[:,1:]
373
+ filtered_logits = logits #top_k(logits, topk_filter_thres)
374
+ if rand_pos:
375
+ temperature = 1 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
376
+ else:
377
+ temperature = 0 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
378
+
379
+ # [INFO] if temperature==0: is equal to argmax (filtered_logits.argmax(dim = -1))
380
+ # pred_ids = filtered_logits.argmax(dim = -1)
381
+ pred_ids = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
382
+ is_mask = ids == mask_id
383
+ temp.append(is_mask[:1])
384
+
385
+ ids = torch.where(
386
+ is_mask,
387
+ pred_ids,
388
+ ids
389
+ )
390
+
391
+ probs_without_temperature = logits.softmax(dim = -1)
392
+ scores = 1 - probs_without_temperature.gather(-1, pred_ids[..., None])
393
+ scores = rearrange(scores, '... 1 -> ...')
394
+ scores = scores.masked_fill(~is_mask, 0)
395
+ return ids
396
+
397
+ class Attention(nn.Module):
398
+
399
+ def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1):
400
+ super().__init__()
401
+ assert embed_dim % 8 == 0
402
+ # key, query, value projections for all heads
403
+ self.key = nn.Linear(embed_dim, embed_dim)
404
+ self.query = nn.Linear(embed_dim, embed_dim)
405
+ self.value = nn.Linear(embed_dim, embed_dim)
406
+
407
+ self.attn_drop = nn.Dropout(drop_out_rate)
408
+ self.resid_drop = nn.Dropout(drop_out_rate)
409
+
410
+ self.proj = nn.Linear(embed_dim, embed_dim)
411
+ self.n_head = n_head
412
+
413
+ def forward(self, x, src_mask):
414
+ B, T, C = x.size()
415
+
416
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
417
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
418
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
419
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
420
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
421
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
422
+ if src_mask is not None:
423
+ att[~src_mask] = float('-inf')
424
+ att = F.softmax(att, dim=-1)
425
+ att = self.attn_drop(att)
426
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
427
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
428
+
429
+ # output projection
430
+ y = self.resid_drop(self.proj(y))
431
+ return y
432
+
433
+ class Block(nn.Module):
434
+
435
+ def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1, fc_rate=4):
436
+ super().__init__()
437
+ self.ln1 = nn.LayerNorm(embed_dim)
438
+ self.ln2 = nn.LayerNorm(embed_dim)
439
+ self.attn = Attention(embed_dim, block_size, n_head, drop_out_rate)
440
+ self.mlp = nn.Sequential(
441
+ nn.Linear(embed_dim, fc_rate * embed_dim),
442
+ nn.GELU(),
443
+ nn.Linear(fc_rate * embed_dim, embed_dim),
444
+ nn.Dropout(drop_out_rate),
445
+ )
446
+
447
+ def forward(self, x, src_mask=None):
448
+ x = x + self.attn(self.ln1(x), src_mask)
449
+ x = x + self.mlp(self.ln2(x))
450
+ return x
451
+
452
+ class CrossAttention(nn.Module):
453
+
454
+ def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1):
455
+ super().__init__()
456
+ assert embed_dim % 8 == 0
457
+ # key, query, value projections for all heads
458
+ self.key = nn.Linear(embed_dim, embed_dim)
459
+ self.query = nn.Linear(embed_dim, embed_dim)
460
+ self.value = nn.Linear(embed_dim, embed_dim)
461
+
462
+ self.attn_drop = nn.Dropout(drop_out_rate)
463
+ self.resid_drop = nn.Dropout(drop_out_rate)
464
+
465
+ self.proj = nn.Linear(embed_dim, embed_dim)
466
+ # causal mask to ensure that attention is only applied to the left in the input sequence
467
+ self.register_buffer("mask", torch.tril(torch.ones(block_size, 77)).view(1, 1, block_size, 77))
468
+ self.n_head = n_head
469
+
470
+ def forward(self, x,word_emb):
471
+ B, T, C = x.size()
472
+ B, N, D = word_emb.size()
473
+
474
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
475
+ k = self.key(word_emb).view(B, N, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
476
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
477
+ v = self.value(word_emb).view(B, N, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
478
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, N) -> (B, nh, T, N)
479
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
480
+ att = F.softmax(att, dim=-1)
481
+ att = self.attn_drop(att)
482
+ y = att @ v # (B, nh, T, N) x (B, nh, N, hs) -> (B, nh, T, hs)
483
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
484
+
485
+ # output projection
486
+ y = self.resid_drop(self.proj(y))
487
+ return y
488
+
489
+ class Block_crossatt(nn.Module):
490
+
491
+ def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1, fc_rate=4):
492
+ super().__init__()
493
+ self.ln1 = nn.LayerNorm(embed_dim)
494
+ self.ln2 = nn.LayerNorm(embed_dim)
495
+ self.ln3 = nn.LayerNorm(embed_dim)
496
+ self.attn = CrossAttention(embed_dim, block_size, n_head, drop_out_rate)
497
+ self.mlp = nn.Sequential(
498
+ nn.Linear(embed_dim, fc_rate * embed_dim),
499
+ nn.GELU(),
500
+ nn.Linear(fc_rate * embed_dim, embed_dim),
501
+ nn.Dropout(drop_out_rate),
502
+ )
503
+
504
+ def forward(self, x,word_emb):
505
+ x = x + self.attn(self.ln1(x), self.ln3(word_emb))
506
+ x = x + self.mlp(self.ln2(x))
507
+ return x
508
+
509
+ class CrossCondTransBase(nn.Module):
510
+
511
+ def __init__(self,
512
+ vqvae,
513
+ num_vq=1024,
514
+ embed_dim=512,
515
+ clip_dim=512,
516
+ block_size=16,
517
+ num_layers=2,
518
+ num_local_layer = 1,
519
+ n_head=8,
520
+ drop_out_rate=0.1,
521
+ fc_rate=4):
522
+ super().__init__()
523
+ self.vqvae = vqvae
524
+ # self.tok_emb = nn.Embedding(num_vq + 3, embed_dim).requires_grad_(False)
525
+ self.learn_tok_emb = nn.Embedding(3, self.vqvae.vqvae.code_dim)# [INFO] 3 = [end_id, blank_id, mask_id]
526
+ self.to_emb = nn.Linear(self.vqvae.vqvae.code_dim, embed_dim)
527
+
528
+ self.cond_emb = nn.Linear(clip_dim, embed_dim)
529
+ self.pos_embedding = nn.Embedding(block_size, embed_dim)
530
+ self.drop = nn.Dropout(drop_out_rate)
531
+ # transformer block
532
+ self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers-num_local_layer)])
533
+ self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
534
+
535
+ self.num_local_layer = num_local_layer
536
+ if num_local_layer > 0:
537
+ self.word_emb = nn.Linear(clip_dim, embed_dim)
538
+ self.cross_att = nn.Sequential(*[Block_crossatt(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_local_layer)])
539
+ self.block_size = block_size
540
+
541
+ self.apply(self._init_weights)
542
+
543
+ def get_block_size(self):
544
+ return self.block_size
545
+
546
+ def _init_weights(self, module):
547
+ if isinstance(module, (nn.Linear, nn.Embedding)):
548
+ module.weight.data.normal_(mean=0.0, std=0.02)
549
+ if isinstance(module, nn.Linear) and module.bias is not None:
550
+ module.bias.data.zero_()
551
+ elif isinstance(module, nn.LayerNorm):
552
+ module.bias.data.zero_()
553
+ module.weight.data.fill_(1.0)
554
+
555
+ def forward(self, idx, clip_feature, src_mask, word_emb):
556
+ if len(idx) == 0:
557
+ token_embeddings = self.cond_emb(clip_feature).unsqueeze(1)
558
+ else:
559
+ b, t = idx.size()
560
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
561
+ # forward the Trans model
562
+ not_learn_idx = idx<self.vqvae.vqvae.num_code
563
+ learn_idx = ~not_learn_idx
564
+
565
+ token_embeddings = torch.empty((*idx.shape, self.vqvae.vqvae.code_dim), device=idx.device)
566
+ token_embeddings[not_learn_idx] = self.vqvae.vqvae.quantizer.dequantize(idx[not_learn_idx]).requires_grad_(False)
567
+ token_embeddings[learn_idx] = self.learn_tok_emb(idx[learn_idx]-self.vqvae.vqvae.num_code)
568
+ token_embeddings = self.to_emb(token_embeddings)
569
+
570
+ if self.num_local_layer > 0:
571
+ word_emb = self.word_emb(word_emb)
572
+ token_embeddings = self.pos_embed(token_embeddings)
573
+ for module in self.cross_att:
574
+ token_embeddings = module(token_embeddings, word_emb)
575
+ token_embeddings = torch.cat([self.cond_emb(clip_feature).unsqueeze(1), token_embeddings], dim=1)
576
+
577
+ x = self.pos_embed(token_embeddings)
578
+ for block in self.blocks:
579
+ x = block(x, src_mask)
580
+
581
+ return x
582
+
583
+
584
+ class CrossCondTransHead(nn.Module):
585
+
586
+ def __init__(self,
587
+ num_vq=1024,
588
+ embed_dim=512,
589
+ block_size=16,
590
+ num_layers=2,
591
+ n_head=8,
592
+ drop_out_rate=0.1,
593
+ fc_rate=4):
594
+ super().__init__()
595
+
596
+ self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
597
+ self.ln_f = nn.LayerNorm(embed_dim)
598
+ self.head = nn.Linear(embed_dim, num_vq, bias=False)
599
+ self.block_size = block_size
600
+
601
+ self.apply(self._init_weights)
602
+
603
+ def get_block_size(self):
604
+ return self.block_size
605
+
606
+ def _init_weights(self, module):
607
+ if isinstance(module, (nn.Linear, nn.Embedding)):
608
+ module.weight.data.normal_(mean=0.0, std=0.02)
609
+ if isinstance(module, nn.Linear) and module.bias is not None:
610
+ module.bias.data.zero_()
611
+ elif isinstance(module, nn.LayerNorm):
612
+ module.bias.data.zero_()
613
+ module.weight.data.fill_(1.0)
614
+
615
+ def forward(self, x, src_mask):
616
+ for block in self.blocks:
617
+ x = block(x, src_mask)
618
+ x = self.ln_f(x)
619
+ logits = self.head(x)
620
+ return logits
621
+
622
+
623
+
624
+
625
+
626
+
models/t2m_trans_uplow.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.distributions import Categorical
6
+ import models.pos_encoding as pos_encoding
7
+ from exit.utils import cosine_schedule, uniform, top_k, gumbel_sample, top_p
8
+ from tqdm import tqdm
9
+ from einops import rearrange, repeat
10
+ from exit.utils import get_model, generate_src_mask
11
+
12
+
13
+ class PatchUpSampling(nn.Module):
14
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
15
+ super().__init__()
16
+ self.dim = dim
17
+ self.up_sampling = nn.Linear(dim, 4 * dim, bias=False)
18
+ self.norm = norm_layer(dim)
19
+
20
+ def forward(self, x):
21
+ """
22
+ x: B, F, C
23
+ """
24
+ x = self.norm(x)
25
+ x = self.up_sampling(x)
26
+ x0 = x[:, :, 0::4]
27
+ x1 = x[:, :, 1::4]
28
+ x2 = x[:, :, 2::4]
29
+ x3 = x[:, :, 3::4]
30
+ x = torch.cat([x0, x1, x2, x3], 1)
31
+ return x
32
+
33
+ class Decoder_Transformer(nn.Module):
34
+ def __init__(self,
35
+ code_dim=1024,
36
+ embed_dim=512,
37
+ output_dim=263,
38
+ block_size=16,
39
+ num_layers=2,
40
+ n_head=8,
41
+ drop_out_rate=0.1,
42
+ fc_rate=4):
43
+
44
+ super().__init__()
45
+ self.joint_embed = nn.Linear(code_dim, embed_dim)
46
+ self.drop = nn.Dropout(drop_out_rate)
47
+ # transformer block
48
+ self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
49
+ self.up_sample = PatchUpSampling(embed_dim)
50
+ self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
51
+ self.head = nn.Sequential(nn.LayerNorm(embed_dim),
52
+ nn.Linear(embed_dim, output_dim))
53
+ self.block_size = block_size
54
+ self.n_head = n_head
55
+ self.apply(self._init_weights)
56
+
57
+ def get_block_size(self):
58
+ return self.block_size
59
+
60
+ def _init_weights(self, module):
61
+ if isinstance(module, (nn.Linear, nn.Embedding)):
62
+ module.weight.data.normal_(mean=0.0, std=0.02)
63
+ if isinstance(module, nn.Linear) and module.bias is not None:
64
+ module.bias.data.zero_()
65
+ elif isinstance(module, nn.LayerNorm):
66
+ module.bias.data.zero_()
67
+ module.weight.data.fill_(1.0)
68
+
69
+ def forward(self, token_embeddings):
70
+ # token_embeddings = self.tok_emb(idx)
71
+ # B, T = src_mask.shape
72
+ # src_mask = src_mask.view(B, 1, 1, T).repeat(1, self.n_head, T, 1)
73
+
74
+ token_embeddings = token_embeddings.permute(0, 2, 1)
75
+ token_embeddings = self.joint_embed(token_embeddings)
76
+ x = self.pos_embed(token_embeddings)
77
+
78
+ # block_step_len = int(len(self.blocks)/3)
79
+ # mask_temp = get_attn_mask(_range=3, _max=x.shape[1]).to(src_mask.device)
80
+ # eye = torch.eye(x.shape[1]).unsqueeze(0).unsqueeze(0).to(src_mask.device).bool()
81
+ # src_mask = src_mask*mask_temp + eye
82
+
83
+ for block in self.blocks:
84
+ x = block(x)
85
+ x = self.up_sample(x)
86
+
87
+ # mask_2 = mask_1.repeat(1, 1, 2, 2)
88
+ # for block in self.blocks[block_step_len:2*block_step_len]:
89
+ # x = block(x, mask_2)
90
+ # x = self.up_sample(x)
91
+
92
+ # mask_3 = mask_2.repeat(1, 1, 2, 2)
93
+ # for block in self.blocks[2*block_step_len:]:
94
+ # x = block(x, mask_3)
95
+
96
+ x = self.head(x).permute(0, 2, 1)
97
+ return x
98
+
99
+ # https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L342C9-L343C33
100
+ class PatchMerging(nn.Module):
101
+ def __init__(self, input_feats, dim, norm_layer=nn.LayerNorm):
102
+ super().__init__()
103
+ self.dim = dim
104
+ self.reduction = nn.Linear(4 * input_feats, dim, bias=False)
105
+ self.norm = norm_layer(4 * input_feats)
106
+
107
+ def forward(self, x):
108
+ """
109
+ x: B, F, C
110
+ """
111
+ x0 = x[:, 0::4, :] # B F/2 C
112
+ x1 = x[:, 1::4, :]
113
+ x2 = x[:, 2::4, :] # B F/2 C
114
+ x3 = x[:, 3::4, :]
115
+ x = torch.cat([x0, x1, x2, x3], -1) # B F/2 2*C
116
+ x = self.norm(x)
117
+ x = self.reduction(x)
118
+ return x
119
+
120
+ class Encoder_Transformer(nn.Module):
121
+ def __init__(self,
122
+ input_feats=1024,
123
+ embed_dim=512,
124
+ output_dim=263,
125
+ block_size=16,
126
+ num_layers=2,
127
+ n_head=8,
128
+ drop_out_rate=0.1,
129
+ fc_rate=4):
130
+
131
+ super().__init__()
132
+ self.joint_embed = nn.Linear(input_feats, embed_dim)
133
+ self.drop = nn.Dropout(drop_out_rate)
134
+ # transformer block
135
+ self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
136
+ # self.patch_merging1 = PatchMerging(input_feats, embed_dim)
137
+ # self.patch_merging2 = PatchMerging(embed_dim)
138
+ self.weighted_mean_norm = nn.LayerNorm(embed_dim)
139
+ self.weighted_mean = torch.nn.Conv1d(in_channels=block_size, out_channels=1, kernel_size=1)
140
+
141
+ self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
142
+ self.head = nn.Sequential(nn.LayerNorm(embed_dim),
143
+ nn.Linear(embed_dim, output_dim))
144
+ self.block_size = block_size
145
+ self.n_head = n_head
146
+ self.apply(self._init_weights)
147
+
148
+ def get_block_size(self):
149
+ return self.block_size
150
+
151
+ def _init_weights(self, module):
152
+ if isinstance(module, (nn.Linear, nn.Embedding)):
153
+ module.weight.data.normal_(mean=0.0, std=0.02)
154
+ if isinstance(module, nn.Linear) and module.bias is not None:
155
+ module.bias.data.zero_()
156
+ elif isinstance(module, nn.LayerNorm):
157
+ module.bias.data.zero_()
158
+ module.weight.data.fill_(1.0)
159
+
160
+ def forward(self, joints):
161
+ # B, T = src_mask.shape
162
+ # src_mask = src_mask.view(B, 1, 1, T).repeat(1, self.n_head, T, 1)
163
+
164
+ joints = joints.permute(0,2,1)
165
+ # token_embeddings = self.joint_embed(joints)
166
+
167
+ block_step_len = int(len(self.blocks)/3)
168
+
169
+ x = self.joint_embed(joints)
170
+ token_len = int(x.shape[1]/self.block_size)
171
+ _original_shape = list(x.shape)
172
+ x = x.view(x.shape[0]*token_len, self.block_size, -1)
173
+
174
+ # mask_temp = get_attn_mask(_range=3, _max=x.shape[1]).to(src_mask.device)
175
+ # eye = torch.eye(x.shape[1]).unsqueeze(0).unsqueeze(0).to(src_mask.device).bool()
176
+ # src_mask = src_mask*mask_temp + eye
177
+
178
+ x = self.pos_embed(x)
179
+ for block in self.blocks:
180
+ x = block(x)
181
+ x = self.weighted_mean_norm(x)
182
+ x = self.weighted_mean(x)
183
+ _original_shape[1] = int(_original_shape[1] / self.block_size)
184
+ x = x.view(*_original_shape)
185
+
186
+ # for block in self.blocks[block_step_len:2*block_step_len]:
187
+ # x = block(x)
188
+ # x = self.patch_merging2(x)
189
+
190
+ # for block in self.blocks[2*block_step_len:]:
191
+ # x = block(x)
192
+ x = self.head(x).permute(0, 2, 1)
193
+ return x
194
+
195
+ class Text2Motion_Transformer(nn.Module):
196
+
197
+ def __init__(self,
198
+ vqvae,
199
+ num_vq=1024,
200
+ embed_dim=512,
201
+ clip_dim=512,
202
+ block_size=16,
203
+ num_layers=2,
204
+ num_local_layer=0,
205
+ n_head=8,
206
+ drop_out_rate=0.1,
207
+ fc_rate=4):
208
+ super().__init__()
209
+ self.n_head = n_head
210
+ self.trans_base = CrossCondTransBase(vqvae, num_vq, embed_dim, clip_dim, block_size, num_layers, num_local_layer, n_head, drop_out_rate, fc_rate)
211
+ self.trans_head = CrossCondTransHead(num_vq, embed_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate)
212
+ self.block_size = block_size
213
+ self.num_vq = num_vq
214
+
215
+
216
+ def get_block_size(self):
217
+ return self.block_size
218
+
219
+ def forward(self, *args, type='forward', **kwargs):
220
+ '''type=[forward, sample]'''
221
+ if type=='forward':
222
+ return self.forward_function(*args, **kwargs)
223
+ elif type=='sample':
224
+ return self.sample(*args, **kwargs)
225
+ elif type=='inpaint':
226
+ return self.inpaint(*args, **kwargs)
227
+ else:
228
+ raise ValueError(f'Unknown "{type}" type')
229
+
230
+ def get_attn_mask(self, src_mask, att_txt=None, txt_mark=None):
231
+ if att_txt is None:
232
+ att_txt = torch.tensor([[True]]*src_mask.shape[0]).to(src_mask.device)
233
+ src_mask = torch.cat([att_txt, src_mask], dim=1)
234
+ B, T = src_mask.shape
235
+ src_mask = src_mask.view(B, 1, 1, T).repeat(1, self.n_head, T, 1)
236
+ if txt_mark is not None:
237
+ att_txt_txt = torch.tensor([[True]]*txt_mark.shape[0]).to(txt_mark.device)
238
+ txt_mark = torch.cat([att_txt_txt, txt_mark], dim=1)
239
+ src_mask[:, :, :, 0] = txt_mark.view(B, 1, T).repeat(1, self.n_head, 1)
240
+ return src_mask
241
+
242
+ def forward_function(self, idx_upper, idx_lower, clip_feature, src_mask=None, att_txt=None, txt_mark=None, word_emb=None):
243
+ # MLD:
244
+ # if att_txt is None:
245
+ # att_txt = torch.tensor([[True]]*src_mask.shape[0]).to(src_mask.device)
246
+ # src_mask = torch.cat([att_txt, src_mask], dim=1)
247
+ # logits = self.skip_trans(idxs, clip_feature, src_mask)
248
+
249
+ # T2M-BD
250
+ if src_mask is not None:
251
+ src_mask = self.get_attn_mask(src_mask, att_txt, txt_mark)
252
+ feat = self.trans_base(idx_upper, idx_lower, clip_feature, src_mask, word_emb)
253
+ logits = self.trans_head(feat, src_mask)
254
+
255
+ return logits
256
+
257
+ def sample(self, clip_feature, idx_lower, word_emb, m_length=None, if_test=False, rand_pos=False, CFG=-1):
258
+ max_steps = 20
259
+ max_length = 49
260
+ batch_size = clip_feature.shape[0]
261
+ mask_id = self.num_vq + 2
262
+ pad_id = self.num_vq + 1
263
+ end_id = self.num_vq
264
+ shape = (batch_size, self.block_size - 1)
265
+ topk_filter_thres = .9
266
+ starting_temperature = 1.0
267
+ scores = torch.ones(shape, dtype = torch.float32, device = clip_feature.device)
268
+
269
+ m_tokens_len = torch.ceil((m_length)/4)
270
+ src_token_mask = generate_src_mask(self.block_size-1, m_tokens_len+1)
271
+ src_token_mask_noend = generate_src_mask(self.block_size-1, m_tokens_len)
272
+ ids = torch.full(shape, mask_id, dtype = torch.long, device = clip_feature.device)
273
+
274
+ # [TODO] confirm that these 2 lines are not neccessary (repeated below and maybe don't need them at all)
275
+ ids[~src_token_mask] = pad_id # [INFO] replace with pad id
276
+ ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
277
+
278
+ ### PlayGround ####
279
+ # score high = mask
280
+ # m_tokens_len = torch.ceil((m_length)/4)
281
+ # src_token_mask = generate_src_mask(self.block_size-1, m_tokens_len+1)
282
+
283
+ # # mock
284
+ # timestep = torch.tensor(.5)
285
+ # rand_mask_prob = cosine_schedule(timestep)
286
+ # scores = torch.arange(self.block_size - 1).repeat(batch_size, 1).cuda()
287
+ # scores[1] = torch.flip(torch.arange(self.block_size - 1), dims=(0,))
288
+
289
+ # # iteration
290
+ # num_token_masked = (rand_mask_prob * m_tokens_len).int().clip(min=1)
291
+ # scores[~src_token_mask] = -1e5
292
+ # masked_indices = scores.argsort(dim=-1, descending=True) # This is flipped the order. The highest score is the first in order.
293
+ # masked_indices = masked_indices < num_token_masked.unsqueeze(-1) # So it can filter out by "< num_token_masked". We want to filter the high score as a mask
294
+ # ids[masked_indices] = mask_id
295
+ #########################
296
+ temp = []
297
+ sample_max_steps = torch.round(max_steps/max_length*m_tokens_len) + 1e-8
298
+ for step in range(max_steps):
299
+ timestep = torch.clip(step/(sample_max_steps), max=1)
300
+ rand_mask_prob = cosine_schedule(timestep) # timestep #
301
+ num_token_masked = (rand_mask_prob * m_tokens_len).long().clip(min=1)
302
+ # [INFO] rm no motion frames
303
+ scores[~src_token_mask_noend] = 0
304
+ scores = scores/scores.sum(-1)[:, None] # normalize only unmasked token
305
+
306
+ # if rand_pos:
307
+ # sorted_score_indices = scores.multinomial(scores.shape[-1], replacement=False) # stocastic
308
+ # else:
309
+ sorted, sorted_score_indices = scores.sort(descending=True) # deterministic
310
+
311
+ ids[~src_token_mask] = pad_id # [INFO] replace with pad id
312
+ ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
313
+ ## [INFO] Replace "mask_id" to "ids" that have highest "num_token_masked" "scores"
314
+ select_masked_indices = generate_src_mask(sorted_score_indices.shape[1], num_token_masked)
315
+ # [INFO] repeat last_id to make it scatter_ the existing last ids.
316
+ last_index = sorted_score_indices.gather(-1, num_token_masked.unsqueeze(-1)-1)
317
+ sorted_score_indices = sorted_score_indices * select_masked_indices + (last_index*~select_masked_indices)
318
+ ids.scatter_(-1, sorted_score_indices, mask_id)
319
+ # if torch.isclose(timestep, torch.tensor(0.7647), atol=.01):
320
+ # print('masked_indices:', ids[0], src_token_mask[0])
321
+
322
+ if CFG!=-1:
323
+ # print('ids:', ids.shape, clip_feature.shape, src_token_mask.shape)
324
+ _ids = ids.repeat(2,1)
325
+ _clip_feature = clip_feature.repeat(2,1)
326
+ _src_token_mask = src_token_mask.repeat(2,1)
327
+ att_txt = torch.cat( (torch.ones((batch_size,1), dtype=torch.bool),
328
+ torch.zeros((batch_size,1), dtype=torch.bool) )).to(_ids.device)
329
+ logits = self.forward(_ids, idx_lower, _clip_feature, _src_token_mask, att_txt)[:,1:]
330
+ logits_textcond = logits[:batch_size]
331
+ logits_uncond = logits[batch_size:]
332
+ # logits = (1-CFG)*logits_textcond + CFG*logits_uncond
333
+ logits = (1+CFG)*logits_textcond - CFG*logits_uncond
334
+ else:
335
+ logits = self.forward(ids, idx_lower, clip_feature, src_token_mask, word_emb=word_emb)[:,1:]
336
+ filtered_logits = logits #top_p(logits, .5) # #top_k(logits, topk_filter_thres)
337
+ if rand_pos:
338
+ temperature = 1 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
339
+ else:
340
+ temperature = 0 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
341
+
342
+ # [INFO] if temperature==0: is equal to argmax (filtered_logits.argmax(dim = -1))
343
+ # pred_ids = filtered_logits.argmax(dim = -1)
344
+ pred_ids = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
345
+ is_mask = ids == mask_id
346
+ temp.append(is_mask[:1])
347
+
348
+ # mid = is_mask[0][:m_tokens_len[0].int()]
349
+ # mid = mid.nonzero(as_tuple=True)[0]
350
+ # print(is_mask[0].sum(), m_tokens_len[0])
351
+
352
+ ids = torch.where(
353
+ is_mask,
354
+ pred_ids,
355
+ ids
356
+ )
357
+
358
+ # if timestep == 1.:
359
+ # print(probs_without_temperature.shape)
360
+ probs_without_temperature = logits.softmax(dim = -1)
361
+ scores = 1 - probs_without_temperature.gather(-1, pred_ids[..., None])
362
+ scores = rearrange(scores, '... 1 -> ...')
363
+ scores = scores.masked_fill(~is_mask, 0)
364
+ if if_test:
365
+ return ids, temp
366
+ return ids
367
+
368
+ def inpaint(self, first_tokens, last_tokens, clip_feature=None, inpaint_len=2, rand_pos=False):
369
+ # support only one sample
370
+ assert first_tokens.shape[0] == 1
371
+ assert last_tokens.shape[0] == 1
372
+ max_steps = 20
373
+ max_length = 49
374
+ batch_size = first_tokens.shape[0]
375
+ mask_id = self.num_vq + 2
376
+ pad_id = self.num_vq + 1
377
+ end_id = self.num_vq
378
+ shape = (batch_size, self.block_size - 1)
379
+ scores = torch.ones(shape, dtype = torch.float32, device = first_tokens.device)
380
+
381
+ # force add first / last tokens
382
+ first_partition_pos_idx = first_tokens.shape[1]
383
+ second_partition_pos_idx = first_partition_pos_idx + inpaint_len
384
+ end_pos_idx = second_partition_pos_idx + last_tokens.shape[1]
385
+
386
+ m_tokens_len = torch.ones(batch_size, device = first_tokens.device)*end_pos_idx
387
+
388
+ src_token_mask = generate_src_mask(self.block_size-1, m_tokens_len+1)
389
+ src_token_mask_noend = generate_src_mask(self.block_size-1, m_tokens_len)
390
+ ids = torch.full(shape, mask_id, dtype = torch.long, device = first_tokens.device)
391
+
392
+ ids[:, :first_partition_pos_idx] = first_tokens
393
+ ids[:, second_partition_pos_idx:end_pos_idx] = last_tokens
394
+ src_token_mask_noend[:, :first_partition_pos_idx] = False
395
+ src_token_mask_noend[:, second_partition_pos_idx:end_pos_idx] = False
396
+
397
+ # [TODO] confirm that these 2 lines are not neccessary (repeated below and maybe don't need them at all)
398
+ ids[~src_token_mask] = pad_id # [INFO] replace with pad id
399
+ ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
400
+
401
+ temp = []
402
+ sample_max_steps = torch.round(max_steps/max_length*m_tokens_len) + 1e-8
403
+
404
+ if clip_feature is None:
405
+ clip_feature = torch.zeros(1, 512).to(first_tokens.device)
406
+ att_txt = torch.zeros((batch_size,1), dtype=torch.bool, device = first_tokens.device)
407
+ else:
408
+ att_txt = torch.ones((batch_size,1), dtype=torch.bool, device = first_tokens.device)
409
+
410
+ for step in range(max_steps):
411
+ timestep = torch.clip(step/(sample_max_steps), max=1)
412
+ rand_mask_prob = cosine_schedule(timestep) # timestep #
413
+ num_token_masked = (rand_mask_prob * m_tokens_len).long().clip(min=1)
414
+ # [INFO] rm no motion frames
415
+ scores[~src_token_mask_noend] = 0
416
+ # [INFO] rm begin and end frames
417
+ scores[:, :first_partition_pos_idx] = 0
418
+ scores[:, second_partition_pos_idx:end_pos_idx] = 0
419
+ scores = scores/scores.sum(-1)[:, None] # normalize only unmasked token
420
+
421
+ sorted, sorted_score_indices = scores.sort(descending=True) # deterministic
422
+
423
+ ids[~src_token_mask] = pad_id # [INFO] replace with pad id
424
+ ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
425
+ ## [INFO] Replace "mask_id" to "ids" that have highest "num_token_masked" "scores"
426
+ select_masked_indices = generate_src_mask(sorted_score_indices.shape[1], num_token_masked)
427
+ # [INFO] repeat last_id to make it scatter_ the existing last ids.
428
+ last_index = sorted_score_indices.gather(-1, num_token_masked.unsqueeze(-1)-1)
429
+ sorted_score_indices = sorted_score_indices * select_masked_indices + (last_index*~select_masked_indices)
430
+ ids.scatter_(-1, sorted_score_indices, mask_id)
431
+
432
+ # [TODO] force replace begin/end tokens b/c the num mask will be more than actual inpainting frames
433
+ ids[:, :first_partition_pos_idx] = first_tokens
434
+ ids[:, second_partition_pos_idx:end_pos_idx] = last_tokens
435
+
436
+ logits = self.forward(ids, clip_feature, src_token_mask, att_txt)[:,1:]
437
+ filtered_logits = logits #top_k(logits, topk_filter_thres)
438
+ if rand_pos:
439
+ temperature = 1 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
440
+ else:
441
+ temperature = 0 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
442
+
443
+ # [INFO] if temperature==0: is equal to argmax (filtered_logits.argmax(dim = -1))
444
+ # pred_ids = filtered_logits.argmax(dim = -1)
445
+ pred_ids = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
446
+ is_mask = ids == mask_id
447
+ temp.append(is_mask[:1])
448
+
449
+ ids = torch.where(
450
+ is_mask,
451
+ pred_ids,
452
+ ids
453
+ )
454
+
455
+ probs_without_temperature = logits.softmax(dim = -1)
456
+ scores = 1 - probs_without_temperature.gather(-1, pred_ids[..., None])
457
+ scores = rearrange(scores, '... 1 -> ...')
458
+ scores = scores.masked_fill(~is_mask, 0)
459
+ return ids
460
+
461
+ class Attention(nn.Module):
462
+
463
+ def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1):
464
+ super().__init__()
465
+ assert embed_dim % 8 == 0
466
+ # key, query, value projections for all heads
467
+ self.key = nn.Linear(embed_dim, embed_dim)
468
+ self.query = nn.Linear(embed_dim, embed_dim)
469
+ self.value = nn.Linear(embed_dim, embed_dim)
470
+
471
+ self.attn_drop = nn.Dropout(drop_out_rate)
472
+ self.resid_drop = nn.Dropout(drop_out_rate)
473
+
474
+ self.proj = nn.Linear(embed_dim, embed_dim)
475
+ self.n_head = n_head
476
+
477
+ def forward(self, x, src_mask):
478
+ B, T, C = x.size()
479
+
480
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
481
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
482
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
483
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
484
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
485
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
486
+ if src_mask is not None:
487
+ att[~src_mask] = float('-inf')
488
+ att = F.softmax(att, dim=-1)
489
+ att = self.attn_drop(att)
490
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
491
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
492
+
493
+ # output projection
494
+ y = self.resid_drop(self.proj(y))
495
+ return y
496
+
497
+ class Block(nn.Module):
498
+
499
+ def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1, fc_rate=4):
500
+ super().__init__()
501
+ self.ln1 = nn.LayerNorm(embed_dim)
502
+ self.ln2 = nn.LayerNorm(embed_dim)
503
+ self.attn = Attention(embed_dim, block_size, n_head, drop_out_rate)
504
+ self.mlp = nn.Sequential(
505
+ nn.Linear(embed_dim, fc_rate * embed_dim),
506
+ nn.GELU(),
507
+ nn.Linear(fc_rate * embed_dim, embed_dim),
508
+ nn.Dropout(drop_out_rate),
509
+ )
510
+
511
+ def forward(self, x, src_mask=None):
512
+ x = x + self.attn(self.ln1(x), src_mask)
513
+ x = x + self.mlp(self.ln2(x))
514
+ return x
515
+
516
+ from models.t2m_trans import Block_crossatt
517
+ class CrossCondTransBase(nn.Module):
518
+
519
+ def __init__(self,
520
+ vqvae,
521
+ num_vq=1024,
522
+ embed_dim=512,
523
+ clip_dim=512,
524
+ block_size=16,
525
+ num_layers=2,
526
+ num_local_layer = 1,
527
+ n_head=8,
528
+ drop_out_rate=0.1,
529
+ fc_rate=4):
530
+ super().__init__()
531
+ self.vqvae = vqvae
532
+ # self.tok_emb = nn.Embedding(num_vq + 3, embed_dim).requires_grad_(False)
533
+ self.learn_tok_emb = nn.Embedding(3, int(self.vqvae.vqvae.code_dim/2))# [INFO] 3 = [end_id, blank_id, mask_id]
534
+ self.to_emb = nn.Linear(self.vqvae.vqvae.code_dim, embed_dim)
535
+
536
+ self.cond_emb = nn.Linear(clip_dim, embed_dim)
537
+ self.pos_embedding = nn.Embedding(block_size, embed_dim)
538
+ self.drop = nn.Dropout(drop_out_rate)
539
+ # transformer block
540
+ self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers-num_local_layer)])
541
+ self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
542
+
543
+ self.num_local_layer = num_local_layer
544
+ if num_local_layer > 0:
545
+ self.word_emb = nn.Linear(clip_dim, embed_dim)
546
+ self.cross_att = nn.Sequential(*[Block_crossatt(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_local_layer)])
547
+ self.block_size = block_size
548
+
549
+ self.apply(self._init_weights)
550
+
551
+ def get_block_size(self):
552
+ return self.block_size
553
+
554
+ def _init_weights(self, module):
555
+ if isinstance(module, (nn.Linear, nn.Embedding)):
556
+ module.weight.data.normal_(mean=0.0, std=0.02)
557
+ if isinstance(module, nn.Linear) and module.bias is not None:
558
+ module.bias.data.zero_()
559
+ elif isinstance(module, nn.LayerNorm):
560
+ module.bias.data.zero_()
561
+ module.weight.data.fill_(1.0)
562
+
563
+ def forward(self, idx_upper, idx_lower, clip_feature, src_mask, word_emb):
564
+ if len(idx_upper) == 0:
565
+ token_embeddings = self.cond_emb(clip_feature).unsqueeze(1)
566
+ else:
567
+ b, t = idx_upper.size()
568
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
569
+ # forward the Trans model
570
+ learn_idx_upper = idx_upper>=self.vqvae.vqvae.num_code
571
+ learn_idx_lower = idx_lower>=self.vqvae.vqvae.num_code
572
+
573
+ code_dim = self.vqvae.vqvae.code_dim
574
+ token_embeddings = torch.empty((*idx_upper.shape, code_dim), device=idx_upper.device)
575
+ token_embeddings[..., :int(code_dim/2)][~learn_idx_upper] = self.vqvae.vqvae.quantizer_upper.dequantize(idx_upper[~learn_idx_upper]).requires_grad_(False)
576
+ token_embeddings[..., :int(code_dim/2)][learn_idx_upper] = self.learn_tok_emb(idx_upper[learn_idx_upper]-self.vqvae.vqvae.num_code)
577
+ token_embeddings[..., int(code_dim/2):][~learn_idx_lower] = self.vqvae.vqvae.quantizer_lower.dequantize(idx_lower[~learn_idx_lower]).requires_grad_(False)
578
+ token_embeddings[..., int(code_dim/2):][learn_idx_lower] = self.learn_tok_emb(idx_lower[learn_idx_lower]-self.vqvae.vqvae.num_code)
579
+ token_embeddings = self.to_emb(token_embeddings)
580
+
581
+ if self.num_local_layer > 0:
582
+ word_emb = self.word_emb(word_emb)
583
+ token_embeddings = self.pos_embed(token_embeddings)
584
+ for module in self.cross_att:
585
+ token_embeddings = module(token_embeddings, word_emb)
586
+ token_embeddings = torch.cat([self.cond_emb(clip_feature).unsqueeze(1), token_embeddings], dim=1)
587
+
588
+ x = self.pos_embed(token_embeddings)
589
+ for block in self.blocks:
590
+ x = block(x, src_mask)
591
+
592
+ return x
593
+
594
+
595
+ class CrossCondTransHead(nn.Module):
596
+
597
+ def __init__(self,
598
+ num_vq=1024,
599
+ embed_dim=512,
600
+ block_size=16,
601
+ num_layers=2,
602
+ n_head=8,
603
+ drop_out_rate=0.1,
604
+ fc_rate=4):
605
+ super().__init__()
606
+
607
+ self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
608
+ self.ln_f = nn.LayerNorm(embed_dim)
609
+ self.head = nn.Linear(embed_dim, num_vq, bias=False)
610
+ self.block_size = block_size
611
+
612
+ self.apply(self._init_weights)
613
+
614
+ def get_block_size(self):
615
+ return self.block_size
616
+
617
+ def _init_weights(self, module):
618
+ if isinstance(module, (nn.Linear, nn.Embedding)):
619
+ module.weight.data.normal_(mean=0.0, std=0.02)
620
+ if isinstance(module, nn.Linear) and module.bias is not None:
621
+ module.bias.data.zero_()
622
+ elif isinstance(module, nn.LayerNorm):
623
+ module.bias.data.zero_()
624
+ module.weight.data.fill_(1.0)
625
+
626
+ def forward(self, x, src_mask):
627
+ for block in self.blocks:
628
+ x = block(x, src_mask)
629
+ x = self.ln_f(x)
630
+ logits = self.head(x)
631
+ return logits
632
+
633
+
634
+
635
+
636
+
637
+
models/vqvae.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.encdec import Encoder, Decoder
3
+ from models.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset
4
+ from models.t2m_trans import Decoder_Transformer, Encoder_Transformer
5
+ from exit.utils import generate_src_mask
6
+
7
+ class VQVAE_251(nn.Module):
8
+ def __init__(self,
9
+ args,
10
+ nb_code=1024,
11
+ code_dim=512,
12
+ output_emb_width=512,
13
+ down_t=3,
14
+ stride_t=2,
15
+ width=512,
16
+ depth=3,
17
+ dilation_growth_rate=3,
18
+ activation='relu',
19
+ norm=None):
20
+
21
+ super().__init__()
22
+ self.code_dim = code_dim
23
+ self.num_code = nb_code
24
+ self.quant = args.quantizer
25
+ output_dim = 251 if args.dataname == 'kit' else 263
26
+ self.encoder = Encoder(output_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
27
+
28
+ # Transformer Encoder
29
+ # self.encoder = Encoder_Transformer(
30
+ # input_feats=output_dim,
31
+ # embed_dim=512, # 1024
32
+ # output_dim=512,
33
+ # block_size=4,
34
+ # num_layers=6,
35
+ # n_head=16
36
+ # )
37
+
38
+ # Transformer Encoder 4 frames
39
+ # from exit.motiontransformer import MotionTransformerEncoder
40
+ # in_feature = 251 if args.dataname == 'kit' else 263
41
+ # self.encoder2 = MotionTransformerEncoder(in_feature, args.code_dim, num_frames=4, num_layers=2)
42
+
43
+ self.decoder = Decoder(output_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
44
+ # self.decoder = Decoder_Transformer(
45
+ # code_dim=512,
46
+ # embed_dim=512, # 1024
47
+ # output_dim=output_dim,
48
+ # block_size=49,
49
+ # num_layers=6,
50
+ # n_head=8
51
+ # )
52
+ if args.quantizer == "ema_reset":
53
+ self.quantizer = QuantizeEMAReset(nb_code, code_dim, args)
54
+ elif args.quantizer == "orig":
55
+ self.quantizer = Quantizer(nb_code, code_dim, 1.0)
56
+ elif args.quantizer == "ema":
57
+ self.quantizer = QuantizeEMA(nb_code, code_dim, args)
58
+ elif args.quantizer == "reset":
59
+ self.quantizer = QuantizeReset(nb_code, code_dim, args)
60
+
61
+
62
+ def preprocess(self, x):
63
+ # (bs, T, Jx3) -> (bs, Jx3, T)
64
+ x = x.permute(0,2,1).float()
65
+ return x
66
+
67
+
68
+ def postprocess(self, x):
69
+ # (bs, Jx3, T) -> (bs, T, Jx3)
70
+ x = x.permute(0,2,1)
71
+ return x
72
+
73
+
74
+ def encode(self, x):
75
+ N, T, _ = x.shape
76
+ x_in = self.preprocess(x)
77
+ x_encoder = self.encoder(x_in)
78
+ x_encoder = self.postprocess(x_encoder)
79
+ x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1]) # (NT, C)
80
+ code_idx = self.quantizer.quantize(x_encoder)
81
+ code_idx = code_idx.view(N, -1)
82
+ return code_idx
83
+
84
+
85
+ def forward(self, x):
86
+
87
+ x_in = self.preprocess(x)
88
+ # Encode
89
+ # _x_in = x_in.reshape( int(x_in.shape[0]*4), x_in.shape[1], 16)
90
+ # x_encoder = self.encoder(_x_in)
91
+ # x_encoder = x_encoder.reshape(x_in.shape[0], -1, int(x_in.shape[2]/4))
92
+
93
+ # [Transformer Encoder]
94
+ # _x_in = x_in.reshape( int(x_in.shape[0]*x_in.shape[2]/4), x_in.shape[1], 4)
95
+ # _x_in = _x_in.permute(0,2,1)
96
+ # x_encoder = self.encoder2(_x_in)
97
+ # x_encoder = x_encoder.permute(0,2,1)
98
+ # x_encoder = x_encoder.reshape(x_in.shape[0], -1, int(x_in.shape[2]/4))
99
+
100
+ x_encoder = self.encoder(x_in)
101
+
102
+ ## quantization
103
+ x_quantized, loss, perplexity = self.quantizer(x_encoder)
104
+
105
+ ## decoder
106
+ x_decoder = self.decoder(x_quantized)
107
+ x_out = self.postprocess(x_decoder)
108
+ return x_out, loss, perplexity
109
+
110
+
111
+ def forward_decoder(self, x):
112
+ # x = x.clone()
113
+ # pad_mask = x >= self.code_dim
114
+ # x[pad_mask] = 0
115
+
116
+ x_d = self.quantizer.dequantize(x)
117
+ x_d = x_d.permute(0, 2, 1).contiguous()
118
+
119
+ # pad_mask = pad_mask.unsqueeze(1)
120
+ # x_d = x_d * ~pad_mask
121
+
122
+ # decoder
123
+ x_decoder = self.decoder(x_d)
124
+ x_out = self.postprocess(x_decoder)
125
+ return x_out
126
+
127
+
128
+
129
+ class HumanVQVAE(nn.Module):
130
+ def __init__(self,
131
+ args,
132
+ nb_code=512,
133
+ code_dim=512,
134
+ output_emb_width=512,
135
+ down_t=3,
136
+ stride_t=2,
137
+ width=512,
138
+ depth=3,
139
+ dilation_growth_rate=3,
140
+ activation='relu',
141
+ norm=None):
142
+
143
+ super().__init__()
144
+
145
+ self.nb_joints = 21 if args.dataname == 'kit' else 22
146
+ self.vqvae = VQVAE_251(args, nb_code, code_dim, code_dim, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
147
+
148
+ def forward(self, x, type='full'):
149
+ '''type=[full, encode, decode]'''
150
+ if type=='full':
151
+ x_out, loss, perplexity = self.vqvae(x)
152
+ return x_out, loss, perplexity
153
+ elif type=='encode':
154
+ b, t, c = x.size()
155
+ quants = self.vqvae.encode(x) # (N, T)
156
+ return quants
157
+ elif type=='decode':
158
+ x_out = self.vqvae.forward_decoder(x)
159
+ return x_out
160
+ else:
161
+ raise ValueError(f'Unknown "{type}" type')
162
+
models/vqvae_sep.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.encdec import Encoder, Decoder
3
+ from models.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset
4
+ from models.t2m_trans import Decoder_Transformer, Encoder_Transformer
5
+ from exit.utils import generate_src_mask
6
+ import torch
7
+ from utils.humanml_utils import HML_UPPER_BODY_MASK, HML_LOWER_BODY_MASK, UPPER_JOINT_Y_MASK
8
+
9
+ class VQVAE_SEP(nn.Module):
10
+ def __init__(self,
11
+ args,
12
+ nb_code=512,
13
+ code_dim=512,
14
+ output_emb_width=512,
15
+ down_t=3,
16
+ stride_t=2,
17
+ width=512,
18
+ depth=3,
19
+ dilation_growth_rate=3,
20
+ activation='relu',
21
+ norm=None,
22
+ moment=None,
23
+ sep_decoder=False):
24
+ super().__init__()
25
+ if args.dataname == 'kit':
26
+ self.nb_joints = 21
27
+ output_dim = 251
28
+ upper_dim = 120
29
+ lower_dim = 131
30
+ else:
31
+ self.nb_joints = 22
32
+ output_dim = 263
33
+ upper_dim = 156
34
+ lower_dim = 107
35
+ self.code_dim = code_dim
36
+ if moment is not None:
37
+ self.moment = moment
38
+ self.register_buffer('mean_upper', torch.tensor([0.1216, 0.2488, 0.2967, 0.5027, 0.4053, 0.4100, 0.5703, 0.4030, 0.4078, 0.1994, 0.1992, 0.0661, 0.0639], dtype=torch.float32))
39
+ self.register_buffer('std_upper', torch.tensor([0.0164, 0.0412, 0.0523, 0.0864, 0.0695, 0.0703, 0.1108, 0.0853, 0.0847, 0.1289, 0.1291, 0.2463, 0.2484], dtype=torch.float32))
40
+ # self.quantizer = QuantizeEMAReset(nb_code, code_dim, args)
41
+
42
+ # self.encoder = Encoder(output_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
43
+ self.sep_decoder = sep_decoder
44
+ if self.sep_decoder:
45
+ self.decoder_upper = Decoder(upper_dim, int(code_dim/2), down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
46
+ self.decoder_lower = Decoder(lower_dim, int(code_dim/2), down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
47
+ else:
48
+ self.decoder = Decoder(output_dim, code_dim, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
49
+
50
+
51
+ self.num_code = nb_code
52
+
53
+ self.encoder_upper = Encoder(upper_dim, int(code_dim/2), down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
54
+ self.encoder_lower = Encoder(lower_dim, int(code_dim/2), down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
55
+ self.quantizer_upper = QuantizeEMAReset(nb_code, int(code_dim/2), args)
56
+ self.quantizer_lower = QuantizeEMAReset(nb_code, int(code_dim/2), args)
57
+
58
+ def rand_emb_idx(self, x_quantized, quantizer, idx_noise):
59
+ # x_quantized = x_quantized.detach()
60
+ x_quantized = x_quantized.permute(0,2,1)
61
+ mask = torch.bernoulli(idx_noise * torch.ones((*x_quantized.shape[:2], 1),
62
+ device=x_quantized.device))
63
+ r_indices = torch.randint(int(self.num_code/2), x_quantized.shape[:2], device=x_quantized.device)
64
+ r_emb = quantizer.dequantize(r_indices)
65
+ x_quantized = mask * r_emb + (1-mask) * x_quantized
66
+ x_quantized = x_quantized.permute(0,2,1)
67
+ return x_quantized
68
+
69
+ def normalize(self, data):
70
+ return (data - self.moment['mean']) / self.moment['std']
71
+
72
+ def denormalize(self, data):
73
+ return data * self.moment['std'] + self.moment['mean']
74
+
75
+ def normalize_upper(self, data):
76
+ return (data - self.mean_upper) / self.std_upper
77
+
78
+ def denormalize_upper(self, data):
79
+ return data * self.std_upper + self.mean_upper
80
+
81
+ def shift_upper_down(self, data):
82
+ data = data.clone()
83
+ data = self.denormalize(data)
84
+ shift_y = data[..., 3:4].clone()
85
+ data[..., UPPER_JOINT_Y_MASK] -= shift_y
86
+ _data = data.clone()
87
+ data = self.normalize(data)
88
+ data[..., UPPER_JOINT_Y_MASK] = self.normalize_upper(_data[..., UPPER_JOINT_Y_MASK])
89
+ return data
90
+
91
+ def shift_upper_up(self, data):
92
+ _data = data.clone()
93
+ data = self.denormalize(data)
94
+ data[..., UPPER_JOINT_Y_MASK] = self.denormalize_upper(_data[..., UPPER_JOINT_Y_MASK])
95
+ shift_y = data[..., 3:4].clone()
96
+ data[..., UPPER_JOINT_Y_MASK] += shift_y
97
+ data = self.normalize(data)
98
+ return data
99
+
100
+ def forward(self, x, *args, type='full', **kwargs):
101
+ '''type=[full, encode, decode]'''
102
+ if type=='full':
103
+ x = x.float()
104
+ x = self.shift_upper_down(x)
105
+
106
+ upper_emb = x[..., HML_UPPER_BODY_MASK]
107
+ lower_emb = x[..., HML_LOWER_BODY_MASK]
108
+ upper_emb = self.preprocess(upper_emb)
109
+ upper_emb = self.encoder_upper(upper_emb)
110
+ upper_emb, loss_upper, perplexity = self.quantizer_upper(upper_emb)
111
+
112
+ lower_emb = self.preprocess(lower_emb)
113
+ lower_emb = self.encoder_lower(lower_emb)
114
+ lower_emb, loss_lower, perplexity = self.quantizer_lower(lower_emb)
115
+ loss = loss_upper + loss_lower
116
+
117
+ if 'idx_noise' in kwargs and kwargs['idx_noise'] > 0:
118
+ upper_emb = self.rand_emb_idx(upper_emb, self.quantizer_upper, kwargs['idx_noise'])
119
+ lower_emb = self.rand_emb_idx(lower_emb, self.quantizer_lower, kwargs['idx_noise'])
120
+
121
+
122
+ # x_in = self.preprocess(x)
123
+ # x_encoder = self.encoder(x_in)
124
+
125
+ # ## quantization
126
+ # x_quantized, loss, perplexity = self.quantizer(x_encoder)
127
+
128
+ ## decoder
129
+ if self.sep_decoder:
130
+ x_decoder_upper = self.decoder_upper(upper_emb)
131
+ x_decoder_upper = self.postprocess(x_decoder_upper)
132
+ x_decoder_lower = self.decoder_lower(lower_emb)
133
+ x_decoder_lower = self.postprocess(x_decoder_lower)
134
+ x_out = merge_upper_lower(x_decoder_upper, x_decoder_lower)
135
+ x_out = self.shift_upper_up(x_out)
136
+
137
+ else:
138
+ x_quantized = torch.cat([upper_emb, lower_emb], dim=1)
139
+ x_decoder = self.decoder(x_quantized)
140
+ x_out = self.postprocess(x_decoder)
141
+
142
+ return x_out, loss, perplexity
143
+ elif type=='encode':
144
+ N, T, _ = x.shape
145
+ x = self.shift_upper_down(x)
146
+
147
+ upper_emb = x[..., HML_UPPER_BODY_MASK]
148
+ upper_emb = self.preprocess(upper_emb)
149
+ upper_emb = self.encoder_upper(upper_emb)
150
+ upper_emb = self.postprocess(upper_emb)
151
+ upper_emb = upper_emb.reshape(-1, upper_emb.shape[-1])
152
+ upper_code_idx = self.quantizer_upper.quantize(upper_emb)
153
+ upper_code_idx = upper_code_idx.view(N, -1)
154
+
155
+ lower_emb = x[..., HML_LOWER_BODY_MASK]
156
+ lower_emb = self.preprocess(lower_emb)
157
+ lower_emb = self.encoder_lower(lower_emb)
158
+ lower_emb = self.postprocess(lower_emb)
159
+ lower_emb = lower_emb.reshape(-1, lower_emb.shape[-1])
160
+ lower_code_idx = self.quantizer_lower.quantize(lower_emb)
161
+ lower_code_idx = lower_code_idx.view(N, -1)
162
+
163
+ code_idx = torch.cat([upper_code_idx.unsqueeze(-1), lower_code_idx.unsqueeze(-1)], dim=-1)
164
+ return code_idx
165
+
166
+ elif type=='decode':
167
+ if self.sep_decoder:
168
+ x_d_upper = self.quantizer_upper.dequantize(x[..., 0])
169
+ x_d_upper = x_d_upper.permute(0, 2, 1).contiguous()
170
+ x_d_upper = self.decoder_upper(x_d_upper)
171
+ x_d_upper = self.postprocess(x_d_upper)
172
+
173
+ x_d_lower = self.quantizer_lower.dequantize(x[..., 1])
174
+ x_d_lower = x_d_lower.permute(0, 2, 1).contiguous()
175
+ x_d_lower = self.decoder_lower(x_d_lower)
176
+ x_d_lower = self.postprocess(x_d_lower)
177
+
178
+ x_out = merge_upper_lower(x_d_upper, x_d_lower)
179
+ x_out = self.shift_upper_up(x_out)
180
+ return x_out
181
+ else:
182
+ x_d_upper = self.quantizer_upper.dequantize(x[..., 0])
183
+ x_d_lower = self.quantizer_lower.dequantize(x[..., 1])
184
+ x_d = torch.cat([x_d_upper, x_d_lower], dim=-1)
185
+ x_d = x_d.permute(0, 2, 1).contiguous()
186
+ x_decoder = self.decoder(x_d)
187
+ x_out = self.postprocess(x_decoder)
188
+ return x_out
189
+
190
+ def preprocess(self, x):
191
+ # (bs, T, Jx3) -> (bs, Jx3, T)
192
+ x = x.permute(0,2,1).float()
193
+ return x
194
+
195
+ def postprocess(self, x):
196
+ # (bs, Jx3, T) -> (bs, T, Jx3)
197
+ x = x.permute(0,2,1)
198
+ return x
199
+
200
+
201
+ def merge_upper_lower(upper_emb, lower_emb):
202
+ motion = torch.empty(*upper_emb.shape[:2], 263).to(upper_emb.device)
203
+ motion[..., HML_UPPER_BODY_MASK] = upper_emb
204
+ motion[..., HML_LOWER_BODY_MASK] = lower_emb
205
+ return motion
206
+
207
+ def upper_lower_sep(motion, joints_num):
208
+ # root
209
+ _root = motion[..., :4] # root
210
+
211
+ # position
212
+ start_indx = 1 + 2 + 1
213
+ end_indx = start_indx + (joints_num - 1) * 3
214
+ positions = motion[..., start_indx:end_indx]
215
+ positions = positions.view(*motion.shape[:2], (joints_num - 1), 3)
216
+
217
+ # 6drot
218
+ start_indx = end_indx
219
+ end_indx = start_indx + (joints_num - 1) * 6
220
+ _6d_rot = motion[..., start_indx:end_indx]
221
+ _6d_rot = _6d_rot.view(*motion.shape[:2], (joints_num - 1), 6)
222
+
223
+ # joint_velo
224
+ start_indx = end_indx
225
+ end_indx = start_indx + joints_num * 3
226
+ joint_velo = motion[..., start_indx:end_indx]
227
+ joint_velo = joint_velo.view(*motion.shape[:2], joints_num, 3)
228
+
229
+ # foot_contact
230
+ foot_contact = motion[..., end_indx:]
231
+
232
+ ################################################################################################
233
+ #### Lower Body
234
+ if joints_num == 22:
235
+ lower_body = torch.tensor([0,1,2,4,5,7,8,10,11])
236
+ else:
237
+ lower_body = torch.tensor([0, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])
238
+ lower_body_exclude_root = lower_body[1:] - 1
239
+
240
+ LOW_positions = positions[:,:, lower_body_exclude_root].view(*motion.shape[:2], -1)
241
+ LOW_6d_rot = _6d_rot[:,:, lower_body_exclude_root].view(*motion.shape[:2], -1)
242
+ LOW_joint_velo = joint_velo[:,:, lower_body].view(*motion.shape[:2], -1)
243
+ lower_emb = torch.cat([_root, LOW_positions, LOW_6d_rot, LOW_joint_velo, foot_contact], dim=-1)
244
+
245
+ #### Upper Body
246
+ if joints_num == 22:
247
+ upper_body = torch.tensor([3,6,9,12,13,14,15,16,17,18,19,20,21])
248
+ else:
249
+ upper_body = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
250
+ upper_body_exclude_root = upper_body - 1
251
+
252
+ UP_positions = positions[:,:, upper_body_exclude_root].view(*motion.shape[:2], -1)
253
+ UP_6d_rot = _6d_rot[:,:, upper_body_exclude_root].view(*motion.shape[:2], -1)
254
+ UP_joint_velo = joint_velo[:,:, upper_body].view(*motion.shape[:2], -1)
255
+ upper_emb = torch.cat([UP_positions, UP_6d_rot, UP_joint_velo], dim=-1)
256
+
257
+ return upper_emb, lower_emb
options/get_eval_option.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import re
3
+ from os.path import join as pjoin
4
+
5
+
6
+ def is_float(numStr):
7
+ flag = False
8
+ numStr = str(numStr).strip().lstrip('-').lstrip('+')
9
+ try:
10
+ reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$')
11
+ res = reg.match(str(numStr))
12
+ if res:
13
+ flag = True
14
+ except Exception as ex:
15
+ print("is_float() - error: " + str(ex))
16
+ return flag
17
+
18
+
19
+ def is_number(numStr):
20
+ flag = False
21
+ numStr = str(numStr).strip().lstrip('-').lstrip('+')
22
+ if str(numStr).isdigit():
23
+ flag = True
24
+ return flag
25
+
26
+
27
+ def get_opt(opt_path, device):
28
+ opt = Namespace()
29
+ opt_dict = vars(opt)
30
+
31
+ skip = ('-------------- End ----------------',
32
+ '------------ Options -------------',
33
+ '\n')
34
+ print('Reading', opt_path)
35
+ with open(opt_path) as f:
36
+ for line in f:
37
+ if line.strip() not in skip:
38
+ # print(line.strip())
39
+ key, value = line.strip().split(': ')
40
+ if value in ('True', 'False'):
41
+ opt_dict[key] = (value == 'True')
42
+ # print(key, value)
43
+ elif is_float(value):
44
+ opt_dict[key] = float(value)
45
+ elif is_number(value):
46
+ opt_dict[key] = int(value)
47
+ else:
48
+ opt_dict[key] = str(value)
49
+
50
+ # print(opt)
51
+ opt_dict['which_epoch'] = 'finest'
52
+ opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
53
+ opt.model_dir = pjoin(opt.save_root, 'model')
54
+ opt.meta_dir = pjoin(opt.save_root, 'meta')
55
+
56
+ if opt.dataset_name == 't2m':
57
+ opt.data_root = './dataset/HumanML3D/'
58
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
59
+ opt.text_dir = pjoin(opt.data_root, 'texts')
60
+ opt.joints_num = 22
61
+ opt.dim_pose = 263
62
+ opt.max_motion_length = 196
63
+ opt.max_motion_frame = 196
64
+ opt.max_motion_token = 55
65
+ elif opt.dataset_name == 'kit':
66
+ opt.data_root = './dataset/KIT-ML/'
67
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
68
+ opt.text_dir = pjoin(opt.data_root, 'texts')
69
+ opt.joints_num = 21
70
+ opt.dim_pose = 251
71
+ opt.max_motion_length = 196
72
+ opt.max_motion_frame = 196
73
+ opt.max_motion_token = 55
74
+ else:
75
+ raise KeyError('Dataset not recognized')
76
+
77
+ opt.dim_word = 300
78
+ opt.num_classes = 200 // opt.unit_length
79
+ opt.is_train = False
80
+ opt.is_continue = False
81
+ opt.device = device
82
+
83
+ return opt
options/option_transformer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def get_args_parser():
4
+ parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for Amass',
5
+ add_help=True,
6
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7
+
8
+ ## dataloader
9
+
10
+ parser.add_argument('--dataname', type=str, default='t2m', help='dataset directory')
11
+ parser.add_argument('--batch-size', default=128, type=int, help='batch size')
12
+ parser.add_argument('--fps', default=[20], nargs="+", type=int, help='frames per second')
13
+ parser.add_argument('--seq-len', type=int, default=64, help='training motion length')
14
+
15
+ ## optimization
16
+ parser.add_argument('--total-iter', default=300000, type=int, help='number of total iterations to run')
17
+ parser.add_argument('--warm-up-iter', default=1000, type=int, help='number of total iterations for warmup')
18
+ parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate')
19
+ parser.add_argument('--lr-scheduler', default=[150000], nargs="+", type=int, help="learning rate schedule (iterations)")
20
+ parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay")
21
+
22
+ parser.add_argument('--weight-decay', default=1e-6, type=float, help='weight decay')
23
+ parser.add_argument('--decay-option',default='all', type=str, choices=['all', 'noVQ'], help='disable weight decay on codebook')
24
+ parser.add_argument('--optimizer',default='adamw', type=str, choices=['adam', 'adamw'], help='disable weight decay on codebook')
25
+
26
+ ## vqvae arch
27
+ parser.add_argument("--code-dim", type=int, default=32, help="embedding dimension")
28
+ parser.add_argument("--nb-code", type=int, default=8192, help="nb of embedding")
29
+ parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook")
30
+ parser.add_argument("--down-t", type=int, default=2, help="downsampling rate")
31
+ parser.add_argument("--stride-t", type=int, default=2, help="stride size")
32
+ parser.add_argument("--width", type=int, default=512, help="width of the network")
33
+ parser.add_argument("--depth", type=int, default=3, help="depth of the network")
34
+ parser.add_argument("--dilation-growth-rate", type=int, default=3, help="dilation growth rate")
35
+ parser.add_argument("--output-emb-width", type=int, default=512, help="output embedding width")
36
+ parser.add_argument('--vq-act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory')
37
+
38
+ ## gpt arch
39
+ parser.add_argument("--block-size", type=int, default=51, help="seq len")
40
+ parser.add_argument("--embed-dim-gpt", type=int, default=1024, help="embedding dimension")
41
+ parser.add_argument("--clip-dim", type=int, default=512, help="latent dimension in the clip feature")
42
+ parser.add_argument("--num-layers", type=int, default=9, help="nb of transformer layers")
43
+ parser.add_argument("--num-local-layer", type=int, default=2, help="nb of transformer local layers")
44
+ parser.add_argument("--n-head-gpt", type=int, default=16, help="nb of heads")
45
+ parser.add_argument("--ff-rate", type=int, default=4, help="feedforward size")
46
+ parser.add_argument("--drop-out-rate", type=float, default=0.1, help="dropout ratio in the pos encoding")
47
+
48
+ ## quantizer
49
+ parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport")
50
+ parser.add_argument('--quantbeta', type=float, default=1.0, help='dataset directory')
51
+
52
+ ## resume
53
+ parser.add_argument("--resume-pth", type=str, default=None, help='resume vq pth')
54
+ parser.add_argument("--resume-trans", type=str, default=None, help='resume gpt pth')
55
+
56
+
57
+ ## output directory
58
+ parser.add_argument('--out-dir', type=str, default='output', help='output directory')
59
+ parser.add_argument('--exp-name', type=str, default='exp_debug', help='name of the experiment, will create a file inside out-dir')
60
+ parser.add_argument('--vq-name', type=str, default='VQVAE', help='name of the generated dataset .npy, will create a file inside out-dir')
61
+ ## other
62
+ parser.add_argument('--print-iter', default=200, type=int, help='print frequency')
63
+ parser.add_argument('--eval-iter', default=10000, type=int, help='evaluation frequency')
64
+ parser.add_argument('--seed', default=123, type=int, help='seed for initializing training. ')
65
+ parser.add_argument("--if-maxtest", action='store_true', help="test in max")
66
+ parser.add_argument('--pkeep', type=float, default=.5, help='keep rate for gpt training')
67
+
68
+ ## generator
69
+ parser.add_argument('--text', type=str, help='text')
70
+ parser.add_argument('--length', type=int, help='length')
71
+
72
+ return parser.parse_args()
options/option_vq.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def get_args_parser():
4
+ parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for AIST',
5
+ add_help=True,
6
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7
+
8
+ ## dataloader
9
+ parser.add_argument('--dataname', type=str, default='kit', help='dataset directory')
10
+ parser.add_argument('--batch-size', default=256, type=int, help='batch size')
11
+ parser.add_argument('--window-size', type=int, default=64, help='training motion length')
12
+
13
+ ## optimization
14
+ parser.add_argument('--total-iter', default=300000, type=int, help='number of total iterations to run')
15
+ parser.add_argument('--warm-up-iter', default=1000, type=int, help='number of total iterations for warmup')
16
+ parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate')
17
+ parser.add_argument('--lr-scheduler', default=[200000], nargs="+", type=int, help="learning rate schedule (iterations)")
18
+ parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay")
19
+
20
+ parser.add_argument('--weight-decay', default=0.0, type=float, help='weight decay')
21
+ parser.add_argument("--commit", type=float, default=0.02, help="hyper-parameter for the commitment loss")
22
+ parser.add_argument('--loss-vel', type=float, default=0.5, help='hyper-parameter for the velocity loss')
23
+ parser.add_argument('--recons-loss', type=str, default='l1_smooth', help='reconstruction loss')
24
+
25
+ ## vqvae arch
26
+ parser.add_argument("--code-dim", type=int, default=32, help="embedding dimension")
27
+ parser.add_argument("--nb-code", type=int, default=8192, help="nb of embedding")
28
+ parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook")
29
+ parser.add_argument("--down-t", type=int, default=2, help="downsampling rate")
30
+ parser.add_argument("--stride-t", type=int, default=2, help="stride size")
31
+ parser.add_argument("--width", type=int, default=512, help="width of the network")
32
+ parser.add_argument("--depth", type=int, default=3, help="depth of the network")
33
+ parser.add_argument("--dilation-growth-rate", type=int, default=3, help="dilation growth rate")
34
+ parser.add_argument("--output-emb-width", type=int, default=512, help="output embedding width")
35
+ parser.add_argument('--vq-act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory')
36
+ parser.add_argument('--vq-norm', type=str, default=None, help='dataset directory')
37
+
38
+ ## quantizer
39
+ parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport")
40
+ parser.add_argument('--beta', type=float, default=1.0, help='commitment loss in standard VQ')
41
+
42
+ ## resume
43
+ parser.add_argument("--resume-pth", type=str, default=None, help='resume pth for VQ')
44
+ parser.add_argument("--resume-gpt", type=str, default=None, help='resume pth for GPT')
45
+
46
+
47
+ ## output directory
48
+ parser.add_argument('--out-dir', type=str, default='output', help='output directory')
49
+ parser.add_argument('--results-dir', type=str, default='visual_results/', help='output directory')
50
+ parser.add_argument('--visual-name', type=str, default='baseline', help='output directory')
51
+ parser.add_argument('--exp-name', type=str, default='exp_debug', help='name of the experiment, will create a file inside out-dir')
52
+ ## other
53
+ parser.add_argument('--print-iter', default=200, type=int, help='print frequency')
54
+ parser.add_argument('--eval-iter', default=5000, type=int, help='evaluation frequency')
55
+ parser.add_argument('--seed', default=123, type=int, help='seed for initializing training.')
56
+
57
+ parser.add_argument('--vis-gt', action='store_true', help='whether visualize GT motions')
58
+ parser.add_argument('--nb-vis', default=20, type=int, help='nb of visualizations')
59
+
60
+ parser.add_argument('--sep-uplow', action='store_true', help='whether visualize GT motions')
61
+
62
+ return parser.parse_args()
train_t2m_trans.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ from os.path import join as pjoin
7
+ from torch.distributions import Categorical
8
+ import json
9
+ import clip
10
+
11
+ import options.option_transformer as option_trans
12
+ import models.vqvae as vqvae
13
+ import utils.utils_model as utils_model
14
+ import utils.eval_trans as eval_trans
15
+ from dataset import dataset_TM_train
16
+ from dataset import dataset_TM_eval
17
+ from dataset import dataset_tokenize
18
+ import models.t2m_trans as trans
19
+ from options.get_eval_option import get_opt
20
+ from models.evaluator_wrapper import EvaluatorModelWrapper
21
+ import warnings
22
+ warnings.filterwarnings('ignore')
23
+ from exit.utils import get_model, visualize_2motions
24
+ from tqdm import tqdm
25
+ from exit.utils import get_model, visualize_2motions, generate_src_mask, init_save_folder, uniform, cosine_schedule
26
+ from einops import rearrange, repeat
27
+ import torch.nn.functional as F
28
+ from exit.utils import base_dir
29
+
30
+ ##### ---- Exp dirs ---- #####
31
+ args = option_trans.get_args_parser()
32
+ torch.manual_seed(args.seed)
33
+
34
+ # args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
35
+ init_save_folder(args)
36
+
37
+ # [TODO] make the 'output/' folder as arg
38
+ args.vq_dir = f'./output/vq/{args.vq_name}' #os.path.join("./dataset/KIT-ML" if args.dataname == 'kit' else "./dataset/HumanML3D", f'{args.vq_name}')
39
+ codebook_dir = f'{args.vq_dir}/codebook/'
40
+ args.resume_pth = f'{args.vq_dir}/net_last.pth'
41
+ os.makedirs(args.vq_dir, exist_ok = True)
42
+ os.makedirs(codebook_dir, exist_ok = True)
43
+ os.makedirs(args.out_dir, exist_ok = True)
44
+ os.makedirs(args.out_dir+'/html', exist_ok=True)
45
+
46
+ ##### ---- Logger ---- #####
47
+ logger = utils_model.get_logger(args.out_dir)
48
+ writer = SummaryWriter(args.out_dir)
49
+ logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
50
+
51
+
52
+ from utils.word_vectorizer import WordVectorizer
53
+ w_vectorizer = WordVectorizer('./glove', 'our_vab')
54
+ val_loader = dataset_TM_eval.DATALoader(args.dataname, False, 32, w_vectorizer)
55
+
56
+ dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataname == 'kit' else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
57
+
58
+ wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
59
+ eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
60
+
61
+ ##### ---- Network ---- #####
62
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False) # Must set jit=False for training
63
+ clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16
64
+ clip_model.eval()
65
+ for p in clip_model.parameters():
66
+ p.requires_grad = False
67
+
68
+ # https://github.com/openai/CLIP/issues/111
69
+ class TextCLIP(torch.nn.Module):
70
+ def __init__(self, model) :
71
+ super(TextCLIP, self).__init__()
72
+ self.model = model
73
+
74
+ def forward(self,text):
75
+ with torch.no_grad():
76
+ word_emb = self.model.token_embedding(text).type(self.model.dtype)
77
+ word_emb = word_emb + self.model.positional_embedding.type(self.model.dtype)
78
+ word_emb = word_emb.permute(1, 0, 2) # NLD -> LND
79
+ word_emb = self.model.transformer(word_emb)
80
+ word_emb = self.model.ln_final(word_emb).permute(1, 0, 2).float()
81
+ enctxt = self.model.encode_text(text).float()
82
+ return enctxt, word_emb
83
+ clip_model = TextCLIP(clip_model)
84
+
85
+ net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
86
+ args.nb_code,
87
+ args.code_dim,
88
+ args.output_emb_width,
89
+ args.down_t,
90
+ args.stride_t,
91
+ args.width,
92
+ args.depth,
93
+ args.dilation_growth_rate)
94
+
95
+
96
+ trans_encoder = trans.Text2Motion_Transformer(vqvae=net,
97
+ num_vq=args.nb_code,
98
+ embed_dim=args.embed_dim_gpt,
99
+ clip_dim=args.clip_dim,
100
+ block_size=args.block_size,
101
+ num_layers=args.num_layers,
102
+ num_local_layer=args.num_local_layer,
103
+ n_head=args.n_head_gpt,
104
+ drop_out_rate=args.drop_out_rate,
105
+ fc_rate=args.ff_rate)
106
+
107
+ print ('loading checkpoint from {}'.format(args.resume_pth))
108
+ ckpt = torch.load(args.resume_pth, map_location='cpu')
109
+ net.load_state_dict(ckpt['net'], strict=True)
110
+ net.eval()
111
+ net.cuda()
112
+
113
+ if args.resume_trans is not None:
114
+ print ('loading transformer checkpoint from {}'.format(args.resume_trans))
115
+ ckpt = torch.load(args.resume_trans, map_location='cpu')
116
+ trans_encoder.load_state_dict(ckpt['trans'], strict=True)
117
+ trans_encoder.train()
118
+ trans_encoder.cuda()
119
+ trans_encoder = torch.nn.DataParallel(trans_encoder)
120
+
121
+ ##### ---- Optimizer & Scheduler ---- #####
122
+ optimizer = utils_model.initial_optim(args.decay_option, args.lr, args.weight_decay, trans_encoder, args.optimizer)
123
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma)
124
+
125
+ ##### ---- Optimization goals ---- #####
126
+ loss_ce = torch.nn.CrossEntropyLoss(reduction='none')
127
+
128
+ ##### ---- get code ---- #####
129
+ ##### ---- Dataloader ---- #####
130
+ if len(os.listdir(codebook_dir)) == 0:
131
+ train_loader_token = dataset_tokenize.DATALoader(args.dataname, 1, unit_length=2**args.down_t)
132
+ for batch in train_loader_token:
133
+ pose, name = batch
134
+ bs, seq = pose.shape[0], pose.shape[1]
135
+
136
+ pose = pose.cuda().float() # bs, nb_joints, joints_dim, seq_len
137
+ target = net(pose, type='encode')
138
+ target = target.cpu().numpy()
139
+ np.save(pjoin(codebook_dir, name[0] +'.npy'), target)
140
+
141
+
142
+ train_loader = dataset_TM_train.DATALoader(args.dataname, args.batch_size, args.nb_code, codebook_dir, unit_length=2**args.down_t)
143
+ train_loader_iter = dataset_TM_train.cycle(train_loader)
144
+
145
+
146
+ ##### ---- Training ---- #####
147
+ best_fid=1000
148
+ best_iter=0
149
+ best_div=100
150
+ best_top1=0
151
+ best_top2=0
152
+ best_top3=0
153
+ best_matching=100
154
+ # pred_pose_eval, pose, m_length, clip_text, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, best_multi, writer, logger = eval_trans.evaluation_transformer(args.out_dir, val_loader, net, trans_encoder, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, clip_model=clip_model, eval_wrapper=eval_wrapper)
155
+
156
+ def get_acc(cls_pred, target, mask):
157
+ cls_pred = torch.masked_select(cls_pred, mask.unsqueeze(-1)).view(-1, cls_pred.shape[-1])
158
+ target_all = torch.masked_select(target, mask)
159
+ probs = torch.softmax(cls_pred, dim=-1)
160
+ _, cls_pred_index = torch.max(probs, dim=-1)
161
+ right_num = (cls_pred_index == target_all).sum()
162
+ return right_num*100/mask.sum()
163
+
164
+ # while nb_iter <= args.total_iter:
165
+ for nb_iter in tqdm(range(1, args.total_iter + 1), position=0, leave=True):
166
+ batch = next(train_loader_iter)
167
+ clip_text, m_tokens, m_tokens_len = batch
168
+ m_tokens, m_tokens_len = m_tokens.cuda(), m_tokens_len.cuda()
169
+ bs = m_tokens.shape[0]
170
+ target = m_tokens # (bs, 26)
171
+ target = target.cuda()
172
+ batch_size, max_len = target.shape[:2]
173
+
174
+ # Random Drop Text
175
+ # text_mask = np.random.random(len(clip_text)) > .05
176
+ # clip_text = np.array(clip_text)
177
+ # clip_text[~text_mask] = ''
178
+
179
+ text = clip.tokenize(clip_text, truncate=True).cuda()
180
+
181
+ feat_clip_text, word_emb = clip_model(text)
182
+
183
+ # [INFO] Swap input tokens
184
+ if args.pkeep == -1:
185
+ proba = np.random.rand(1)[0]
186
+ mask = torch.bernoulli(proba * torch.ones(target.shape,
187
+ device=target.device))
188
+ else:
189
+ mask = torch.bernoulli(args.pkeep * torch.ones(target.shape,
190
+ device=target.device))
191
+ # random only motion token (not pad token). To prevent pad token got mixed up.
192
+ seq_mask_no_end = generate_src_mask(max_len, m_tokens_len)
193
+ mask = torch.logical_or(mask, ~seq_mask_no_end).int()
194
+ r_indices = torch.randint_like(target, args.nb_code)
195
+ input_indices = mask*target+(1-mask)*r_indices
196
+
197
+ # Time step masking
198
+ mask_id = get_model(net).vqvae.num_code + 2
199
+ # rand_time = uniform((batch_size,), device = target.device)
200
+ # rand_mask_probs = cosine_schedule(rand_time)
201
+ rand_mask_probs = torch.zeros(batch_size, device = m_tokens_len.device).float().uniform_(0.5, 1)
202
+ # rand_mask_probs = cosine_schedule(rand_mask_probs)
203
+ num_token_masked = (m_tokens_len * rand_mask_probs).round().clamp(min = 1)
204
+ seq_mask = generate_src_mask(max_len, m_tokens_len+1)
205
+ batch_randperm = torch.rand((batch_size, max_len), device = target.device) - seq_mask_no_end.int()
206
+ batch_randperm = batch_randperm.argsort(dim = -1)
207
+ mask_token = batch_randperm < rearrange(num_token_masked, 'b -> b 1')
208
+
209
+ # masked_target = torch.where(mask_token, input=input_indices, other=-1)
210
+ masked_input_indices = torch.where(mask_token, mask_id, input_indices)
211
+
212
+ att_txt = None # CFG: torch.rand((seq_mask.shape[0], 1)) > 0.1
213
+ cls_pred = trans_encoder(masked_input_indices, feat_clip_text, src_mask = seq_mask, att_txt=att_txt, word_emb=word_emb)[:, 1:]
214
+
215
+ # [INFO] Compute xent loss as a batch
216
+ weights = seq_mask_no_end / (seq_mask_no_end.sum(-1).unsqueeze(-1) * seq_mask_no_end.shape[0])
217
+ cls_pred_seq_masked = cls_pred[seq_mask_no_end, :].view(-1, cls_pred.shape[-1])
218
+ target_seq_masked = target[seq_mask_no_end]
219
+ weight_seq_masked = weights[seq_mask_no_end]
220
+ loss_cls = F.cross_entropy(cls_pred_seq_masked, target_seq_masked, reduction = 'none')
221
+ loss_cls = (loss_cls * weight_seq_masked).sum()
222
+
223
+ ## global loss
224
+ optimizer.zero_grad()
225
+ loss_cls.backward()
226
+ optimizer.step()
227
+ scheduler.step()
228
+
229
+ if nb_iter % args.print_iter == 0 :
230
+ probs_seq_masked = torch.softmax(cls_pred_seq_masked, dim=-1)
231
+ _, cls_pred_seq_masked_index = torch.max(probs_seq_masked, dim=-1)
232
+ target_seq_masked = torch.masked_select(target, seq_mask_no_end)
233
+ right_seq_masked = (cls_pred_seq_masked_index == target_seq_masked).sum()
234
+
235
+ writer.add_scalar('./Loss/all', loss_cls, nb_iter)
236
+ writer.add_scalar('./ACC/every_token', right_seq_masked*100/seq_mask_no_end.sum(), nb_iter)
237
+
238
+ # [INFO] log mask/nomask separately
239
+ no_mask_token = ~mask_token * seq_mask_no_end
240
+ writer.add_scalar('./ACC/masked', get_acc(cls_pred, target, mask_token), nb_iter)
241
+ writer.add_scalar('./ACC/no_masked', get_acc(cls_pred, target, no_mask_token), nb_iter)
242
+
243
+ # msg = f"Train. Iter {nb_iter} : Loss. {avg_loss_cls:.5f}, ACC. {avg_acc:.4f}"
244
+ # logger.info(msg)
245
+
246
+ if nb_iter==0 or nb_iter % args.eval_iter == 0 or nb_iter == args.total_iter:
247
+ num_repeat = 1
248
+ rand_pos = False
249
+ if nb_iter == args.total_iter:
250
+ num_repeat = -30
251
+ rand_pos = True
252
+ val_loader = dataset_TM_eval.DATALoader(args.dataname, True, 32, w_vectorizer)
253
+ pred_pose_eval, pose, m_length, clip_text, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, best_multi, writer, logger = eval_trans.evaluation_transformer(args.out_dir, val_loader, net, trans_encoder, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, clip_model=clip_model, eval_wrapper=eval_wrapper, dataname=args.dataname, num_repeat=num_repeat, rand_pos=rand_pos)
254
+ # for i in range(4):
255
+ # x = pose[i].detach().cpu().numpy()
256
+ # y = pred_pose_eval[i].detach().cpu().numpy()
257
+ # l = m_length[i]
258
+ # caption = clip_text[i]
259
+ # cleaned_name = '-'.join(caption[:200].split('/'))
260
+ # visualize_2motions(x, val_loader.dataset.std, val_loader.dataset.mean, args.dataname, l, y, save_path=f'{args.out_dir}/html/{str(nb_iter)}_{cleaned_name}_{l}.html')
261
+
262
+ if nb_iter == args.total_iter:
263
+ msg_final = f"Train. Iter {best_iter} : FID. {best_fid:.5f}, Diversity. {best_div:.4f}, TOP1. {best_top1:.4f}, TOP2. {best_top2:.4f}, TOP3. {best_top3:.4f}"
264
+ logger.info(msg_final)
265
+ break
train_vq.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import torch
5
+ import torch.optim as optim
6
+ from torch.utils.tensorboard import SummaryWriter
7
+
8
+ import models.vqvae as vqvae
9
+ import utils.losses as losses
10
+ import options.option_vq as option_vq
11
+ import utils.utils_model as utils_model
12
+ from dataset import dataset_VQ, dataset_TM_eval
13
+ import utils.eval_trans as eval_trans
14
+ from options.get_eval_option import get_opt
15
+ from models.evaluator_wrapper import EvaluatorModelWrapper
16
+ import warnings
17
+ warnings.filterwarnings('ignore')
18
+ from utils.word_vectorizer import WordVectorizer
19
+ from tqdm import tqdm
20
+ from exit.utils import get_model, generate_src_mask, init_save_folder
21
+ from models.vqvae_sep import VQVAE_SEP
22
+
23
+ def update_lr_warm_up(optimizer, nb_iter, warm_up_iter, lr):
24
+
25
+ current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
26
+ for param_group in optimizer.param_groups:
27
+ param_group["lr"] = current_lr
28
+
29
+ return optimizer, current_lr
30
+
31
+ ##### ---- Exp dirs ---- #####
32
+ args = option_vq.get_args_parser()
33
+ torch.manual_seed(args.seed)
34
+
35
+ args.out_dir = os.path.join(args.out_dir, f'vq') # /{args.exp_name}
36
+ # os.makedirs(args.out_dir, exist_ok = True)
37
+ init_save_folder(args)
38
+
39
+ ##### ---- Logger ---- #####
40
+ logger = utils_model.get_logger(args.out_dir)
41
+ writer = SummaryWriter(args.out_dir)
42
+ logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
43
+
44
+
45
+
46
+ w_vectorizer = WordVectorizer('./glove', 'our_vab')
47
+
48
+ if args.dataname == 'kit' :
49
+ dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt'
50
+ args.nb_joints = 21
51
+
52
+ else :
53
+ dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
54
+ args.nb_joints = 22
55
+
56
+ logger.info(f'Training on {args.dataname}, motions are with {args.nb_joints} joints')
57
+
58
+ wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
59
+ eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
60
+
61
+
62
+ ##### ---- Dataloader ---- #####
63
+ train_loader = dataset_VQ.DATALoader(args.dataname,
64
+ args.batch_size,
65
+ window_size=args.window_size,
66
+ unit_length=2**args.down_t)
67
+
68
+ train_loader_iter = dataset_VQ.cycle(train_loader)
69
+
70
+ val_loader = dataset_TM_eval.DATALoader(args.dataname, False,
71
+ 32,
72
+ w_vectorizer,
73
+ unit_length=2**args.down_t)
74
+
75
+ ##### ---- Network ---- #####
76
+ if args.sep_uplow:
77
+ net = VQVAE_SEP(args, ## use args to define different parameters in different quantizers
78
+ args.nb_code,
79
+ args.code_dim,
80
+ args.output_emb_width,
81
+ args.down_t,
82
+ args.stride_t,
83
+ args.width,
84
+ args.depth,
85
+ args.dilation_growth_rate,
86
+ args.vq_act,
87
+ args.vq_norm,
88
+ {'mean': torch.from_numpy(train_loader.dataset.mean).cuda().float(),
89
+ 'std': torch.from_numpy(train_loader.dataset.std).cuda().float()},
90
+ True)
91
+ else:
92
+ net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
93
+ args.nb_code,
94
+ args.code_dim,
95
+ args.output_emb_width,
96
+ args.down_t,
97
+ args.stride_t,
98
+ args.width,
99
+ args.depth,
100
+ args.dilation_growth_rate,
101
+ args.vq_act,
102
+ args.vq_norm)
103
+
104
+
105
+ if args.resume_pth :
106
+ logger.info('loading checkpoint from {}'.format(args.resume_pth))
107
+ ckpt = torch.load(args.resume_pth, map_location='cpu')
108
+ net.load_state_dict(ckpt['net'], strict=True)
109
+ net.train()
110
+ net.cuda()
111
+
112
+ ##### ---- Optimizer & Scheduler ---- #####
113
+ optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
114
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma)
115
+
116
+
117
+ Loss = losses.ReConsLoss(args.recons_loss, args.nb_joints)
118
+
119
+ ##### ------ warm-up ------- #####
120
+ avg_recons, avg_perplexity, avg_commit = 0., 0., 0.
121
+
122
+ for nb_iter in tqdm(range(1, args.warm_up_iter)):
123
+
124
+ optimizer, current_lr = update_lr_warm_up(optimizer, nb_iter, args.warm_up_iter, args.lr)
125
+
126
+ gt_motion = next(train_loader_iter)
127
+ gt_motion = gt_motion.cuda().float() # (bs, 64, dim)
128
+
129
+ pred_motion, loss_commit, perplexity = net(gt_motion)
130
+ loss_motion = Loss(pred_motion, gt_motion)
131
+ loss_vel = Loss.forward_joint(pred_motion, gt_motion)
132
+
133
+ loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel
134
+
135
+ optimizer.zero_grad()
136
+ loss.backward()
137
+ optimizer.step()
138
+
139
+ avg_recons += loss_motion.item()
140
+ avg_perplexity += perplexity.item()
141
+ avg_commit += loss_commit.item()
142
+
143
+ if nb_iter % args.print_iter == 0 :
144
+ avg_recons /= args.print_iter
145
+ avg_perplexity /= args.print_iter
146
+ avg_commit /= args.print_iter
147
+
148
+ logger.info(f"Warmup. Iter {nb_iter} : lr {current_lr:.5f} \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}")
149
+
150
+ avg_recons, avg_perplexity, avg_commit = 0., 0., 0.
151
+
152
+ ##### ---- Training ---- #####
153
+ avg_recons, avg_perplexity, avg_commit = 0., 0., 0.
154
+ best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper)
155
+
156
+ for nb_iter in tqdm(range(1, args.total_iter + 1)):
157
+
158
+ gt_motion = next(train_loader_iter)
159
+ gt_motion = gt_motion.cuda().float() # bs, nb_joints, joints_dim, seq_len
160
+
161
+ if args.sep_uplow:
162
+ pred_motion, loss_commit, perplexity = net(gt_motion, idx_noise=0)
163
+ else:
164
+ pred_motion, loss_commit, perplexity = net(gt_motion)
165
+ loss_motion = Loss(pred_motion, gt_motion)
166
+ loss_vel = Loss.forward_joint(pred_motion, gt_motion)
167
+
168
+ loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel
169
+
170
+ optimizer.zero_grad()
171
+ loss.backward()
172
+ optimizer.step()
173
+ scheduler.step()
174
+
175
+ avg_recons += loss_motion.item()
176
+ avg_perplexity += perplexity.item()
177
+ avg_commit += loss_commit.item()
178
+
179
+ if nb_iter % args.print_iter == 0 :
180
+ avg_recons /= args.print_iter
181
+ avg_perplexity /= args.print_iter
182
+ avg_commit /= args.print_iter
183
+
184
+ writer.add_scalar('./Train/L1', avg_recons, nb_iter)
185
+ writer.add_scalar('./Train/PPL', avg_perplexity, nb_iter)
186
+ writer.add_scalar('./Train/Commit', avg_commit, nb_iter)
187
+
188
+ logger.info(f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}")
189
+
190
+ avg_recons, avg_perplexity, avg_commit = 0., 0., 0.,
191
+
192
+ if nb_iter % args.eval_iter==0 :
193
+ best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper=eval_wrapper)
194
+
utils/eval_trans.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import clip
4
+ import numpy as np
5
+ import torch
6
+ from scipy import linalg
7
+
8
+ # import visualization.plot_3d_global as plot_3d
9
+ from utils.motion_process import recover_from_ric
10
+ from exit.utils import get_model, visualize_2motions, generate_src_mask
11
+ from tqdm import tqdm
12
+
13
+
14
+ def tensorborad_add_video_xyz(writer, xyz, nb_iter, tag, nb_vis=4, title_batch=None, outname=None):
15
+ xyz = xyz[:1]
16
+ bs, seq = xyz.shape[:2]
17
+ xyz = xyz.reshape(bs, seq, -1, 3)
18
+ plot_xyz = plot_3d.draw_to_batch(xyz.cpu().numpy(),title_batch, outname)
19
+ plot_xyz =np.transpose(plot_xyz, (0, 1, 4, 2, 3))
20
+ writer.add_video(tag, plot_xyz, nb_iter, fps = 20)
21
+
22
+ @torch.no_grad()
23
+ def evaluation_vqvae(out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper, draw = True, save = True, savegif=False, savenpy=False) :
24
+ net.eval()
25
+ nb_sample = 0
26
+
27
+ draw_org = []
28
+ draw_pred = []
29
+ draw_text = []
30
+
31
+
32
+ motion_annotation_list = []
33
+ motion_pred_list = []
34
+
35
+ R_precision_real = 0
36
+ R_precision = 0
37
+
38
+ nb_sample = 0
39
+ matching_score_real = 0
40
+ matching_score_pred = 0
41
+ for batch in val_loader:
42
+ word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token, name = batch
43
+
44
+ motion = motion.cuda()
45
+ et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length)
46
+ bs, seq = motion.shape[0], motion.shape[1]
47
+
48
+ num_joints = 21 if motion.shape[-1] == 251 else 22
49
+
50
+ pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda()
51
+
52
+ for i in range(bs):
53
+ pose = val_loader.dataset.inv_transform(motion[i:i+1, :m_length[i], :].detach().cpu().numpy())
54
+ # pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints)
55
+
56
+
57
+ pred_pose, loss_commit, perplexity = net(motion[i:i+1, :m_length[i]])
58
+ # pred_denorm = val_loader.dataset.inv_transform(pred_pose.detach().cpu().numpy())
59
+ # pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float().cuda(), num_joints)
60
+
61
+ # if savenpy:
62
+ # np.save(os.path.join(out_dir, name[i]+'_gt.npy'), pose_xyz[:, :m_length[i]].cpu().numpy())
63
+ # np.save(os.path.join(out_dir, name[i]+'_pred.npy'), pred_xyz.detach().cpu().numpy())
64
+
65
+ pred_pose_eval[i:i+1,:m_length[i],:] = pred_pose
66
+
67
+ # if i < min(4, bs):
68
+ # draw_org.append(pose_xyz)
69
+ # draw_pred.append(pred_xyz)
70
+ # draw_text.append(caption[i])
71
+
72
+ et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, m_length)
73
+
74
+ motion_pred_list.append(em_pred)
75
+ motion_annotation_list.append(em)
76
+
77
+ temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
78
+ R_precision_real += temp_R
79
+ matching_score_real += temp_match
80
+ temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
81
+ R_precision += temp_R
82
+ matching_score_pred += temp_match
83
+
84
+ nb_sample += bs
85
+
86
+ motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
87
+ motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
88
+ gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
89
+ mu, cov= calculate_activation_statistics(motion_pred_np)
90
+
91
+ diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
92
+ diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
93
+
94
+ R_precision_real = R_precision_real / nb_sample
95
+ R_precision = R_precision / nb_sample
96
+
97
+ matching_score_real = matching_score_real / nb_sample
98
+ matching_score_pred = matching_score_pred / nb_sample
99
+
100
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
101
+
102
+ msg = f"--> \t Eva. Iter {nb_iter} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}"
103
+ logger.info(msg)
104
+
105
+ if draw:
106
+ writer.add_scalar('./Test/FID', fid, nb_iter)
107
+ writer.add_scalar('./Test/Diversity', diversity, nb_iter)
108
+ writer.add_scalar('./Test/top1', R_precision[0], nb_iter)
109
+ writer.add_scalar('./Test/top2', R_precision[1], nb_iter)
110
+ writer.add_scalar('./Test/top3', R_precision[2], nb_iter)
111
+ writer.add_scalar('./Test/matching_score', matching_score_pred, nb_iter)
112
+
113
+
114
+ # if nb_iter % 5000 == 0 :
115
+ # for ii in range(4):
116
+ # tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/org_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'gt'+str(ii)+'.gif')] if savegif else None)
117
+
118
+ # if nb_iter % 5000 == 0 :
119
+ # for ii in range(4):
120
+ # tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/pred_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'pred'+str(ii)+'.gif')] if savegif else None)
121
+
122
+
123
+ if fid < best_fid :
124
+ msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
125
+ logger.info(msg)
126
+ best_fid, best_iter = fid, nb_iter
127
+ # if save:
128
+ # torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_fid.pth'))
129
+
130
+ if abs(diversity_real - diversity) < abs(diversity_real - best_div) :
131
+ msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
132
+ logger.info(msg)
133
+ best_div = diversity
134
+ # if save:
135
+ # torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_div.pth'))
136
+
137
+ if R_precision[0] > best_top1 :
138
+ msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
139
+ logger.info(msg)
140
+ best_top1 = R_precision[0]
141
+ # if save:
142
+ # torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_top1.pth'))
143
+
144
+ if R_precision[1] > best_top2 :
145
+ msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
146
+ logger.info(msg)
147
+ best_top2 = R_precision[1]
148
+
149
+ if R_precision[2] > best_top3 :
150
+ msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
151
+ logger.info(msg)
152
+ best_top3 = R_precision[2]
153
+
154
+ if matching_score_pred < best_matching :
155
+ msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
156
+ logger.info(msg)
157
+ best_matching = matching_score_pred
158
+ # if save:
159
+ # torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_matching.pth'))
160
+
161
+ if save:
162
+ torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_last.pth'))
163
+
164
+ net.train()
165
+ return best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger
166
+
167
+
168
+ @torch.no_grad()
169
+ def evaluation_transformer(out_dir, val_loader, net, trans, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, clip_model, eval_wrapper, dataname='t2m', draw = True, save = True, savegif=False, num_repeat=1, rand_pos=False, CFG=-1) :
170
+ if num_repeat < 0:
171
+ is_avg_all = True
172
+ num_repeat = -num_repeat
173
+ else:
174
+ is_avg_all = False
175
+
176
+
177
+ trans.eval()
178
+ nb_sample = 0
179
+
180
+ draw_org = []
181
+ draw_pred = []
182
+ draw_text = []
183
+ draw_text_pred = []
184
+
185
+ motion_annotation_list = []
186
+ motion_pred_list = []
187
+ motion_multimodality = []
188
+ R_precision_real = 0
189
+ R_precision = 0
190
+ matching_score_real = 0
191
+ matching_score_pred = 0
192
+
193
+ nb_sample = 0
194
+ blank_id = get_model(trans).num_vq
195
+ for batch in tqdm(val_loader):
196
+ word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token, name = batch
197
+
198
+ bs, seq = pose.shape[:2]
199
+ num_joints = 21 if pose.shape[-1] == 251 else 22
200
+
201
+ text = clip.tokenize(clip_text, truncate=True).cuda()
202
+
203
+ feat_clip_text, word_emb = clip_model(text)
204
+
205
+ motion_multimodality_batch = []
206
+ m_tokens_len = torch.ceil((m_length)/4)
207
+
208
+
209
+ pred_len = m_length.cuda()
210
+ pred_tok_len = m_tokens_len
211
+
212
+
213
+ for i in range(num_repeat):
214
+ pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1])).cuda()
215
+ # pred_len = torch.ones(bs).long()
216
+
217
+ index_motion = trans(feat_clip_text, word_emb, type="sample", m_length=pred_len, rand_pos=rand_pos, CFG=CFG)
218
+ # [INFO] 1. this get the last index of blank_id
219
+ # pred_length = (index_motion == blank_id).int().argmax(1).float()
220
+ # [INFO] 2. this get the first index of blank_id
221
+ pred_length = (index_motion >= blank_id).int()
222
+ pred_length = torch.topk(pred_length, k=1, dim=1).indices.squeeze().float()
223
+ # pred_length[pred_length==0] = index_motion.shape[1] # if blank_id in the first frame, set length to max
224
+ # [INFO] need to run single sample at a time b/c it's conv
225
+ for k in range(bs):
226
+ ######### [INFO] Eval only the predicted length
227
+ # if pred_length[k] == 0:
228
+ # pred_len[k] = seq
229
+ # continue
230
+ # pred_pose = net(index_motion[k:k+1, :int(pred_length[k].item())], type='decode')
231
+ # cur_len = pred_pose.shape[1]
232
+
233
+ # pred_len[k] = min(cur_len, seq)
234
+ # pred_pose_eval[k:k+1, :cur_len] = pred_pose[:, :seq]
235
+ # et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, pred_len)
236
+ ######################################################
237
+
238
+ ######### [INFO] Eval by m_length
239
+ pred_pose = net(index_motion[k:k+1, :int(pred_tok_len[k].item())], type='decode')
240
+ pred_pose_eval[k:k+1, :int(pred_len[k].item())] = pred_pose
241
+ et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, m_length)
242
+ ######################################################
243
+
244
+ motion_multimodality_batch.append(em_pred.reshape(bs, 1, -1))
245
+
246
+ if i == 0 or is_avg_all:
247
+ pose = pose.cuda().float()
248
+
249
+ et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
250
+ motion_annotation_list.append(em)
251
+ motion_pred_list.append(em_pred)
252
+
253
+ # if draw:
254
+ # pose = val_loader.dataset.inv_transform(pose.detach().cpu().numpy())
255
+ # pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints)
256
+
257
+
258
+ # for j in range(min(4, bs)):
259
+ # draw_org.append(pose_xyz[j][:m_length[j]].unsqueeze(0))
260
+ # draw_text.append(clip_text[j])
261
+
262
+ temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
263
+ R_precision_real += temp_R
264
+ matching_score_real += temp_match
265
+ temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
266
+ R_precision += temp_R
267
+ matching_score_pred += temp_match
268
+
269
+ nb_sample += bs
270
+ motion_multimodality.append(torch.cat(motion_multimodality_batch, dim=1))
271
+
272
+ motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
273
+ motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
274
+ gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
275
+ mu, cov= calculate_activation_statistics(motion_pred_np)
276
+
277
+ diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
278
+ diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
279
+
280
+ R_precision_real = R_precision_real / nb_sample
281
+ R_precision = R_precision / nb_sample
282
+
283
+ matching_score_real = matching_score_real / nb_sample
284
+ matching_score_pred = matching_score_pred / nb_sample
285
+
286
+ multimodality = 0
287
+ motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy()
288
+ if num_repeat > 1:
289
+ multimodality = calculate_multimodality(motion_multimodality, 10)
290
+
291
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
292
+
293
+ msg = f"--> \t Eva. Iter {nb_iter} :, \n\
294
+ FID. {fid:.4f} , \n\
295
+ Diversity Real. {diversity_real:.4f}, \n\
296
+ Diversity. {diversity:.4f}, \n\
297
+ R_precision_real. {R_precision_real}, \n\
298
+ R_precision. {R_precision}, \n\
299
+ matching_score_real. {matching_score_real}, \n\
300
+ matching_score_pred. {matching_score_pred}, \n\
301
+ multimodality. {multimodality:.4f}"
302
+ logger.info(msg)
303
+
304
+
305
+ if draw:
306
+ writer.add_scalar('./Test/FID', fid, nb_iter)
307
+ writer.add_scalar('./Test/Diversity', diversity, nb_iter)
308
+ writer.add_scalar('./Test/top1', R_precision[0], nb_iter)
309
+ writer.add_scalar('./Test/top2', R_precision[1], nb_iter)
310
+ writer.add_scalar('./Test/top3', R_precision[2], nb_iter)
311
+ writer.add_scalar('./Test/matching_score', matching_score_pred, nb_iter)
312
+ writer.add_scalar('./Test/multimodality', multimodality, nb_iter)
313
+
314
+ # if nb_iter % 10000 == 0 :
315
+ # for ii in range(4):
316
+ # tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/org_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'gt'+str(ii)+'.gif')] if savegif else None)
317
+ # if nb_iter % 10000 == 0 :
318
+ # for ii in range(4):
319
+ # tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/pred_eval'+str(ii), nb_vis=1, title_batch=[draw_text_pred[ii]], outname=[os.path.join(out_dir, 'pred'+str(ii)+'.gif')] if savegif else None)
320
+
321
+
322
+ if fid < best_fid :
323
+ msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
324
+ logger.info(msg)
325
+ best_fid, best_iter = fid, nb_iter
326
+ # if save:
327
+ # torch.save({'trans' : get_model(trans).state_dict()}, os.path.join(out_dir, 'net_best_fid.pth'))
328
+
329
+ if matching_score_pred < best_matching :
330
+ msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
331
+ logger.info(msg)
332
+ best_matching = matching_score_pred
333
+
334
+ if abs(diversity_real - diversity) < abs(diversity_real - best_div) :
335
+ msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
336
+ logger.info(msg)
337
+ best_div = diversity
338
+
339
+ if R_precision[0] > best_top1 :
340
+ msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
341
+ logger.info(msg)
342
+ best_top1 = R_precision[0]
343
+
344
+ if R_precision[1] > best_top2 :
345
+ msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
346
+ logger.info(msg)
347
+ best_top2 = R_precision[1]
348
+
349
+ if R_precision[2] > best_top3 :
350
+ msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
351
+ logger.info(msg)
352
+ best_top3 = R_precision[2]
353
+
354
+ if save:
355
+ torch.save({'trans' : get_model(trans).state_dict()}, os.path.join(out_dir, 'net_last.pth'))
356
+
357
+ trans.train()
358
+ return pred_pose_eval, pose, m_length, clip_text, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, multimodality, writer, logger
359
+
360
+ def evaluation_transformer_uplow(out_dir, val_loader, net, trans, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, clip_model, eval_wrapper, dataname, draw = True, save = True, savegif=False, num_repeat=1, rand_pos=False, CFG=-1) :
361
+ from utils.humanml_utils import HML_UPPER_BODY_MASK, HML_LOWER_BODY_MASK
362
+
363
+ trans.eval()
364
+ nb_sample = 0
365
+
366
+ draw_org = []
367
+ draw_pred = []
368
+ draw_text = []
369
+ draw_text_pred = []
370
+
371
+ motion_annotation_list = []
372
+ motion_pred_list = []
373
+ motion_multimodality = []
374
+ R_precision_real = 0
375
+ R_precision = 0
376
+ matching_score_real = 0
377
+ matching_score_pred = 0
378
+
379
+ nb_sample = 0
380
+ blank_id = get_model(trans).num_vq
381
+ for batch in tqdm(val_loader):
382
+ word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token, name = batch
383
+ pose = pose.cuda().float()
384
+ pose_lower = pose[..., HML_LOWER_BODY_MASK]
385
+ bs, seq = pose.shape[:2]
386
+ num_joints = 21 if pose.shape[-1] == 251 else 22
387
+
388
+ text = clip.tokenize(clip_text, truncate=True).cuda()
389
+
390
+ feat_clip_text, word_emb = clip_model(text)
391
+
392
+ motion_multimodality_batch = []
393
+ m_tokens_len = torch.ceil((m_length)/4)
394
+
395
+
396
+ pred_len = m_length.cuda()
397
+ pred_tok_len = m_tokens_len
398
+
399
+ max_motion_length = int(seq/4) + 1
400
+ mot_end_idx = get_model(net).vqvae.num_code
401
+ mot_pad_idx = get_model(net).vqvae.num_code + 1
402
+ target_lower = []
403
+ for k in range(bs):
404
+ target = net(pose[k:k+1, :m_length[k]], type='encode')
405
+ if m_tokens_len[k]+1 < max_motion_length:
406
+ target = torch.cat([target,
407
+ torch.ones((1, 1, 2), dtype=int, device=target.device) * mot_end_idx,
408
+ torch.ones((1, max_motion_length-1-m_tokens_len[k].int().item(), 2), dtype=int, device=target.device) * mot_pad_idx], axis=1)
409
+ else:
410
+ target = torch.cat([target,
411
+ torch.ones((1, 1, 2), dtype=int, device=target.device) * mot_end_idx], axis=1)
412
+ target_lower.append(target[..., 1])
413
+ target_lower = torch.cat(target_lower, axis=0)
414
+
415
+ for i in range(num_repeat):
416
+ pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1])).cuda()
417
+ # pred_len = torch.ones(bs).long()
418
+
419
+ index_motion = trans(feat_clip_text, target_lower, word_emb, type="sample", m_length=pred_len, rand_pos=rand_pos, CFG=CFG)
420
+ # [INFO] 1. this get the last index of blank_id
421
+ # pred_length = (index_motion == blank_id).int().argmax(1).float()
422
+ # [INFO] 2. this get the first index of blank_id
423
+ pred_length = (index_motion >= blank_id).int()
424
+ pred_length = torch.topk(pred_length, k=1, dim=1).indices.squeeze().float()
425
+ # pred_length[pred_length==0] = index_motion.shape[1] # if blank_id in the first frame, set length to max
426
+ # [INFO] need to run single sample at a time b/c it's conv
427
+ for k in range(bs):
428
+ ######### [INFO] Eval only the predicted length
429
+ # if pred_length[k] == 0:
430
+ # pred_len[k] = seq
431
+ # continue
432
+ # pred_pose = net(index_motion[k:k+1, :int(pred_length[k].item())], type='decode')
433
+ # cur_len = pred_pose.shape[1]
434
+
435
+ # pred_len[k] = min(cur_len, seq)
436
+ # pred_pose_eval[k:k+1, :cur_len] = pred_pose[:, :seq]
437
+ # et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, pred_len)
438
+ ######################################################
439
+
440
+ ######### [INFO] Eval by m_length
441
+ all_tokens = torch.cat([
442
+ index_motion[k:k+1, :int(pred_tok_len[k].item()), None],
443
+ target_lower[k:k+1, :int(pred_tok_len[k].item()), None]
444
+ ], axis=-1)
445
+ pred_pose = net(all_tokens, type='decode')
446
+ pred_pose_eval[k:k+1, :int(pred_len[k].item())] = pred_pose
447
+ pred_pose_eval[..., HML_LOWER_BODY_MASK] = pose_lower
448
+ et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, m_length)
449
+ ######################################################
450
+
451
+ motion_multimodality_batch.append(em_pred.reshape(bs, 1, -1))
452
+
453
+ if i == 0:
454
+
455
+
456
+ et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
457
+ motion_annotation_list.append(em)
458
+ motion_pred_list.append(em_pred)
459
+
460
+ # if draw:
461
+ # pose = val_loader.dataset.inv_transform(pose.detach().cpu().numpy())
462
+ # pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints)
463
+
464
+
465
+ # for j in range(min(4, bs)):
466
+ # draw_org.append(pose_xyz[j][:m_length[j]].unsqueeze(0))
467
+ # draw_text.append(clip_text[j])
468
+
469
+ temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
470
+ R_precision_real += temp_R
471
+ matching_score_real += temp_match
472
+ temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
473
+ R_precision += temp_R
474
+ matching_score_pred += temp_match
475
+
476
+ nb_sample += bs
477
+ motion_multimodality.append(torch.cat(motion_multimodality_batch, dim=1))
478
+
479
+ motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
480
+ motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
481
+ gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
482
+ mu, cov= calculate_activation_statistics(motion_pred_np)
483
+
484
+ diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
485
+ diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
486
+
487
+ R_precision_real = R_precision_real / nb_sample
488
+ R_precision = R_precision / nb_sample
489
+
490
+ matching_score_real = matching_score_real / nb_sample
491
+ matching_score_pred = matching_score_pred / nb_sample
492
+
493
+ multimodality = 0
494
+ motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy()
495
+ if num_repeat > 1:
496
+ multimodality = calculate_multimodality(motion_multimodality, 10)
497
+
498
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
499
+
500
+ msg = f"--> \t Eva. Iter {nb_iter} :, \n\
501
+ FID. {fid:.4f} , \n\
502
+ Diversity Real. {diversity_real:.4f}, \n\
503
+ Diversity. {diversity:.4f}, \n\
504
+ R_precision_real. {R_precision_real}, \n\
505
+ R_precision. {R_precision}, \n\
506
+ matching_score_real. {matching_score_real}, \n\
507
+ matching_score_pred. {matching_score_pred}, \n\
508
+ multimodality. {multimodality:.4f}"
509
+ logger.info(msg)
510
+
511
+
512
+ if draw:
513
+ writer.add_scalar('./Test/FID', fid, nb_iter)
514
+ writer.add_scalar('./Test/Diversity', diversity, nb_iter)
515
+ writer.add_scalar('./Test/top1', R_precision[0], nb_iter)
516
+ writer.add_scalar('./Test/top2', R_precision[1], nb_iter)
517
+ writer.add_scalar('./Test/top3', R_precision[2], nb_iter)
518
+ writer.add_scalar('./Test/matching_score', matching_score_pred, nb_iter)
519
+ writer.add_scalar('./Test/multimodality', multimodality, nb_iter)
520
+
521
+ # if nb_iter % 10000 == 0 :
522
+ # for ii in range(4):
523
+ # tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/org_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'gt'+str(ii)+'.gif')] if savegif else None)
524
+ # if nb_iter % 10000 == 0 :
525
+ # for ii in range(4):
526
+ # tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/pred_eval'+str(ii), nb_vis=1, title_batch=[draw_text_pred[ii]], outname=[os.path.join(out_dir, 'pred'+str(ii)+'.gif')] if savegif else None)
527
+
528
+
529
+ if fid < best_fid :
530
+ msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
531
+ logger.info(msg)
532
+ best_fid, best_iter = fid, nb_iter
533
+ # if save:
534
+ # torch.save({'trans' : get_model(trans).state_dict()}, os.path.join(out_dir, 'net_best_fid.pth'))
535
+
536
+ if matching_score_pred < best_matching :
537
+ msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
538
+ logger.info(msg)
539
+ best_matching = matching_score_pred
540
+
541
+ if abs(diversity_real - diversity) < abs(diversity_real - best_div) :
542
+ msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
543
+ logger.info(msg)
544
+ best_div = diversity
545
+
546
+ if R_precision[0] > best_top1 :
547
+ msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
548
+ logger.info(msg)
549
+ best_top1 = R_precision[0]
550
+
551
+ if R_precision[1] > best_top2 :
552
+ msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
553
+ logger.info(msg)
554
+ best_top2 = R_precision[1]
555
+
556
+ if R_precision[2] > best_top3 :
557
+ msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
558
+ logger.info(msg)
559
+ best_top3 = R_precision[2]
560
+
561
+ if save:
562
+ torch.save({'trans' : get_model(trans).state_dict()}, os.path.join(out_dir, 'net_last.pth'))
563
+
564
+ trans.train()
565
+ return pred_pose_eval, pose, m_length, clip_text, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, multimodality, writer, logger
566
+
567
+ @torch.no_grad()
568
+ def evaluation_transformer_test(out_dir, val_loader, net, trans, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, best_multi, clip_model, eval_wrapper, draw = True, save = True, savegif=False, savenpy=False) :
569
+
570
+ trans.eval()
571
+ nb_sample = 0
572
+
573
+ draw_org = []
574
+ draw_pred = []
575
+ draw_text = []
576
+ draw_text_pred = []
577
+ draw_name = []
578
+
579
+ motion_annotation_list = []
580
+ motion_pred_list = []
581
+ motion_multimodality = []
582
+ R_precision_real = 0
583
+ R_precision = 0
584
+ matching_score_real = 0
585
+ matching_score_pred = 0
586
+
587
+ nb_sample = 0
588
+
589
+ for batch in val_loader:
590
+
591
+ word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token, name = batch
592
+ bs, seq = pose.shape[:2]
593
+ num_joints = 21 if pose.shape[-1] == 251 else 22
594
+
595
+ text = clip.tokenize(clip_text, truncate=True).cuda()
596
+
597
+ feat_clip_text = clip_model.encode_text(text).float()
598
+ motion_multimodality_batch = []
599
+ for i in range(30):
600
+ pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1])).cuda()
601
+ pred_len = torch.ones(bs).long()
602
+
603
+ for k in range(bs):
604
+ try:
605
+ index_motion = trans.sample(feat_clip_text[k:k+1], True)
606
+ except:
607
+ index_motion = torch.ones(1,1).cuda().long()
608
+
609
+ pred_pose = net.forward_decoder(index_motion)
610
+ cur_len = pred_pose.shape[1]
611
+
612
+ pred_len[k] = min(cur_len, seq)
613
+ pred_pose_eval[k:k+1, :cur_len] = pred_pose[:, :seq]
614
+
615
+ if i == 0 and (draw or savenpy):
616
+ pred_denorm = val_loader.dataset.inv_transform(pred_pose.detach().cpu().numpy())
617
+ pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float().cuda(), num_joints)
618
+
619
+ if savenpy:
620
+ np.save(os.path.join(out_dir, name[k]+'_pred.npy'), pred_xyz.detach().cpu().numpy())
621
+
622
+ if draw:
623
+ if i == 0:
624
+ draw_pred.append(pred_xyz)
625
+ draw_text_pred.append(clip_text[k])
626
+ draw_name.append(name[k])
627
+
628
+ et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, pred_len)
629
+
630
+ motion_multimodality_batch.append(em_pred.reshape(bs, 1, -1))
631
+
632
+ if i == 0:
633
+ pose = pose.cuda().float()
634
+
635
+ et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
636
+ motion_annotation_list.append(em)
637
+ motion_pred_list.append(em_pred)
638
+
639
+ if draw or savenpy:
640
+ pose = val_loader.dataset.inv_transform(pose.detach().cpu().numpy())
641
+ pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints)
642
+
643
+ if savenpy:
644
+ for j in range(bs):
645
+ np.save(os.path.join(out_dir, name[j]+'_gt.npy'), pose_xyz[j][:m_length[j]].unsqueeze(0).cpu().numpy())
646
+
647
+ if draw:
648
+ for j in range(bs):
649
+ draw_org.append(pose_xyz[j][:m_length[j]].unsqueeze(0))
650
+ draw_text.append(clip_text[j])
651
+
652
+ temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
653
+ R_precision_real += temp_R
654
+ matching_score_real += temp_match
655
+ temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
656
+ R_precision += temp_R
657
+ matching_score_pred += temp_match
658
+
659
+ nb_sample += bs
660
+
661
+ motion_multimodality.append(torch.cat(motion_multimodality_batch, dim=1))
662
+
663
+ motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
664
+ motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
665
+ gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
666
+ mu, cov= calculate_activation_statistics(motion_pred_np)
667
+
668
+ diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
669
+ diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
670
+
671
+ R_precision_real = R_precision_real / nb_sample
672
+ R_precision = R_precision / nb_sample
673
+
674
+ matching_score_real = matching_score_real / nb_sample
675
+ matching_score_pred = matching_score_pred / nb_sample
676
+
677
+ multimodality = 0
678
+ motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy()
679
+ multimodality = calculate_multimodality(motion_multimodality, 10)
680
+
681
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
682
+
683
+ msg = f"--> \t Eva. Iter {nb_iter} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}, multimodality. {multimodality:.4f}"
684
+ logger.info(msg)
685
+
686
+
687
+ if draw:
688
+ for ii in range(len(draw_org)):
689
+ tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/'+draw_name[ii]+'_org', nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, draw_name[ii]+'_skel_gt.gif')] if savegif else None)
690
+
691
+ tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/'+draw_name[ii]+'_pred', nb_vis=1, title_batch=[draw_text_pred[ii]], outname=[os.path.join(out_dir, draw_name[ii]+'_skel_pred.gif')] if savegif else None)
692
+
693
+ trans.train()
694
+ return fid, best_iter, diversity, R_precision[0], R_precision[1], R_precision[2], matching_score_pred, multimodality, writer, logger
695
+
696
+ # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
697
+ def euclidean_distance_matrix(matrix1, matrix2):
698
+ """
699
+ Params:
700
+ -- matrix1: N1 x D
701
+ -- matrix2: N2 x D
702
+ Returns:
703
+ -- dist: N1 x N2
704
+ dist[i, j] == distance(matrix1[i], matrix2[j])
705
+ """
706
+ assert matrix1.shape[1] == matrix2.shape[1]
707
+ d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
708
+ d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
709
+ d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
710
+ dists = np.sqrt(d1 + d2 + d3) # broadcasting
711
+ return dists
712
+
713
+
714
+
715
+ def calculate_top_k(mat, top_k):
716
+ size = mat.shape[0]
717
+ gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
718
+ bool_mat = (mat == gt_mat)
719
+ correct_vec = False
720
+ top_k_list = []
721
+ for i in range(top_k):
722
+ # print(correct_vec, bool_mat[:, i])
723
+ correct_vec = (correct_vec | bool_mat[:, i])
724
+ # print(correct_vec)
725
+ top_k_list.append(correct_vec[:, None])
726
+ top_k_mat = np.concatenate(top_k_list, axis=1)
727
+ return top_k_mat
728
+
729
+
730
+ def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
731
+ dist_mat = euclidean_distance_matrix(embedding1, embedding2)
732
+ matching_score = dist_mat.trace()
733
+ argmax = np.argsort(dist_mat, axis=1)
734
+ top_k_mat = calculate_top_k(argmax, top_k)
735
+ if sum_all:
736
+ return top_k_mat.sum(axis=0), matching_score
737
+ else:
738
+ return top_k_mat, matching_score
739
+
740
+ def calculate_multimodality(activation, multimodality_times):
741
+ assert len(activation.shape) == 3
742
+ assert activation.shape[1] > multimodality_times
743
+ num_per_sent = activation.shape[1]
744
+
745
+ first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
746
+ second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
747
+ dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
748
+ return dist.mean()
749
+
750
+
751
+ def calculate_diversity(activation, diversity_times):
752
+ assert len(activation.shape) == 2
753
+ assert activation.shape[0] > diversity_times
754
+ num_samples = activation.shape[0]
755
+
756
+ first_indices = np.random.choice(num_samples, diversity_times, replace=False)
757
+ second_indices = np.random.choice(num_samples, diversity_times, replace=False)
758
+ dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
759
+ return dist.mean()
760
+
761
+
762
+
763
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
764
+
765
+ mu1 = np.atleast_1d(mu1)
766
+ mu2 = np.atleast_1d(mu2)
767
+
768
+ sigma1 = np.atleast_2d(sigma1)
769
+ sigma2 = np.atleast_2d(sigma2)
770
+
771
+ assert mu1.shape == mu2.shape, \
772
+ 'Training and test mean vectors have different lengths'
773
+ assert sigma1.shape == sigma2.shape, \
774
+ 'Training and test covariances have different dimensions'
775
+
776
+ diff = mu1 - mu2
777
+
778
+ # Product might be almost singular
779
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
780
+ if not np.isfinite(covmean).all():
781
+ msg = ('fid calculation produces singular product; '
782
+ 'adding %s to diagonal of cov estimates') % eps
783
+ print(msg)
784
+ offset = np.eye(sigma1.shape[0]) * eps
785
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
786
+
787
+ # Numerical error might give slight imaginary component
788
+ if np.iscomplexobj(covmean):
789
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
790
+ m = np.max(np.abs(covmean.imag))
791
+ raise ValueError('Imaginary component {}'.format(m))
792
+ covmean = covmean.real
793
+
794
+ tr_covmean = np.trace(covmean)
795
+
796
+ return (diff.dot(diff) + np.trace(sigma1)
797
+ + np.trace(sigma2) - 2 * tr_covmean)
798
+
799
+
800
+
801
+ def calculate_activation_statistics(activations):
802
+
803
+ mu = np.mean(activations, axis=0)
804
+ cov = np.cov(activations, rowvar=False)
805
+ return mu, cov
806
+
807
+
808
+ def calculate_frechet_feature_distance(feature_list1, feature_list2):
809
+ feature_list1 = np.stack(feature_list1)
810
+ feature_list2 = np.stack(feature_list2)
811
+
812
+ # normalize the scale
813
+ mean = np.mean(feature_list1, axis=0)
814
+ std = np.std(feature_list1, axis=0) + 1e-10
815
+ feature_list1 = (feature_list1 - mean) / std
816
+ feature_list2 = (feature_list2 - mean) / std
817
+
818
+ dist = calculate_frechet_distance(
819
+ mu1=np.mean(feature_list1, axis=0),
820
+ sigma1=np.cov(feature_list1, rowvar=False),
821
+ mu2=np.mean(feature_list2, axis=0),
822
+ sigma2=np.cov(feature_list2, rowvar=False),
823
+ )
824
+ return dist
utils/humanml_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ HML_JOINT_NAMES = [
4
+ 'pelvis',
5
+ 'left_hip',
6
+ 'right_hip',
7
+ 'spine1',
8
+ 'left_knee',
9
+ 'right_knee',
10
+ 'spine2',
11
+ 'left_ankle',
12
+ 'right_ankle',
13
+ 'spine3',
14
+ 'left_foot',
15
+ 'right_foot',
16
+ 'neck',
17
+ 'left_collar',
18
+ 'right_collar',
19
+ 'head',
20
+ 'left_shoulder',
21
+ 'right_shoulder',
22
+ 'left_elbow',
23
+ 'right_elbow',
24
+ 'left_wrist',
25
+ 'right_wrist',
26
+ ]
27
+
28
+ NUM_HML_JOINTS = len(HML_JOINT_NAMES) # 22 SMPLH body joints
29
+
30
+ HML_LOWER_BODY_JOINTS = [HML_JOINT_NAMES.index(name) for name in ['pelvis', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle', 'left_foot', 'right_foot',]]
31
+ SMPL_UPPER_BODY_JOINTS = [i for i in range(len(HML_JOINT_NAMES)) if i not in HML_LOWER_BODY_JOINTS]
32
+
33
+
34
+ # Recover global angle and positions for rotation data
35
+ # root_rot_velocity (B, seq_len, 1)
36
+ # root_linear_velocity (B, seq_len, 2)
37
+ # root_y (B, seq_len, 1)
38
+ # ric_data (B, seq_len, (joint_num - 1)*3)
39
+ # rot_data (B, seq_len, (joint_num - 1)*6)
40
+ # local_velocity (B, seq_len, joint_num*3)
41
+ # foot contact (B, seq_len, 4)
42
+ HML_ROOT_BINARY = np.array([True] + [False] * (NUM_HML_JOINTS-1))
43
+ HML_ROOT_MASK = np.concatenate(([True]*(1+2+1),
44
+ HML_ROOT_BINARY[1:].repeat(3),
45
+ HML_ROOT_BINARY[1:].repeat(6),
46
+ HML_ROOT_BINARY.repeat(3),
47
+ [False] * 4))
48
+ HML_LOWER_BODY_JOINTS_BINARY = np.array([i in HML_LOWER_BODY_JOINTS for i in range(NUM_HML_JOINTS)])
49
+ HML_LOWER_BODY_MASK = np.concatenate(([True]*(1+2+1),
50
+ HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(3),
51
+ HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(6),
52
+ HML_LOWER_BODY_JOINTS_BINARY.repeat(3),
53
+ [True]*4))
54
+ HML_UPPER_BODY_MASK = ~HML_LOWER_BODY_MASK
55
+
56
+
57
+ ALL_JOINT_FALSE = np.full(*HML_ROOT_BINARY.shape, False)
58
+ HML_UPPER_BODY_JOINTS_BINARY = np.array([i in SMPL_UPPER_BODY_JOINTS for i in range(NUM_HML_JOINTS)])
59
+
60
+ UPPER_JOINT_Y_TRUE = np.array([ALL_JOINT_FALSE[1:], HML_UPPER_BODY_JOINTS_BINARY[1:], ALL_JOINT_FALSE[1:]])
61
+ UPPER_JOINT_Y_TRUE = UPPER_JOINT_Y_TRUE.T
62
+ UPPER_JOINT_Y_TRUE = UPPER_JOINT_Y_TRUE.reshape(ALL_JOINT_FALSE[1:].shape[0]*3)
63
+
64
+ UPPER_JOINT_Y_MASK = np.concatenate(([False]*(1+2+1),
65
+ UPPER_JOINT_Y_TRUE,
66
+ ALL_JOINT_FALSE[1:].repeat(6),
67
+ ALL_JOINT_FALSE.repeat(3),
68
+ [False] * 4))
utils/losses.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ReConsLoss(nn.Module):
5
+ def __init__(self, recons_loss, nb_joints):
6
+ super(ReConsLoss, self).__init__()
7
+
8
+ if recons_loss == 'l1':
9
+ self.Loss = torch.nn.L1Loss()
10
+ elif recons_loss == 'l2' :
11
+ self.Loss = torch.nn.MSELoss()
12
+ elif recons_loss == 'l1_smooth' :
13
+ self.Loss = torch.nn.SmoothL1Loss()
14
+
15
+ # 4 global motion associated to root
16
+ # 12 local motion (3 local xyz, 3 vel xyz, 6 rot6d)
17
+ # 3 global vel xyz
18
+ # 4 foot contact
19
+ self.nb_joints = nb_joints
20
+ self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4
21
+
22
+ def forward(self, motion_pred, motion_gt) :
23
+ loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim])
24
+ return loss
25
+
26
+ def forward_joint(self, motion_pred, motion_gt) :
27
+ loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4])
28
+ return loss
29
+
30
+
utils/motion_process.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils.quaternion import quaternion_to_cont6d, qrot, qinv
3
+
4
+ def recover_root_rot_pos(data):
5
+ rot_vel = data[..., 0]
6
+ r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
7
+ '''Get Y-axis rotation from rotation velocity'''
8
+ r_rot_ang[..., 1:] = rot_vel[..., :-1]
9
+ r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
10
+
11
+ r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
12
+ r_rot_quat[..., 0] = torch.cos(r_rot_ang)
13
+ r_rot_quat[..., 2] = torch.sin(r_rot_ang)
14
+
15
+ r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
16
+ r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
17
+ '''Add Y-axis rotation to root position'''
18
+ r_pos = qrot(qinv(r_rot_quat), r_pos)
19
+
20
+ r_pos = torch.cumsum(r_pos, dim=-2)
21
+
22
+ r_pos[..., 1] = data[..., 3]
23
+ return r_rot_quat, r_pos
24
+
25
+
26
+ def recover_from_rot(data, joints_num, skeleton):
27
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
28
+
29
+ r_rot_cont6d = quaternion_to_cont6d(r_rot_quat)
30
+
31
+ start_indx = 1 + 2 + 1 + (joints_num - 1) * 3
32
+ end_indx = start_indx + (joints_num - 1) * 6
33
+ cont6d_params = data[..., start_indx:end_indx]
34
+ # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape)
35
+ cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1)
36
+ cont6d_params = cont6d_params.view(-1, joints_num, 6)
37
+
38
+ positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos)
39
+
40
+ return positions
41
+
42
+
43
+ def recover_from_ric(data, joints_num):
44
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
45
+ positions = data[..., 4:(joints_num - 1) * 3 + 4]
46
+ positions = positions.view(positions.shape[:-1] + (-1, 3))
47
+
48
+ '''Add Y-axis rotation to local joints'''
49
+ positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
50
+
51
+ '''Add root XZ to joints'''
52
+ positions[..., 0] += r_pos[..., 0:1]
53
+ positions[..., 2] += r_pos[..., 2:3]
54
+
55
+ '''Concate root and joints'''
56
+ positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
57
+
58
+ return positions
59
+
utils/paramUtil.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # Define a kinematic tree for the skeletal struture
4
+ kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
5
+
6
+ kit_raw_offsets = np.array(
7
+ [
8
+ [0, 0, 0],
9
+ [0, 1, 0],
10
+ [0, 1, 0],
11
+ [0, 1, 0],
12
+ [0, 1, 0],
13
+ [1, 0, 0],
14
+ [0, -1, 0],
15
+ [0, -1, 0],
16
+ [-1, 0, 0],
17
+ [0, -1, 0],
18
+ [0, -1, 0],
19
+ [1, 0, 0],
20
+ [0, -1, 0],
21
+ [0, -1, 0],
22
+ [0, 0, 1],
23
+ [0, 0, 1],
24
+ [-1, 0, 0],
25
+ [0, -1, 0],
26
+ [0, -1, 0],
27
+ [0, 0, 1],
28
+ [0, 0, 1]
29
+ ]
30
+ )
31
+
32
+ t2m_raw_offsets = np.array([[0,0,0],
33
+ [1,0,0],
34
+ [-1,0,0],
35
+ [0,1,0],
36
+ [0,-1,0],
37
+ [0,-1,0],
38
+ [0,1,0],
39
+ [0,-1,0],
40
+ [0,-1,0],
41
+ [0,1,0],
42
+ [0,0,1],
43
+ [0,0,1],
44
+ [0,1,0],
45
+ [1,0,0],
46
+ [-1,0,0],
47
+ [0,0,1],
48
+ [0,-1,0],
49
+ [0,-1,0],
50
+ [0,-1,0],
51
+ [0,-1,0],
52
+ [0,-1,0],
53
+ [0,-1,0]])
54
+
55
+ t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
56
+ t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
57
+ t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
58
+
59
+
60
+ kit_tgt_skel_id = '03950'
61
+
62
+ t2m_tgt_skel_id = '000021'
63
+
utils/quaternion.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import numpy as np
10
+
11
+ _EPS4 = np.finfo(float).eps * 4.0
12
+
13
+ _FLOAT_EPS = np.finfo(float).eps
14
+
15
+ # PyTorch-backed implementations
16
+ def qinv(q):
17
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18
+ mask = torch.ones_like(q)
19
+ mask[..., 1:] = -mask[..., 1:]
20
+ return q * mask
21
+
22
+
23
+ def qinv_np(q):
24
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25
+ return qinv(torch.from_numpy(q).float()).numpy()
26
+
27
+
28
+ def qnormalize(q):
29
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30
+ return q / torch.norm(q, dim=-1, keepdim=True)
31
+
32
+
33
+ def qmul(q, r):
34
+ """
35
+ Multiply quaternion(s) q with quaternion(s) r.
36
+ Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37
+ Returns q*r as a tensor of shape (*, 4).
38
+ """
39
+ assert q.shape[-1] == 4
40
+ assert r.shape[-1] == 4
41
+
42
+ original_shape = q.shape
43
+
44
+ # Compute outer product
45
+ terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46
+
47
+ w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48
+ x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49
+ y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50
+ z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51
+ return torch.stack((w, x, y, z), dim=1).view(original_shape)
52
+
53
+
54
+ def qrot(q, v):
55
+ """
56
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
57
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58
+ where * denotes any number of dimensions.
59
+ Returns a tensor of shape (*, 3).
60
+ """
61
+ assert q.shape[-1] == 4
62
+ assert v.shape[-1] == 3
63
+ assert q.shape[:-1] == v.shape[:-1]
64
+
65
+ original_shape = list(v.shape)
66
+ # print(q.shape)
67
+ q = q.contiguous().view(-1, 4)
68
+ v = v.contiguous().view(-1, 3)
69
+
70
+ qvec = q[:, 1:]
71
+ uv = torch.cross(qvec, v, dim=1)
72
+ uuv = torch.cross(qvec, uv, dim=1)
73
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74
+
75
+
76
+ def qeuler(q, order, epsilon=0, deg=True):
77
+ """
78
+ Convert quaternion(s) q to Euler angles.
79
+ Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80
+ Returns a tensor of shape (*, 3).
81
+ """
82
+ assert q.shape[-1] == 4
83
+
84
+ original_shape = list(q.shape)
85
+ original_shape[-1] = 3
86
+ q = q.view(-1, 4)
87
+
88
+ q0 = q[:, 0]
89
+ q1 = q[:, 1]
90
+ q2 = q[:, 2]
91
+ q3 = q[:, 3]
92
+
93
+ if order == 'xyz':
94
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95
+ y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97
+ elif order == 'yzx':
98
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100
+ z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101
+ elif order == 'zxy':
102
+ x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105
+ elif order == 'xzy':
106
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107
+ y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108
+ z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109
+ elif order == 'yxz':
110
+ x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111
+ y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112
+ z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113
+ elif order == 'zyx':
114
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115
+ y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116
+ z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117
+ else:
118
+ raise
119
+
120
+ if deg:
121
+ return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
122
+ else:
123
+ return torch.stack((x, y, z), dim=1).view(original_shape)
124
+
125
+
126
+ # Numpy-backed implementations
127
+
128
+ def qmul_np(q, r):
129
+ q = torch.from_numpy(q).contiguous().float()
130
+ r = torch.from_numpy(r).contiguous().float()
131
+ return qmul(q, r).numpy()
132
+
133
+
134
+ def qrot_np(q, v):
135
+ q = torch.from_numpy(q).contiguous().float()
136
+ v = torch.from_numpy(v).contiguous().float()
137
+ return qrot(q, v).numpy()
138
+
139
+
140
+ def qeuler_np(q, order, epsilon=0, use_gpu=False):
141
+ if use_gpu:
142
+ q = torch.from_numpy(q).cuda().float()
143
+ return qeuler(q, order, epsilon).cpu().numpy()
144
+ else:
145
+ q = torch.from_numpy(q).contiguous().float()
146
+ return qeuler(q, order, epsilon).numpy()
147
+
148
+
149
+ def qfix(q):
150
+ """
151
+ Enforce quaternion continuity across the time dimension by selecting
152
+ the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
153
+ between two consecutive frames.
154
+
155
+ Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
156
+ Returns a tensor of the same shape.
157
+ """
158
+ assert len(q.shape) == 3
159
+ assert q.shape[-1] == 4
160
+
161
+ result = q.copy()
162
+ dot_products = np.sum(q[1:] * q[:-1], axis=2)
163
+ mask = dot_products < 0
164
+ mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
165
+ result[1:][mask] *= -1
166
+ return result
167
+
168
+
169
+ def euler2quat(e, order, deg=True):
170
+ """
171
+ Convert Euler angles to quaternions.
172
+ """
173
+ assert e.shape[-1] == 3
174
+
175
+ original_shape = list(e.shape)
176
+ original_shape[-1] = 4
177
+
178
+ e = e.view(-1, 3)
179
+
180
+ ## if euler angles in degrees
181
+ if deg:
182
+ e = e * np.pi / 180.
183
+
184
+ x = e[:, 0]
185
+ y = e[:, 1]
186
+ z = e[:, 2]
187
+
188
+ rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
189
+ ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
190
+ rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
191
+
192
+ result = None
193
+ for coord in order:
194
+ if coord == 'x':
195
+ r = rx
196
+ elif coord == 'y':
197
+ r = ry
198
+ elif coord == 'z':
199
+ r = rz
200
+ else:
201
+ raise
202
+ if result is None:
203
+ result = r
204
+ else:
205
+ result = qmul(result, r)
206
+
207
+ # Reverse antipodal representation to have a non-negative "w"
208
+ if order in ['xyz', 'yzx', 'zxy']:
209
+ result *= -1
210
+
211
+ return result.view(original_shape)
212
+
213
+
214
+ def expmap_to_quaternion(e):
215
+ """
216
+ Convert axis-angle rotations (aka exponential maps) to quaternions.
217
+ Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
218
+ Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
219
+ Returns a tensor of shape (*, 4).
220
+ """
221
+ assert e.shape[-1] == 3
222
+
223
+ original_shape = list(e.shape)
224
+ original_shape[-1] = 4
225
+ e = e.reshape(-1, 3)
226
+
227
+ theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
228
+ w = np.cos(0.5 * theta).reshape(-1, 1)
229
+ xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
230
+ return np.concatenate((w, xyz), axis=1).reshape(original_shape)
231
+
232
+
233
+ def euler_to_quaternion(e, order):
234
+ """
235
+ Convert Euler angles to quaternions.
236
+ """
237
+ assert e.shape[-1] == 3
238
+
239
+ original_shape = list(e.shape)
240
+ original_shape[-1] = 4
241
+
242
+ e = e.reshape(-1, 3)
243
+
244
+ x = e[:, 0]
245
+ y = e[:, 1]
246
+ z = e[:, 2]
247
+
248
+ rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
249
+ ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
250
+ rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
251
+
252
+ result = None
253
+ for coord in order:
254
+ if coord == 'x':
255
+ r = rx
256
+ elif coord == 'y':
257
+ r = ry
258
+ elif coord == 'z':
259
+ r = rz
260
+ else:
261
+ raise
262
+ if result is None:
263
+ result = r
264
+ else:
265
+ result = qmul_np(result, r)
266
+
267
+ # Reverse antipodal representation to have a non-negative "w"
268
+ if order in ['xyz', 'yzx', 'zxy']:
269
+ result *= -1
270
+
271
+ return result.reshape(original_shape)
272
+
273
+
274
+ def quaternion_to_matrix(quaternions):
275
+ """
276
+ Convert rotations given as quaternions to rotation matrices.
277
+ Args:
278
+ quaternions: quaternions with real part first,
279
+ as tensor of shape (..., 4).
280
+ Returns:
281
+ Rotation matrices as tensor of shape (..., 3, 3).
282
+ """
283
+ r, i, j, k = torch.unbind(quaternions, -1)
284
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
285
+
286
+ o = torch.stack(
287
+ (
288
+ 1 - two_s * (j * j + k * k),
289
+ two_s * (i * j - k * r),
290
+ two_s * (i * k + j * r),
291
+ two_s * (i * j + k * r),
292
+ 1 - two_s * (i * i + k * k),
293
+ two_s * (j * k - i * r),
294
+ two_s * (i * k - j * r),
295
+ two_s * (j * k + i * r),
296
+ 1 - two_s * (i * i + j * j),
297
+ ),
298
+ -1,
299
+ )
300
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
301
+
302
+
303
+ def quaternion_to_matrix_np(quaternions):
304
+ q = torch.from_numpy(quaternions).contiguous().float()
305
+ return quaternion_to_matrix(q).numpy()
306
+
307
+
308
+ def quaternion_to_cont6d_np(quaternions):
309
+ rotation_mat = quaternion_to_matrix_np(quaternions)
310
+ cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
311
+ return cont_6d
312
+
313
+
314
+ def quaternion_to_cont6d(quaternions):
315
+ rotation_mat = quaternion_to_matrix(quaternions)
316
+ cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
317
+ return cont_6d
318
+
319
+
320
+ def cont6d_to_matrix(cont6d):
321
+ assert cont6d.shape[-1] == 6, "The last dimension must be 6"
322
+ x_raw = cont6d[..., 0:3]
323
+ y_raw = cont6d[..., 3:6]
324
+
325
+ x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
326
+ z = torch.cross(x, y_raw, dim=-1)
327
+ z = z / torch.norm(z, dim=-1, keepdim=True)
328
+
329
+ y = torch.cross(z, x, dim=-1)
330
+
331
+ x = x[..., None]
332
+ y = y[..., None]
333
+ z = z[..., None]
334
+
335
+ mat = torch.cat([x, y, z], dim=-1)
336
+ return mat
337
+
338
+
339
+ def cont6d_to_matrix_np(cont6d):
340
+ q = torch.from_numpy(cont6d).contiguous().float()
341
+ return cont6d_to_matrix(q).numpy()
342
+
343
+
344
+ def qpow(q0, t, dtype=torch.float):
345
+ ''' q0 : tensor of quaternions
346
+ t: tensor of powers
347
+ '''
348
+ q0 = qnormalize(q0)
349
+ theta0 = torch.acos(q0[..., 0])
350
+
351
+ ## if theta0 is close to zero, add epsilon to avoid NaNs
352
+ mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
353
+ theta0 = (1 - mask) * theta0 + mask * 10e-10
354
+ v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
355
+
356
+ if isinstance(t, torch.Tensor):
357
+ q = torch.zeros(t.shape + q0.shape)
358
+ theta = t.view(-1, 1) * theta0.view(1, -1)
359
+ else: ## if t is a number
360
+ q = torch.zeros(q0.shape)
361
+ theta = t * theta0
362
+
363
+ q[..., 0] = torch.cos(theta)
364
+ q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
365
+
366
+ return q.to(dtype)
367
+
368
+
369
+ def qslerp(q0, q1, t):
370
+ '''
371
+ q0: starting quaternion
372
+ q1: ending quaternion
373
+ t: array of points along the way
374
+
375
+ Returns:
376
+ Tensor of Slerps: t.shape + q0.shape
377
+ '''
378
+
379
+ q0 = qnormalize(q0)
380
+ q1 = qnormalize(q1)
381
+ q_ = qpow(qmul(q1, qinv(q0)), t)
382
+
383
+ return qmul(q_,
384
+ q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
385
+
386
+
387
+ def qbetween(v0, v1):
388
+ '''
389
+ find the quaternion used to rotate v0 to v1
390
+ '''
391
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
392
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
393
+
394
+ v = torch.cross(v0, v1)
395
+ w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
396
+ keepdim=True)
397
+ return qnormalize(torch.cat([w, v], dim=-1))
398
+
399
+
400
+ def qbetween_np(v0, v1):
401
+ '''
402
+ find the quaternion used to rotate v0 to v1
403
+ '''
404
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
405
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
406
+
407
+ v0 = torch.from_numpy(v0).float()
408
+ v1 = torch.from_numpy(v1).float()
409
+ return qbetween(v0, v1).numpy()
410
+
411
+
412
+ def lerp(p0, p1, t):
413
+ if not isinstance(t, torch.Tensor):
414
+ t = torch.Tensor([t])
415
+
416
+ new_shape = t.shape + p0.shape
417
+ new_view_t = t.shape + torch.Size([1] * len(p0.shape))
418
+ new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
419
+ p0 = p0.view(new_view_p).expand(new_shape)
420
+ p1 = p1.view(new_view_p).expand(new_shape)
421
+ t = t.view(new_view_t).expand(new_shape)
422
+
423
+ return p0 + t * (p1 - p0)
utils/utils_model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.optim as optim
4
+ import logging
5
+ import os
6
+ import sys
7
+
8
+ def getCi(accLog):
9
+
10
+ mean = np.mean(accLog)
11
+ std = np.std(accLog)
12
+ ci95 = 1.96*std/np.sqrt(len(accLog))
13
+
14
+ return mean, ci95
15
+
16
+ def get_logger(out_dir):
17
+ logger = logging.getLogger('Exp')
18
+ logger.setLevel(logging.INFO)
19
+ formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
20
+
21
+ file_path = os.path.join(out_dir, "run.log")
22
+ file_hdlr = logging.FileHandler(file_path)
23
+ file_hdlr.setFormatter(formatter)
24
+
25
+ strm_hdlr = logging.StreamHandler(sys.stdout)
26
+ strm_hdlr.setFormatter(formatter)
27
+
28
+ logger.addHandler(file_hdlr)
29
+ logger.addHandler(strm_hdlr)
30
+ return logger
31
+
32
+ ## Optimizer
33
+ def initial_optim(decay_option, lr, weight_decay, net, optimizer) :
34
+
35
+ if optimizer == 'adamw' :
36
+ optimizer_adam_family = optim.AdamW
37
+ elif optimizer == 'adam' :
38
+ optimizer_adam_family = optim.Adam
39
+ if decay_option == 'all':
40
+ #optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
41
+ optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.5, 0.9), weight_decay=weight_decay)
42
+
43
+ elif decay_option == 'noVQ':
44
+ all_params = set(net.parameters())
45
+ no_decay = set([net.vq_layer])
46
+
47
+ decay = all_params - no_decay
48
+ optimizer = optimizer_adam_family([
49
+ {'params': list(no_decay), 'weight_decay': 0},
50
+ {'params': list(decay), 'weight_decay' : weight_decay}], lr=lr)
51
+
52
+ return optimizer
53
+
54
+
55
+ def get_motion_with_trans(motion, velocity) :
56
+ '''
57
+ motion : torch.tensor, shape (batch_size, T, 72), with the global translation = 0
58
+ velocity : torch.tensor, shape (batch_size, T, 3), contain the information of velocity = 0
59
+
60
+ '''
61
+ trans = torch.cumsum(velocity, dim=1)
62
+ trans = trans - trans[:, :1] ## the first root is initialized at 0 (just for visualization)
63
+ trans = trans.repeat((1, 1, 21))
64
+ motion_with_trans = motion + trans
65
+ return motion_with_trans
66
+
utils/word_vectorizer.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ from os.path import join as pjoin
4
+
5
+ POS_enumerator = {
6
+ 'VERB': 0,
7
+ 'NOUN': 1,
8
+ 'DET': 2,
9
+ 'ADP': 3,
10
+ 'NUM': 4,
11
+ 'AUX': 5,
12
+ 'PRON': 6,
13
+ 'ADJ': 7,
14
+ 'ADV': 8,
15
+ 'Loc_VIP': 9,
16
+ 'Body_VIP': 10,
17
+ 'Obj_VIP': 11,
18
+ 'Act_VIP': 12,
19
+ 'Desc_VIP': 13,
20
+ 'OTHER': 14,
21
+ }
22
+
23
+ Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
24
+ 'up', 'down', 'straight', 'curve')
25
+
26
+ Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
27
+
28
+ Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
29
+
30
+ Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
31
+ 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
32
+ 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
33
+
34
+ Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
35
+ 'angrily', 'sadly')
36
+
37
+ VIP_dict = {
38
+ 'Loc_VIP': Loc_list,
39
+ 'Body_VIP': Body_list,
40
+ 'Obj_VIP': Obj_List,
41
+ 'Act_VIP': Act_list,
42
+ 'Desc_VIP': Desc_list,
43
+ }
44
+
45
+
46
+ class WordVectorizer(object):
47
+ def __init__(self, meta_root, prefix):
48
+ vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix))
49
+ words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb'))
50
+ self.word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb'))
51
+ self.word2vec = {w: vectors[self.word2idx[w]] for w in words}
52
+
53
+ def _get_pos_ohot(self, pos):
54
+ pos_vec = np.zeros(len(POS_enumerator))
55
+ if pos in POS_enumerator:
56
+ pos_vec[POS_enumerator[pos]] = 1
57
+ else:
58
+ pos_vec[POS_enumerator['OTHER']] = 1
59
+ return pos_vec
60
+
61
+ def __len__(self):
62
+ return len(self.word2vec)
63
+
64
+ def __getitem__(self, item):
65
+ word, pos = item.split('/')
66
+ if word in self.word2vec:
67
+ word_vec = self.word2vec[word]
68
+ vip_pos = None
69
+ for key, values in VIP_dict.items():
70
+ if word in values:
71
+ vip_pos = key
72
+ break
73
+ if vip_pos is not None:
74
+ pos_vec = self._get_pos_ohot(vip_pos)
75
+ else:
76
+ pos_vec = self._get_pos_ohot(pos)
77
+ else:
78
+ word_vec = self.word2vec['unk']
79
+ pos_vec = self._get_pos_ohot('OTHER')
80
+ return word_vec, pos_vec
81
+
82
+
83
+ class WordVectorizerV2(WordVectorizer):
84
+ def __init__(self, meta_root, prefix):
85
+ super(WordVectorizerV2, self).__init__(meta_root, prefix)
86
+ self.idx2word = {self.word2idx[w]: w for w in self.word2idx}
87
+
88
+ def __getitem__(self, item):
89
+ word_vec, pose_vec = super(WordVectorizerV2, self).__getitem__(item)
90
+ word, pos = item.split('/')
91
+ if word in self.word2vec:
92
+ return word_vec, pose_vec, self.word2idx[word]
93
+ else:
94
+ return word_vec, pose_vec, self.word2idx['unk']
95
+
96
+ def itos(self, idx):
97
+ if idx == len(self.idx2word):
98
+ return "pad"
99
+ return self.idx2word[idx]