jaskaran Singh commited on
Commit
22a7887
·
1 Parent(s): 92172e2
Files changed (46) hide show
  1. .gitattributes +4 -0
  2. LICENSE +201 -0
  3. README.md +1 -3
  4. maha_tts/__init__.py +1 -0
  5. maha_tts/__pycache__/__init__.cpython-311.pyc +0 -0
  6. maha_tts/__pycache__/config.cpython-311.pyc +0 -0
  7. maha_tts/__pycache__/inference.cpython-311.pyc +0 -0
  8. maha_tts/config.py +23 -0
  9. maha_tts/dataloaders/__init__.py +0 -0
  10. maha_tts/inference.py +254 -0
  11. maha_tts/models/__init__.py +0 -0
  12. maha_tts/models/__pycache__/__init__.cpython-311.pyc +0 -0
  13. maha_tts/models/__pycache__/autoregressive.cpython-311.pyc +0 -0
  14. maha_tts/models/__pycache__/diff_model.cpython-311.pyc +0 -0
  15. maha_tts/models/__pycache__/modules.cpython-311.pyc +0 -0
  16. maha_tts/models/__pycache__/vocoder.cpython-311.pyc +0 -0
  17. maha_tts/models/autoregressive.py +135 -0
  18. maha_tts/models/diff_model.py +303 -0
  19. maha_tts/models/modules.py +406 -0
  20. maha_tts/models/vocoder.py +342 -0
  21. maha_tts/pretrained_models/.DS_Store +0 -0
  22. maha_tts/pretrained_models/hifigan/config.json +3 -0
  23. maha_tts/pretrained_models/hifigan/g_02500000 +3 -0
  24. maha_tts/pretrained_models/smolie/S2A/s2a_latest.pt +3 -0
  25. maha_tts/pretrained_models/smolie/T2S/t2s_best.pt +3 -0
  26. maha_tts/text/__init__.py +0 -0
  27. maha_tts/text/__pycache__/__init__.cpython-311.pyc +0 -0
  28. maha_tts/text/__pycache__/cleaners.cpython-311.pyc +0 -0
  29. maha_tts/text/__pycache__/symbols.cpython-311.pyc +0 -0
  30. maha_tts/text/cleaners.py +143 -0
  31. maha_tts/text/symbols.py +28 -0
  32. maha_tts/utils/__init__.py +0 -0
  33. maha_tts/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  34. maha_tts/utils/__pycache__/audio.cpython-311.pyc +0 -0
  35. maha_tts/utils/__pycache__/diffusion.cpython-311.pyc +0 -0
  36. maha_tts/utils/__pycache__/stft.cpython-311.pyc +0 -0
  37. maha_tts/utils/audio.py +109 -0
  38. maha_tts/utils/diffusion.py +1283 -0
  39. maha_tts/utils/stft.py +109 -0
  40. ref_clips/2971_4275_000003_000007.wav +0 -0
  41. ref_clips/2971_4275_000020_000001.wav +0 -0
  42. ref_clips/2971_4275_000023_000010.wav +0 -0
  43. ref_clips/2971_4275_000049_000000.wav +0 -0
  44. ref_clips/2971_4275_000049_000004.wav +0 -0
  45. ref_clips/2971_4275_000050_000000.wav +0 -0
  46. tts.py +14 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ maha_tts/pretrained_models/smolie/T2S/t2s_best.pt filter=lfs diff=lfs merge=lfs -text
37
+ maha_tts/pretrained_models/smolie/S2A/s2a_latest.pt filter=lfs diff=lfs merge=lfs -text
38
+ maha_tts/pretrained_models/hifigan/config.json filter=lfs diff=lfs merge=lfs -text
39
+ maha_tts/pretrained_models/hifigan/g_02500000 filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
1
+ # MahaTTS
 
 
maha_tts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .inference import load_models,load_diffuser,infer_tts
maha_tts/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (285 Bytes). View file
 
maha_tts/__pycache__/config.cpython-311.pyc ADDED
Binary file (817 Bytes). View file
 
maha_tts/__pycache__/inference.cpython-311.pyc ADDED
Binary file (17.1 kB). View file
 
maha_tts/config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ class config:
3
+
4
+ semantic_model_centroids = 10000 + 1
5
+ seed_value = 3407
6
+
7
+ # Text to Semantic
8
+ t2s_position = 2048
9
+
10
+ # Semantic to acoustic
11
+ sa_timesteps_max = 1000
12
+
13
+ #Acoustic Properties
14
+ CLIP_LENGTH = 500
15
+ MAX_WAV_VALUE=32768.0
16
+ filter_length=1024
17
+ hop_length=256 #256
18
+ window = 'hann'
19
+ win_length=1024
20
+ n_mel_channels=80
21
+ sampling_rate=22050
22
+ mel_fmin=0.0
23
+ mel_fmax=8000.0
maha_tts/dataloaders/__init__.py ADDED
File without changes
maha_tts/inference.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch,glob,os
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ from librosa.filters import mel as librosa_mel_fn
6
+ from scipy.io.wavfile import write
7
+ from scipy.special import softmax
8
+ from maha_tts.models.diff_model import load_diff_model
9
+ from maha_tts.models.autoregressive import load_TS_model
10
+ from maha_tts.models.vocoder import load_vocoder_model,infer_wav
11
+ from maha_tts.utils.audio import denormalize_tacotron_mel,normalize_tacotron_mel,load_wav_to_torch,dynamic_range_compression
12
+ from maha_tts.utils.stft import STFT
13
+ from maha_tts.utils.diffusion import SpacedDiffusion,get_named_beta_schedule,space_timesteps
14
+ from maha_tts.text.symbols import labels,text_labels,code_labels,text_enc,text_dec,code_enc,code_dec
15
+ from maha_tts.text.cleaners import english_cleaners
16
+ from maha_tts.config import config
17
+
18
+ stft_fn = STFT(config.filter_length, config.hop_length, config.win_length)
19
+
20
+ mel_basis = librosa_mel_fn(
21
+ sr=config.sampling_rate, n_fft=config.filter_length, n_mels=config.n_mel_channels, fmin=config.mel_fmin, fmax=config.mel_fmax)
22
+
23
+ mel_basis = torch.from_numpy(mel_basis).float()
24
+
25
+ model_dirs= {
26
+ 'Smolie':'asdf',
27
+ 'hifigan':'asdf'
28
+ }
29
+
30
+ def download_model(name):
31
+ pass
32
+
33
+
34
+ def load_models(name,device=torch.device('cpu')):
35
+ '''
36
+ Load pre-trained models for different components of a text-to-speech system.
37
+
38
+ Args:
39
+ device (str): The target device for model loading (e.g., 'cpu' or 'cuda').
40
+ checkpoint_diff (str): File path to the pre-trained model checkpoint for the diffusion model.
41
+ checkpoint_ts (str): File path to the pre-trained model checkpoint for the text-to-semantic model.
42
+ checkpoint_voco (str): File path to the pre-trained model checkpoint for the vocoder model.
43
+ voco_config_path (str): File path to the configuration file for the vocoder model.
44
+
45
+ Returns:
46
+ diff_model (object): Loaded diffusion model for semantic-to-acoustic tokens.
47
+ ts_model (object): Loaded text-to-semantic model for converting text-to-semantic tokens.
48
+ vocoder (object): Loaded vocoder model for generating waveform from acoustic tokens.
49
+ diffuser (object): Configured diffuser object for use in the diffusion model.
50
+ '''
51
+
52
+ assert name in model_dirs, "no model name "+name
53
+
54
+ checkpoint_diff = 'maha_tts/pretrained_models/'+str(name)+'/S2A/s2a_latest.pt'
55
+ checkpoint_ts = 'maha_tts/pretrained_models/'+str(name)+'/T2S/t2s_best.pt'
56
+ checkpoint_voco = 'maha_tts/pretrained_models/hifigan/g_02500000'
57
+ voco_config_path = 'maha_tts/pretrained_models/hifigan/config.json'
58
+
59
+ # for i in [checkpoint_diff,checkpoint_ts,checkpoint_voco,voco_config_path]:
60
+ if not os.path.exists(checkpoint_diff) or not os.path.exists(checkpoint_ts):
61
+ download_model(name)
62
+
63
+ if not os.path.exists(checkpoint_voco) or not os.path.exists(voco_config_path):
64
+ download_model('hifigan')
65
+
66
+ diff_model = load_diff_model(checkpoint_diff,device)
67
+ ts_model = load_TS_model(checkpoint_ts,device)
68
+ vocoder = load_vocoder_model(voco_config_path,checkpoint_voco,device)
69
+ diffuser = load_diffuser()
70
+
71
+ return diff_model,ts_model,vocoder,diffuser
72
+
73
+ def infer_mel(model,timeshape,code,ref_mel,diffuser,temperature=0.1):
74
+ device = next(model.parameters()).device
75
+ code = code.to(device)
76
+ output_shape = (1,80,timeshape)
77
+ noise = torch.randn(output_shape, device=code.device) * temperature
78
+ mel = diffuser.p_sample_loop(model, output_shape, noise=noise,
79
+ model_kwargs={'code_emb': code,'ref_clips':ref_mel},
80
+ progress=True)
81
+ return denormalize_tacotron_mel(mel)
82
+
83
+ def generate_semantic_tokens(
84
+ text,
85
+ model,
86
+ ref_mels,
87
+ temp = 0.7,
88
+ top_p= None,
89
+ top_k= None,
90
+ n_tot_steps = 1000,
91
+ device = None
92
+ ):
93
+ semb = []
94
+ with torch.no_grad():
95
+ for n in range(n_tot_steps):
96
+ x = get_inputs(text,semb,ref_mels,device)
97
+ _,result = model(**x)
98
+ relevant_logits = result[0,:,-1]
99
+ if top_p is not None:
100
+ # faster to convert to numpy
101
+ original_device = relevant_logits.device
102
+ relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
103
+ sorted_indices = np.argsort(relevant_logits)[::-1]
104
+ sorted_logits = relevant_logits[sorted_indices]
105
+ cumulative_probs = np.cumsum(softmax(sorted_logits))
106
+ sorted_indices_to_remove = cumulative_probs > top_p
107
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
108
+ sorted_indices_to_remove[0] = False
109
+ relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
110
+ relevant_logits = torch.from_numpy(relevant_logits)
111
+ relevant_logits = relevant_logits.to(original_device)
112
+
113
+ if top_k is not None:
114
+ v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
115
+ relevant_logits[relevant_logits < v[-1]] = -float("Inf")
116
+
117
+ probs = F.softmax(relevant_logits / temp, dim=-1)
118
+ item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
119
+ semb.append(str(code_dec[item_next.item()]))
120
+ if semb[-1] == '<EST>' or semb[-1] == '<PAD>':
121
+ break
122
+
123
+ del relevant_logits, probs, item_next
124
+
125
+ semb = torch.tensor([int(i) for i in semb[:-1]])
126
+ return semb,result
127
+
128
+ def get_inputs(text,semb=[],ref_mels=[],device=torch.device('cpu')):
129
+ text = text.lower()
130
+ text_ids=[text_enc['<S>']]+[text_enc[i] for i in text.strip()]+[text_enc['<E>']]
131
+ semb_ids=[code_enc['<SST>']]+[code_enc[i] for i in semb]#+[tok_enc['<EST>']]
132
+
133
+ input_ids = text_ids+semb_ids
134
+ # pad_length = config.t2s_position-(len(text_ids)+len(semb_ids))
135
+
136
+ token_type_ids = [0]*len(text_ids)+[1]*len(semb_ids)#+[0]*pad_length
137
+ positional_ids = [i for i in range(len(text_ids))]+[i for i in range(len(semb_ids))]#+[0]*pad_length
138
+ # labels = [-100]*len(text_ids)+semb_ids+[-100]*pad_length
139
+ attention_mask = [1]*len(input_ids)#+[0]*pad_length
140
+ # input_ids += [tok_enc['<PAD>']]*pad_length
141
+ return {'text_ids':torch.tensor(text_ids).unsqueeze(0).to(device),'codes_ids':torch.tensor(semb_ids).unsqueeze(0).to(device),'ref_clips':normalize_tacotron_mel(ref_mels).to(device)}
142
+
143
+ def get_ref_mels(ref_clips):
144
+ ref_mels = []
145
+ for i in ref_clips:
146
+ ref_mels.append(get_mel(i)[0][:,:500])
147
+
148
+ ref_mels_padded = (torch.randn((len(ref_mels), 80, 500)))*1e-8
149
+ for i,mel in enumerate(ref_mels):
150
+ ref_mels_padded[i, :, :mel.size(1)] = mel
151
+ return ref_mels_padded.unsqueeze(0)
152
+
153
+ def get_mel(filepath):
154
+ audio, sampling_rate = load_wav_to_torch(filepath)
155
+ audio_norm = audio / config.MAX_WAV_VALUE
156
+ audio_norm = audio_norm.unsqueeze(0)
157
+ y = torch.autograd.Variable(audio_norm, requires_grad=False)
158
+
159
+ assert(torch.min(y.data) >= -1)
160
+ assert(torch.max(y.data) <= 1)
161
+ magnitudes, phases = stft_fn.transform(y)
162
+ magnitudes = magnitudes.data
163
+ mel_output = torch.matmul(mel_basis, magnitudes)
164
+ mel_output = dynamic_range_compression(mel_output)
165
+ melspec = torch.squeeze(mel_output, 0)
166
+ energy = torch.norm(magnitudes, dim=1).squeeze(0)
167
+ return melspec,list(energy)
168
+
169
+ def infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder):
170
+ '''
171
+ Generate audio from the given text using a text-to-speech (TTS) pipeline.
172
+
173
+ Args:
174
+ text (str): The input text to be synthesized into speech.
175
+ ref_clips (list): A list of paths to reference audio clips, preferably more than 3 clips.
176
+ diffuser (object): A diffusion object used for denoising and guidance in the diffusion model. It should be obtained using load_diffuser.
177
+ diff_model: diffusion model for semantic-to-acoustic tokens.
178
+ ts_model: text-to-semantic model for converting text-to-semantic tokens.
179
+ vocoder: vocoder model for generating waveform from acoustic tokens.
180
+
181
+ Returns:
182
+ audio (numpy.ndarray): Generated audio waveform.
183
+ sampling_rate (int): Sampling rate of the generated audio.
184
+
185
+ Description:
186
+ The `infer_tts` function takes input text and reference audio clips, and processes them through a TTS pipeline.
187
+ It first performs text preprocessing and generates semantic tokens using the specified text synthesis model.
188
+ Then, it infers mel-spectrogram features using the diffusion model and the provided diffuser.
189
+ Finally, it generates audio from the mel-spectrogram using the vocoder.
190
+
191
+ Note: The function requires properly configured diff_model, ts_model, and vocoder objects for successful TTS.
192
+
193
+ Example usage:
194
+ audio, sampling_rate = infer_tts("Hello, how are you?", ref_clips, diffuser, diff_model, ts_model, vocoder)
195
+ '''
196
+ text = english_cleaners(text)
197
+ ref_mels = get_ref_mels(ref_clips)
198
+ with torch.no_grad():
199
+ sem_tok,_ = generate_semantic_tokens(
200
+ text,
201
+ ts_model,
202
+ ref_mels,
203
+ temp = 0.7,
204
+ top_p= 0.8,
205
+ top_k= 5,
206
+ n_tot_steps = 1000,
207
+ device = None
208
+ )
209
+ mel = infer_mel(diff_model,int(((sem_tok.shape[-1] * 320 / 16000) * 22050/256)+1),sem_tok.unsqueeze(0) + 1,
210
+ ref_mels,diffuser,temperature=1.0)
211
+
212
+ audio = infer_wav(mel,vocoder)
213
+
214
+ return audio,config.sampling_rate
215
+
216
+ def load_diffuser(timesteps = 100, gudiance=3):
217
+ '''
218
+ Load and configure a diffuser for denoising and guidance in the diffusion model.
219
+
220
+ Args:
221
+ timesteps (int): Number of denoising steps out of 1000. Default is 100.
222
+ guidance (int): Conditioning-free guidance parameter. Default is 3.
223
+
224
+ Returns:
225
+ diffuser (object): Configured diffuser object for use in the diffusion model.
226
+
227
+ Description:
228
+ The `load_diffuser` function initializes a diffuser with specific settings for denoising and guidance.
229
+ '''
230
+ betas = get_named_beta_schedule('cosine',config.sa_timesteps_max)
231
+ diffuser = SpacedDiffusion(use_timesteps=space_timesteps(1000, [timesteps]), model_mean_type='epsilon',
232
+ model_var_type='learned_range', loss_type='rescaled_mse', betas=betas,
233
+ conditioning_free=True, conditioning_free_k=gudiance)
234
+ diffuser.training=False
235
+ return diffuser
236
+
237
+ if __name__ == '__main__':
238
+
239
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
240
+ print(device)
241
+ text = 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition.'
242
+ ref_clips = glob.glob('/Users/jaskaransingh/Desktop/maha_tts/ref_clips/*.wav')
243
+
244
+ checkpoint_diff = 'maha_tts/pretrained_models/S2A/s2a_latest.pt'
245
+ checkpoint_ts = 'maha_tts/pretrained_models/T2S/t2s_best.pt'
246
+ checkpoint_voco = 'maha_tts/pretrained_models/hifigan/g_02500000'
247
+ voco_config_path = 'maha_tts/pretrained_models/hifigan/config.json'
248
+
249
+ diffuser = load_diffuser()
250
+ diff_model,ts_model,vocoder = load_models(device,checkpoint_diff,checkpoint_ts,checkpoint_voco,voco_config_path)
251
+ audio,sr = infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder)
252
+ write('test.wav',sr,audio)
253
+
254
+
maha_tts/models/__init__.py ADDED
File without changes
maha_tts/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (177 Bytes). View file
 
maha_tts/models/__pycache__/autoregressive.cpython-311.pyc ADDED
Binary file (9.89 kB). View file
 
maha_tts/models/__pycache__/diff_model.cpython-311.pyc ADDED
Binary file (18.9 kB). View file
 
maha_tts/models/__pycache__/modules.cpython-311.pyc ADDED
Binary file (28.6 kB). View file
 
maha_tts/models/__pycache__/vocoder.cpython-311.pyc ADDED
Binary file (22.8 kB). View file
 
maha_tts/models/autoregressive.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Inspiration taken from https://github.com/neonbjb/tortoise-tts/blob/main/tortoise/models/autoregressive.py
3
+ '''
4
+ import os,sys
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.optim as optim
9
+ import functools
10
+
11
+ from typing import Any
12
+ from torch.utils.data import Dataset,DataLoader
13
+ from transformers import GPT2Tokenizer,GPT2Config, GPT2Model, GPT2LMHeadModel
14
+ from tqdm import tqdm
15
+ from maha_tts.config import config
16
+ from maha_tts.text.symbols import labels,code_labels,text_labels
17
+ from maha_tts.models.modules import GST
18
+
19
+ def null_position_embeddings(range, dim):
20
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
21
+
22
+ class TS_model(nn.Module):
23
+ def __init__(self,n_embed = 512, n_layer = 16, n_head = 8):
24
+ super(TS_model,self).__init__()
25
+
26
+ self.vocab_size=len(labels)
27
+ self.n_positions=config.t2s_position
28
+ self.n_embed=n_embed
29
+ self.n_layer=n_layer
30
+ self.n_head=n_head
31
+
32
+ self.config = GPT2Config(vocab_size=self.vocab_size,n_positions=self.n_positions,n_embd=self.n_embed,n_layer=self.n_layer,n_head=self.n_head)
33
+ self.gpt = GPT2Model(self.config)
34
+ del self.gpt.wpe
35
+ self.gpt.wpe = functools.partial(null_position_embeddings, dim=self.n_embed)
36
+ # Built-in token embeddings are unused.
37
+ del self.gpt.wte
38
+ self.GST = GST(model_channels=self.n_embed,num_heads=self.n_head,in_channels=config.n_mel_channels,k=1)
39
+ self.text_head = nn.Linear(self.n_embed,len(text_labels))
40
+ self.code_head = nn.Linear(self.n_embed,len(code_labels))
41
+
42
+ self.text_positional_embed = LearnedPositionEmbeddings(self.n_positions,self.n_embed)
43
+ self.code_positional_embed = LearnedPositionEmbeddings(self.n_positions,self.n_embed)
44
+
45
+ self.text_embed = nn.Embedding(len(text_labels),self.n_embed)
46
+ self.code_embed = nn.Embedding(len(code_labels),self.n_embed)
47
+ self.final_norm = nn.LayerNorm(self.n_embed)
48
+
49
+ def get_speaker_latent(self, ref_mels):
50
+ ref_mels = ref_mels.unsqueeze(1) if len(
51
+ ref_mels.shape) == 3 else ref_mels
52
+
53
+ conds = []
54
+ for j in range(ref_mels.shape[1]):
55
+ conds.append(self.GST(ref_mels[:, j,:,:]))
56
+
57
+ conds = torch.cat(conds, dim=-1)
58
+ conds = conds.mean(dim=-1)
59
+
60
+ return conds.unsqueeze(1)
61
+
62
+ def forward(self,text_ids,codes_ids = None,speaker_embed=None,ref_clips=None,return_loss = False):
63
+ assert speaker_embed is not None or ref_clips is not None
64
+ text_embed = self.text_embed(text_ids)
65
+ text_embed += self.text_positional_embed(text_embed)
66
+
67
+ code_embed = None
68
+ code_probs= None
69
+
70
+ if codes_ids is not None:
71
+ code_embed = self.code_embed(codes_ids)
72
+ code_embed+= self.code_positional_embed(code_embed)
73
+
74
+ if ref_clips is not None:
75
+ speaker_embed = self.get_speaker_latent(ref_clips)
76
+
77
+ text_embed,code_embed = self.get_logits(speaker_embed=speaker_embed,text_embed=text_embed,code_embed=code_embed)
78
+
79
+ text_probs = self.text_head(text_embed).permute(0,2,1)
80
+
81
+ if codes_ids is not None:
82
+ code_probs = self.code_head(code_embed).permute(0,2,1)
83
+
84
+ if return_loss:
85
+ loss_text = F.cross_entropy(text_probs[:,:,:-1], text_ids[:,1:].long(), reduce=False)
86
+ loss_mel = F.cross_entropy(code_probs[:,:,:-1], codes_ids[:,1:].long(), reduce=False)
87
+ return loss_text,loss_mel,code_probs
88
+
89
+ return text_probs,code_probs
90
+
91
+
92
+ def get_logits(self,speaker_embed,text_embed,code_embed=None):
93
+
94
+ if code_embed is not None:
95
+ embed = torch.cat([speaker_embed,text_embed,code_embed],dim=1)
96
+ else:
97
+ embed = torch.cat([speaker_embed,text_embed],dim=1)
98
+
99
+ gpt_output = self.gpt(inputs_embeds=embed, return_dict=True)
100
+ enc = gpt_output.last_hidden_state[:, 1:]
101
+ enc = self.final_norm(enc)
102
+ if code_embed is not None:
103
+ return enc[:,:text_embed.shape[1]],enc[:,-code_embed.shape[1]:]
104
+
105
+ return enc[:,:text_embed.shape[1]],None
106
+
107
+ class LearnedPositionEmbeddings(nn.Module):
108
+ def __init__(self, seq_len, model_dim, init=.02):
109
+ super().__init__()
110
+ self.emb = nn.Embedding(seq_len, model_dim)
111
+ # Initializing this way is standard for GPT-2
112
+ self.emb.weight.data.normal_(mean=0.0, std=init)
113
+
114
+ def forward(self, x):
115
+ sl = x.shape[1]
116
+ return self.emb(torch.arange(0, sl, device=x.device))
117
+
118
+ def get_fixed_embedding(self, ind, dev):
119
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
120
+
121
+ def load_TS_model(checkpoint,device):
122
+ sem_model= TS_model(n_embed = 512, n_layer = 16, n_head = 8)
123
+ sem_model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu')),strict=False)
124
+ sem_model.eval().to(device)
125
+
126
+ return sem_model
127
+
128
+ if __name__ == '__main__':
129
+ model=TS_model(n_embed = 256, n_layer = 6, n_head = 4)
130
+
131
+ text_ids = torch.randint(0,100,(5,20))
132
+ code_ids = torch.randint(0,100,(5,200))
133
+ speaker_embed = torch.randn((5,1,256))
134
+
135
+ output=model(text_ids=text_ids,speaker_embed=speaker_embed,codes_ids=code_ids,return_loss=True)
maha_tts/models/diff_model.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ inspiration taken from https://github.com/neonbjb/tortoise-tts/blob/main/tortoise/models/diffusion_decoder.py
3
+ '''
4
+ import sys
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import math
9
+
10
+ from maha_tts.config import config
11
+ from torch import autocast
12
+ from maha_tts.models.modules import QuartzNetBlock,AttentionBlock,mySequential,normalization,SCBD,SqueezeExcite,GST
13
+
14
+ def timestep_embedding(timesteps, dim, max_period=10000):
15
+ """
16
+ Create sinusoidal timestep embeddings.
17
+
18
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
19
+ These may be fractional.
20
+ :param dim: the dimension of the output.
21
+ :param max_period: controls the minimum frequency of the embeddings.
22
+ :return: an [N x dim] Tensor of positional embeddings.
23
+ """
24
+ half = dim // 2
25
+ freqs = torch.exp(
26
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
27
+ ).to(device=timesteps.device)
28
+ args = timesteps[:, None].float() * freqs[None]
29
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
30
+ if dim % 2:
31
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
32
+ return embedding
33
+
34
+ class TimestepBlock(nn.Module):
35
+ def forward(self, x, emb):
36
+ """
37
+ Apply the module to `x` given `emb` timestep embeddings.
38
+ """
39
+
40
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
41
+ def forward(self, x, emb):
42
+ for layer in self:
43
+ if isinstance(layer, TimestepBlock):
44
+ x = layer(x, emb)
45
+ else:
46
+ x = layer(x)
47
+ return x
48
+
49
+ class QuartzNetBlock(TimestepBlock):
50
+ '''Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
51
+ if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
52
+ '''
53
+ def __init__(self,nin,nout,emb_channels,kernel_size=3,dropout=0.1,R=1,se=True,ratio=8,separable=False,bias=True,use_scale_shift_norm=True):
54
+ super(QuartzNetBlock,self).__init__()
55
+ self.use_scale_shift_norm = use_scale_shift_norm
56
+ self.se=se
57
+ self.in_layers = mySequential(
58
+ nn.Conv1d(nin,nout,kernel_size=1,padding='same',bias=bias),
59
+ normalization(nout) #nn.BatchNorm1d(nout,eps)
60
+ )
61
+
62
+ self.residual=mySequential(
63
+ nn.Conv1d(nin,nout,kernel_size=1,padding='same',bias=bias),
64
+ normalization(nout) #nn.BatchNorm1d(nout,eps)
65
+ )
66
+
67
+ nin=nout
68
+ model=[]
69
+
70
+ self.emb_layers = nn.Sequential(
71
+ nn.SiLU(),
72
+ nn.Linear(
73
+ emb_channels,
74
+ 2 * nout if use_scale_shift_norm else nout,
75
+ ),
76
+ )
77
+
78
+ for i in range(R-1):
79
+ model.append(SCBD(nin,nout,kernel_size,dropout,bias=bias))
80
+ nin=nout
81
+
82
+ if separable:
83
+ model.append(SCBD(nin,nout,kernel_size,dropout,rd=False,bias=bias))
84
+ else:
85
+ model.append(SCBD(nin,nout,kernel_size,dropout,rd=False,separable=False,bias=bias))
86
+
87
+ self.model=mySequential(*model)
88
+ if self.se:
89
+ self.se_layer=SqueezeExcite(nin,ratio)
90
+
91
+ self.mout= mySequential(nn.SiLU(),nn.Dropout(dropout))
92
+
93
+ def forward(self,x,emb,mask=None):
94
+ x_new=self.in_layers(x)
95
+ emb = self.emb_layers(emb)
96
+ while len(emb.shape) < len(x_new.shape):
97
+ emb = emb[..., None]
98
+ scale, shift = torch.chunk(emb, 2, dim=1)
99
+ x_new = x_new * (1 + scale) + shift
100
+ y,_=self.model(x_new)
101
+
102
+ if self.se:
103
+ y,_=self.se_layer(y,mask)
104
+ y+=self.residual(x)
105
+ y=self.mout(y)
106
+
107
+ return y
108
+
109
+ class QuartzAttn(TimestepBlock):
110
+ def __init__(self, model_channels, dropout, num_heads):
111
+ super().__init__()
112
+ self.resblk = QuartzNetBlock(model_channels, model_channels, model_channels,dropout=dropout,use_scale_shift_norm=True)
113
+ self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
114
+
115
+ def forward(self, x, time_emb):
116
+ y = self.resblk(x, time_emb)
117
+ return self.attn(y)
118
+
119
+ class QuartzNet9x5(nn.Module):
120
+ def __init__(self,model_channels,num_heads,enable_fp16=False):
121
+ super(QuartzNet9x5,self).__init__()
122
+ self.enable_fp16 = enable_fp16
123
+
124
+ self.conv1=QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=3,dropout=0.1,R=3)
125
+ kernels=[5,7,9,13,15,17]
126
+ quartznet=[]
127
+ attn=[]
128
+ for i in kernels:
129
+ quartznet.append(QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=i,dropout=0.1,R=5,se=True))
130
+ attn.append(AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True))
131
+ kernels=[21,23,25]
132
+ quartznet.append(QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=21,dropout=0.1,R=5,se=True))
133
+ attn.append(AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True))
134
+
135
+ for i in kernels[1:]:
136
+ quartznet.append(QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=i,dropout=0.1,R=5,se=True))
137
+ attn.append(AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True))
138
+ self.quartznet= nn.ModuleList(quartznet)
139
+ self.attn = nn.ModuleList(attn)
140
+ self.conv3=nn.Conv1d(model_channels, model_channels, 1, padding='same')
141
+
142
+
143
+ def forward(self, x, time_emb):
144
+ x = self.conv1(x,time_emb)
145
+ # with autocast(x.device.type, enabled=self.enable_fp16):
146
+ for n,(layer,attn) in enumerate(zip(self.quartznet,self.attn)):
147
+ x = layer(x,time_emb) #256 dim
148
+ x = attn(x)
149
+ x = self.conv3(x.float())
150
+ return x
151
+
152
+ class DiffModel(nn.Module):
153
+
154
+ def __init__(
155
+ self,
156
+ input_channels=80,
157
+ output_channels=160,
158
+ model_channels=512,
159
+ num_heads=8,
160
+ dropout=0.0,
161
+ multispeaker = True,
162
+ condition_free_per=0.1,
163
+ training = False,
164
+ ar_active = False,
165
+ in_latent_channels = 10004
166
+ ):
167
+
168
+ super().__init__()
169
+ self.input_channels = input_channels
170
+ self.model_channels = model_channels
171
+ self.output_channels = output_channels
172
+ self.num_heads = num_heads
173
+ self.dropout = dropout
174
+ self.condition_free_per = condition_free_per
175
+ self.training = training
176
+ self.multispeaker = multispeaker
177
+ self.ar_active = ar_active
178
+ self.in_latent_channels = in_latent_channels
179
+
180
+ if not self.ar_active:
181
+ self.code_emb = nn.Embedding(config.semantic_model_centroids+1,model_channels)
182
+ self.code_converter = mySequential(
183
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
184
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
185
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
186
+ )
187
+ else:
188
+ self.code_converter = mySequential(
189
+ nn.Conv1d(self.in_latent_channels, model_channels, 3, padding=1),
190
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
191
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
192
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
193
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
194
+ )
195
+ if self.multispeaker:
196
+ self.GST = GST(model_channels,num_heads)
197
+
198
+ self.code_norm = normalization(model_channels)
199
+ self.time_norm = normalization(model_channels)
200
+ self.noise_norm = normalization(model_channels)
201
+ self.code_time_norm = normalization(model_channels)
202
+
203
+ # self.code_latent = []
204
+ self.time_embed = mySequential(
205
+ nn.Linear(model_channels, model_channels),
206
+ nn.SiLU(),
207
+ nn.Linear(model_channels, model_channels),)
208
+
209
+ self.input_block = nn.Conv1d(input_channels,model_channels,3,1,1)
210
+ self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
211
+
212
+ self.code_time = TimestepEmbedSequential(QuartzAttn(model_channels, dropout, num_heads),QuartzAttn(model_channels, dropout, num_heads),QuartzAttn(model_channels, dropout, num_heads))
213
+ self.layers = QuartzNet9x5(model_channels,num_heads)
214
+
215
+ self.out = nn.Sequential(
216
+ normalization(model_channels),
217
+ nn.SiLU(),
218
+ nn.Conv1d(model_channels, output_channels, 3, padding=1),
219
+ )
220
+
221
+ def get_speaker_latent(self, ref_mels):
222
+ ref_mels = ref_mels.unsqueeze(1) if len(
223
+ ref_mels.shape) == 3 else ref_mels
224
+
225
+ conds = []
226
+ for j in range(ref_mels.shape[1]):
227
+ conds.append(self.GST(ref_mels[:, j,:,:]))
228
+
229
+ conds = torch.cat(conds, dim=-1)
230
+ conds = conds.mean(dim=-1)
231
+
232
+ return conds.unsqueeze(2)
233
+
234
+ def forward(self ,x,t,code_emb,ref_clips=None,speaker_latents=None,conditioning_free=False):
235
+ time_embed = self.time_norm(self.time_embed(timestep_embedding(t.unsqueeze(-1),self.model_channels)).permute(0,2,1)).squeeze(2)
236
+ if conditioning_free:
237
+ code_embed = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
238
+ else:
239
+ if not self.ar_active:
240
+ code_embed = self.code_norm(self.code_converter(self.code_emb(code_emb).permute(0,2,1)))
241
+ else:
242
+ code_embed = self.code_norm(self.code_converter(code_emb))
243
+ if self.multispeaker:
244
+ assert speaker_latents is not None or ref_clips is not None
245
+ if ref_clips is not None:
246
+ speaker_latents = self.get_speaker_latent(ref_clips)
247
+ cond_scale, cond_shift = torch.chunk(speaker_latents, 2, dim=1)
248
+ code_embed = code_embed * (1 + cond_scale) + cond_shift
249
+ if self.training and self.condition_free_per > 0:
250
+ unconditioned_batches = torch.rand((code_embed.shape[0], 1, 1),
251
+ device=code_embed.device) < self.condition_free_per
252
+ code_embed = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_embed.shape[0], 1, 1),
253
+ code_embed)
254
+
255
+ expanded_code_emb = F.interpolate(code_embed, size=x.shape[-1], mode='nearest') #try different modes
256
+
257
+ x_cond = self.code_time_norm(self.code_time(expanded_code_emb,time_embed))
258
+
259
+ x = self.noise_norm(self.input_block(x))
260
+ x += x_cond
261
+ x = self.layers(x, time_embed)
262
+ out = self.out(x)
263
+ return out
264
+
265
+ def load_diff_model(checkpoint,device,model_channels=512,ar_active=False,len_code_labels=10004):
266
+ diff_model = DiffModel(input_channels=80,
267
+ output_channels=160,
268
+ model_channels=512,
269
+ num_heads=8,
270
+ dropout=0.15,
271
+ condition_free_per=0.15,
272
+ multispeaker=True,
273
+ training=False,
274
+ ar_active=ar_active,
275
+ in_latent_channels = len_code_labels)
276
+
277
+ # diff_model.load_state_dict(torch.load('/content/LibriTTS_fp64_10k/S2A/_latest.pt',map_location=torch.device('cpu')),strict=True)
278
+ diff_model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu')),strict=True)
279
+ diff_model=diff_model.eval().to(device)
280
+ return diff_model
281
+
282
+
283
+ if __name__ == '__main__':
284
+
285
+ device = torch.device('cpu')
286
+ diff_model = DiffModel(input_channels=80,
287
+ output_channels=160,
288
+ model_channels=1024,
289
+ num_heads=8,
290
+ dropout=0.1,
291
+ num_layers=8,
292
+ enable_fp16=True,
293
+ condition_free_per=0.1,
294
+ multispeaker=True,
295
+ training=True).to(device)
296
+
297
+ batch_Size = 32
298
+ timeseries = 800
299
+ from torchinfo import summary
300
+ summary(diff_model, input_data={'x': torch.randn(batch_Size, 80, timeseries).to(device),
301
+ 'ref_clips': torch.randn(batch_Size,3, 80, timeseries).to(device),
302
+ 't':torch.LongTensor(size=[batch_Size,]).to(device),
303
+ 'code_emb':torch.randint(0,201,(batch_Size,timeseries)).to(device)})
maha_tts/models/modules.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch,math
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.nn.init as init
5
+ from einops import rearrange, repeat
6
+
7
+ def zero_module(module):
8
+ """
9
+ Zero out the parameters of a module and return it.
10
+ Using it for Zero Convolutions
11
+ """
12
+ for p in module.parameters():
13
+ p.detach().zero_()
14
+ return module
15
+
16
+
17
+ class GroupNorm32(nn.GroupNorm):
18
+ def forward(self, x):
19
+ return super().forward(x.float()).type(x.dtype)
20
+
21
+
22
+ def normalization(channels):
23
+ """
24
+ Make a standard normalization layer. of groups ranging from 2 to 32.
25
+
26
+ :param channels: number of input channels.
27
+ :return: an nn.Module for normalization.
28
+ """
29
+ groups = 32
30
+ if channels <= 16:
31
+ groups = 8
32
+ elif channels <= 64:
33
+ groups = 16
34
+ while channels % groups != 0:
35
+ groups = int(groups / 2)
36
+ assert groups > 2
37
+ return GroupNorm32(groups, channels)
38
+
39
+
40
+ class mySequential(nn.Sequential):
41
+ '''Using this to pass mask variable to nn layers
42
+ '''
43
+ def forward(self, *inputs):
44
+ for module in self._modules.values():
45
+ if type(inputs) == tuple:
46
+ inputs = module(*inputs)
47
+ else:
48
+ inputs = module(inputs)
49
+ return inputs
50
+
51
+ class SepConv1D(nn.Module):
52
+ '''Depth wise separable Convolution layer with mask
53
+ '''
54
+ def __init__(self,nin,nout,kernel_size,stride=1,dilation=1,padding_mode='same',bias=True):
55
+ super(SepConv1D,self).__init__()
56
+ self.conv1=nn.Conv1d(nin, nin, kernel_size=kernel_size, stride=stride,groups=nin,dilation=dilation,padding=padding_mode,bias=bias)
57
+ self.conv2=nn.Conv1d(nin,nout,kernel_size=1,stride=1,padding=padding_mode,bias=bias)
58
+
59
+ def forward(self,x,mask=None):
60
+ if mask is not None:
61
+ x = x * mask.unsqueeze(1).to(device=x.device)
62
+ x=self.conv1(x)
63
+ x=self.conv2(x)
64
+ return x,mask
65
+
66
+ class Conv1DBN(nn.Module):
67
+ def __init__(self,nin,nout,kernel_size,stride=1,dilation=1,dropout=0.1,padding_mode='same',bias=False):
68
+ super(Conv1DBN,self).__init__()
69
+ self.conv1=nn.Conv1d(nin, nout, kernel_size=kernel_size, stride=stride,padding=padding_mode,dilation=dilation,bias=bias)
70
+ self.bn=nn.BatchNorm1d(nout)
71
+ self.drop=nn.Dropout(dropout)
72
+
73
+ def forward(self,x,mask=None):
74
+ if mask is not None:
75
+ x = x * mask.unsqueeze(1).to(device=x.device)
76
+ x=self.conv1(x)
77
+ x=self.bn(x)
78
+ x=F.relu(x)
79
+ x=self.drop(x)
80
+ return x,mask
81
+
82
+ class Conv1d(nn.Module):
83
+ '''normal conv1d with mask
84
+ '''
85
+ def __init__(self,nin,nout,kernel_size,padding,bias=True):
86
+ super(Conv1d,self).__init__()
87
+ self.l=nn.Conv1d(nin,nout,kernel_size,padding=padding,bias=bias)
88
+ def forward(self,x,mask):
89
+ if mask is not None:
90
+ x = x * mask.unsqueeze(1).to(device=x.device)
91
+ y=self.l(x)
92
+ return y,mask
93
+
94
+ class SqueezeExcite(nn.Module):
95
+ '''Let the CNN decide how to add across channels
96
+ '''
97
+ def __init__(self,nin,ratio=8):
98
+ super(SqueezeExcite,self).__init__()
99
+ self.nin=nin
100
+ self.ratio=ratio
101
+
102
+ self.fc=mySequential(
103
+ nn.Linear(nin,nin//ratio,bias=True),nn.SiLU(inplace=True),nn.Linear(nin//ratio,nin,bias=True)
104
+ )
105
+
106
+ def forward(self,x,mask=None):
107
+ if mask is None:
108
+ mask = torch.ones((x.shape[0],x.shape[-1]),dtype=torch.bool).to(x.device)
109
+ mask=~mask
110
+ x=x.float()
111
+ x.masked_fill_(mask.unsqueeze(1), 0.0)
112
+ mask=~mask
113
+ y = (torch.sum(x, dim=-1, keepdim=True) / mask.unsqueeze(1).sum(dim=-1, keepdim=True)).type(x.dtype)
114
+ # y=torch.mean(x,-1,keepdim=True)
115
+ y=y.transpose(1, -1)
116
+ y=self.fc(y)
117
+ y=torch.sigmoid(y)
118
+ y=y.transpose(1, -1)
119
+ y= x * y
120
+ return y,mask
121
+
122
+
123
+
124
+ class SCBD(nn.Module):
125
+ '''SeparableConv1D + Batchnorm + Dropout, Generally use it for middle layers and resnet
126
+ '''
127
+ def __init__(self,nin,nout,kernel_size,p=0.1,rd=True,separable=True,bias=True):
128
+ super(SCBD,self).__init__()
129
+ if separable:
130
+ self.SC=SepConv1D(nin,nout,kernel_size,bias=bias)
131
+ else:
132
+ self.SC=Conv1d(nin,nout,kernel_size,padding='same',bias=bias)
133
+
134
+ if rd: #relu and Dropout
135
+ self.mout=mySequential(normalization(nout),nn.SiLU(), # nn.BatchNorm1d(nout,eps)
136
+ nn.Dropout(p))
137
+ else:
138
+ self.mout=normalization(nout) # nn.BatchNorm1d(nout,eps)
139
+
140
+ def forward(self,x,mask=None):
141
+ if mask is not None:
142
+ x = x * mask.unsqueeze(1).to(device=x.device)
143
+ x,_= self.SC(x,mask)
144
+ y = self.mout(x)
145
+ return y,mask
146
+
147
+ class QuartzNetBlock(nn.Module):
148
+ '''Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
149
+ if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
150
+ '''
151
+ def __init__(self,nin,nout,kernel_size,dropout=0.1,R=5,se=False,ratio=8,separable=False,bias=True):
152
+ super(QuartzNetBlock,self).__init__()
153
+ self.se=se
154
+ self.residual=mySequential(
155
+ nn.Conv1d(nin,nout,kernel_size=1,padding='same',bias=bias),
156
+ normalization(nout) #nn.BatchNorm1d(nout,eps)
157
+ )
158
+ model=[]
159
+
160
+ for i in range(R-1):
161
+ model.append(SCBD(nin,nout,kernel_size,dropout,eps=0.001,bias=bias))
162
+ nin=nout
163
+
164
+ if separable:
165
+ model.append(SCBD(nin,nout,kernel_size,dropout,eps=0.001,rd=False,bias=bias))
166
+ else:
167
+ model.append(SCBD(nin,nout,kernel_size,dropout,eps=0.001,rd=False,separable=False,bias=bias))
168
+ self.model=mySequential(*model)
169
+
170
+ if self.se:
171
+ self.se_layer=SqueezeExcite(nin,ratio)
172
+
173
+ self.mout= mySequential(nn.SiLU(),nn.Dropout(dropout))
174
+
175
+ def forward(self,x,mask=None):
176
+ if mask is not None:
177
+ x = x * mask.unsqueeze(1).to(device=x.device)
178
+ y,_=self.model(x,mask)
179
+ if self.se:
180
+ y,_=self.se_layer(y,mask)
181
+ y+=self.residual(x)
182
+ y=self.mout(y)
183
+ return y,mask
184
+
185
+ class QKVAttentionLegacy(nn.Module):
186
+ """
187
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
188
+ """
189
+
190
+ def __init__(self, n_heads):
191
+ super().__init__()
192
+ self.n_heads = n_heads
193
+
194
+ def forward(self, qkv, mask=None, rel_pos=None):
195
+ """
196
+ Apply QKV attention.
197
+
198
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
199
+ :return: an [N x (H * C) x T] tensor after attention.
200
+ """
201
+ bs, width, length = qkv.shape
202
+ assert width % (3 * self.n_heads) == 0
203
+ ch = width // (3 * self.n_heads)
204
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
205
+ scale = 1 / math.sqrt(math.sqrt(ch))
206
+ weight = torch.einsum(
207
+ "bct,bcs->bts", q * scale, k * scale
208
+ ) # More stable with f16 than dividing afterwards
209
+ if rel_pos is not None:
210
+ weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
211
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
212
+ if mask is not None:
213
+ # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
214
+ mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
215
+ weight = weight * mask
216
+ a = torch.einsum("bts,bcs->bct", weight, v)
217
+
218
+ return a.reshape(bs, -1, length)
219
+
220
+ class AttentionBlock(nn.Module):
221
+ """
222
+ An attention block that allows spatial positions to attend to each other.
223
+
224
+ Originally ported from here, but adapted to the N-d case.
225
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
226
+ """
227
+
228
+ def __init__(
229
+ self,
230
+ channels,
231
+ num_heads=1,
232
+ num_head_channels=-1,
233
+ do_checkpoint=True,
234
+ relative_pos_embeddings=False,
235
+ ):
236
+ super().__init__()
237
+ self.channels = channels
238
+ self.do_checkpoint = do_checkpoint
239
+ if num_head_channels == -1:
240
+ self.num_heads = num_heads
241
+ else:
242
+ assert (
243
+ channels % num_head_channels == 0
244
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
245
+ self.num_heads = channels // num_head_channels
246
+ self.norm = normalization(channels)
247
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
248
+ # split heads before split qkv
249
+ self.attention = QKVAttentionLegacy(self.num_heads)
250
+
251
+ self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) # no effect of attention in the inital stages.
252
+ # if relative_pos_embeddings:
253
+ self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) #need to read about this, vit and swin transformers
254
+ # self.relative_pos_embeddings = FixedPositionalEmbedding(dim=channels)
255
+ # else:
256
+ # self.relative_pos_embeddings = None
257
+
258
+ def forward(self, x, mask=None):
259
+ b, c, *spatial = x.shape
260
+ x = x.reshape(b, c, -1)
261
+ qkv = self.qkv(self.norm(x))
262
+ h = self.attention(qkv, mask, self.relative_pos_embeddings)
263
+ h = self.proj_out(h)
264
+ return (x + h).reshape(b, c, *spatial)
265
+
266
+ class AbsolutePositionalEmbedding(nn.Module):
267
+ def __init__(self, dim, max_seq_len):
268
+ super().__init__()
269
+ self.scale = dim ** -0.5
270
+ self.emb = nn.Embedding(max_seq_len, dim)
271
+
272
+ def forward(self, x):
273
+ n = torch.arange(x.shape[1], device=x.device)
274
+ pos_emb = self.emb(n)
275
+ pos_emb = rearrange(pos_emb, 'n d -> () n d')
276
+ return pos_emb * self.scale
277
+
278
+
279
+ class FixedPositionalEmbedding(nn.Module):
280
+ def __init__(self, dim):
281
+ super().__init__()
282
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
283
+ self.register_buffer('inv_freq', inv_freq)
284
+
285
+ def forward(self, x, seq_dim=1, offset=0):
286
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
287
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
288
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
289
+ return rearrange(emb, 'n d -> () n d')
290
+
291
+ class RelativePositionBias(nn.Module):
292
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
293
+ super().__init__()
294
+ self.scale = scale
295
+ self.causal = causal
296
+ self.num_buckets = num_buckets
297
+ self.max_distance = max_distance
298
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
299
+
300
+ @staticmethod
301
+ def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
302
+ ret = 0
303
+ n = -relative_position
304
+ if not causal:
305
+ num_buckets //= 2
306
+ ret += (n < 0).long() * num_buckets
307
+ n = torch.abs(n)
308
+ else:
309
+ n = torch.max(n, torch.zeros_like(n))
310
+
311
+ max_exact = num_buckets // 2
312
+ is_small = n < max_exact
313
+
314
+ val_if_large = max_exact + (
315
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
316
+ ).long()
317
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
318
+
319
+ ret += torch.where(is_small, n, val_if_large)
320
+ return ret
321
+
322
+ def forward(self, qk_dots):
323
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
324
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
325
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
326
+ rel_pos = k_pos[None, :] - q_pos[:, None]
327
+ rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
328
+ max_distance=self.max_distance)
329
+ values = self.relative_attention_bias(rp_bucket)
330
+ bias = rearrange(values, 'i j h -> () h i j')
331
+ return qk_dots + (bias * self.scale)
332
+
333
+
334
+
335
+ class MultiHeadAttention(nn.Module):
336
+ '''
337
+ only for GST
338
+ input:
339
+ query --- [N, T_q, query_dim]
340
+ key --- [N, T_k, key_dim]
341
+ output:
342
+ out --- [N, T_q, num_units]
343
+ '''
344
+ def __init__(self, query_dim, key_dim, num_units, num_heads):
345
+ super().__init__()
346
+ self.num_units = num_units
347
+ self.num_heads = num_heads
348
+ self.key_dim = key_dim
349
+
350
+ self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
351
+ self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
352
+ self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
353
+
354
+ def forward(self, query, key):
355
+ querys = self.W_query(query) # [N, T_q, num_units]
356
+ keys = self.W_key(key) # [N, T_k, num_units]
357
+ values = self.W_value(key)
358
+
359
+ split_size = self.num_units // self.num_heads
360
+ querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
361
+ keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
362
+ values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
363
+
364
+ # score = softmax(QK^T / (d_k ** 0.5))
365
+ scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
366
+ scores = scores / (self.key_dim ** 0.5)
367
+ scores = F.softmax(scores, dim=3)
368
+
369
+ # out = score * V
370
+ out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
371
+ out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
372
+
373
+ return out
374
+
375
+
376
+ class GST(nn.Module):
377
+ def __init__(self,model_channels=512,num_heads=8,in_channels=80,k=2):
378
+ super(GST,self).__init__()
379
+ self.model_channels=model_channels
380
+ self.num_heads=num_heads
381
+
382
+ self.reference_encoder=nn.Sequential(
383
+ nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
384
+ nn.Conv1d(model_channels, model_channels*k,3,padding=1,stride=2),
385
+ AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
386
+ AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
387
+ AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
388
+ AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
389
+ AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False)
390
+ )
391
+
392
+ def forward(self,x):
393
+ x=self.reference_encoder(x)
394
+ return x
395
+
396
+
397
+ if __name__ == '__main__':
398
+ device = torch.device('cpu')
399
+ m = GST(512,10).to(device)
400
+ mels = torch.rand((16,80,1000)).to(device)
401
+
402
+ o = m(mels)
403
+ print(o.shape,'final output')
404
+
405
+ from torchinfo import summary
406
+ summary(m, input_data={'x': torch.randn(16,80,500).to(device)})
maha_tts/models/vocoder.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ copde from https://github.com/jik876/hifi-gan/blob/master/models.py
3
+ '''
4
+
5
+ import json,os
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.nn as nn
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
11
+ # from utils import init_weights, get_padding
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+ class AttrDict(dict):
16
+ def __init__(self, *args, **kwargs):
17
+ super(AttrDict, self).__init__(*args, **kwargs)
18
+ self.__dict__ = self
19
+
20
+ def init_weights(m, mean=0.0, std=0.01):
21
+ classname = m.__class__.__name__
22
+ if classname.find("Conv") != -1:
23
+ m.weight.data.normal_(mean, std)
24
+
25
+
26
+ def apply_weight_norm(m):
27
+ classname = m.__class__.__name__
28
+ if classname.find("Conv") != -1:
29
+ weight_norm(m)
30
+
31
+
32
+ def get_padding(kernel_size, dilation=1):
33
+ return int((kernel_size*dilation - dilation)/2)
34
+
35
+
36
+ class ResBlock1(torch.nn.Module):
37
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
38
+ super(ResBlock1, self).__init__()
39
+ self.h = h
40
+ self.convs1 = nn.ModuleList([
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
42
+ padding=get_padding(kernel_size, dilation[0]))),
43
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
44
+ padding=get_padding(kernel_size, dilation[1]))),
45
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
46
+ padding=get_padding(kernel_size, dilation[2])))
47
+ ])
48
+ self.convs1.apply(init_weights)
49
+
50
+ self.convs2 = nn.ModuleList([
51
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
52
+ padding=get_padding(kernel_size, 1))),
53
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
54
+ padding=get_padding(kernel_size, 1))),
55
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
56
+ padding=get_padding(kernel_size, 1)))
57
+ ])
58
+ self.convs2.apply(init_weights)
59
+
60
+ def forward(self, x):
61
+ for c1, c2 in zip(self.convs1, self.convs2):
62
+ xt = F.leaky_relu(x, LRELU_SLOPE)
63
+ xt = c1(xt)
64
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
65
+ xt = c2(xt)
66
+ x = xt + x
67
+ return x
68
+
69
+ def remove_weight_norm(self):
70
+ for l in self.convs1:
71
+ remove_weight_norm(l)
72
+ for l in self.convs2:
73
+ remove_weight_norm(l)
74
+
75
+
76
+ class ResBlock2(torch.nn.Module):
77
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
78
+ super(ResBlock2, self).__init__()
79
+ self.h = h
80
+ self.convs = nn.ModuleList([
81
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
82
+ padding=get_padding(kernel_size, dilation[0]))),
83
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
84
+ padding=get_padding(kernel_size, dilation[1])))
85
+ ])
86
+ self.convs.apply(init_weights)
87
+
88
+ def forward(self, x):
89
+ for c in self.convs:
90
+ xt = F.leaky_relu(x, LRELU_SLOPE)
91
+ xt = c(xt)
92
+ x = xt + x
93
+ return x
94
+
95
+ def remove_weight_norm(self):
96
+ for l in self.convs:
97
+ remove_weight_norm(l)
98
+
99
+
100
+ class Generator(torch.nn.Module):
101
+ def __init__(self, h):
102
+ super(Generator, self).__init__()
103
+ self.h = h
104
+ self.num_kernels = len(h.resblock_kernel_sizes)
105
+ self.num_upsamples = len(h.upsample_rates)
106
+ self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
107
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
108
+
109
+ self.ups = nn.ModuleList()
110
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
111
+ self.ups.append(weight_norm(
112
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
113
+ k, u, padding=(k-u)//2)))
114
+
115
+ self.resblocks = nn.ModuleList()
116
+ for i in range(len(self.ups)):
117
+ ch = h.upsample_initial_channel//(2**(i+1))
118
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
119
+ self.resblocks.append(resblock(h, ch, k, d))
120
+
121
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
122
+ self.ups.apply(init_weights)
123
+ self.conv_post.apply(init_weights)
124
+
125
+ def forward(self, x):
126
+ x = self.conv_pre(x)
127
+ for i in range(self.num_upsamples):
128
+ x = F.leaky_relu(x, LRELU_SLOPE)
129
+ x = self.ups[i](x)
130
+ xs = None
131
+ for j in range(self.num_kernels):
132
+ if xs is None:
133
+ xs = self.resblocks[i*self.num_kernels+j](x)
134
+ else:
135
+ xs += self.resblocks[i*self.num_kernels+j](x)
136
+ x = xs / self.num_kernels
137
+ x = F.leaky_relu(x)
138
+ x = self.conv_post(x)
139
+ x = torch.tanh(x)
140
+
141
+ return x
142
+
143
+ def remove_weight_norm(self):
144
+ # print('Removing weight norm...')
145
+ for l in self.ups:
146
+ remove_weight_norm(l)
147
+ for l in self.resblocks:
148
+ l.remove_weight_norm()
149
+ remove_weight_norm(self.conv_pre)
150
+ remove_weight_norm(self.conv_post)
151
+
152
+
153
+ class DiscriminatorP(torch.nn.Module):
154
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
155
+ super(DiscriminatorP, self).__init__()
156
+ self.period = period
157
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
158
+ self.convs = nn.ModuleList([
159
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
160
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
161
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
162
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
163
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
164
+ ])
165
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
166
+
167
+ def forward(self, x):
168
+ fmap = []
169
+
170
+ # 1d to 2d
171
+ b, c, t = x.shape
172
+ if t % self.period != 0: # pad first
173
+ n_pad = self.period - (t % self.period)
174
+ x = F.pad(x, (0, n_pad), "reflect")
175
+ t = t + n_pad
176
+ x = x.view(b, c, t // self.period, self.period)
177
+
178
+ for l in self.convs:
179
+ x = l(x)
180
+ x = F.leaky_relu(x, LRELU_SLOPE)
181
+ fmap.append(x)
182
+ x = self.conv_post(x)
183
+ fmap.append(x)
184
+ x = torch.flatten(x, 1, -1)
185
+
186
+ return x, fmap
187
+
188
+
189
+ class MultiPeriodDiscriminator(torch.nn.Module):
190
+ def __init__(self):
191
+ super(MultiPeriodDiscriminator, self).__init__()
192
+ self.discriminators = nn.ModuleList([
193
+ DiscriminatorP(2),
194
+ DiscriminatorP(3),
195
+ DiscriminatorP(5),
196
+ DiscriminatorP(7),
197
+ DiscriminatorP(11),
198
+ ])
199
+
200
+ def forward(self, y, y_hat):
201
+ y_d_rs = []
202
+ y_d_gs = []
203
+ fmap_rs = []
204
+ fmap_gs = []
205
+ for i, d in enumerate(self.discriminators):
206
+ y_d_r, fmap_r = d(y)
207
+ y_d_g, fmap_g = d(y_hat)
208
+ y_d_rs.append(y_d_r)
209
+ fmap_rs.append(fmap_r)
210
+ y_d_gs.append(y_d_g)
211
+ fmap_gs.append(fmap_g)
212
+
213
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
214
+
215
+
216
+ class DiscriminatorS(torch.nn.Module):
217
+ def __init__(self, use_spectral_norm=False):
218
+ super(DiscriminatorS, self).__init__()
219
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
220
+ self.convs = nn.ModuleList([
221
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
222
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
223
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
224
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
225
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
226
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
227
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
228
+ ])
229
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
230
+
231
+ def forward(self, x):
232
+ fmap = []
233
+ for l in self.convs:
234
+ x = l(x)
235
+ x = F.leaky_relu(x, LRELU_SLOPE)
236
+ fmap.append(x)
237
+ x = self.conv_post(x)
238
+ fmap.append(x)
239
+ x = torch.flatten(x, 1, -1)
240
+
241
+ return x, fmap
242
+
243
+
244
+ class MultiScaleDiscriminator(torch.nn.Module):
245
+ def __init__(self):
246
+ super(MultiScaleDiscriminator, self).__init__()
247
+ self.discriminators = nn.ModuleList([
248
+ DiscriminatorS(use_spectral_norm=True),
249
+ DiscriminatorS(),
250
+ DiscriminatorS(),
251
+ ])
252
+ self.meanpools = nn.ModuleList([
253
+ AvgPool1d(4, 2, padding=2),
254
+ AvgPool1d(4, 2, padding=2)
255
+ ])
256
+
257
+ def forward(self, y, y_hat):
258
+ y_d_rs = []
259
+ y_d_gs = []
260
+ fmap_rs = []
261
+ fmap_gs = []
262
+ for i, d in enumerate(self.discriminators):
263
+ if i != 0:
264
+ y = self.meanpools[i-1](y)
265
+ y_hat = self.meanpools[i-1](y_hat)
266
+ y_d_r, fmap_r = d(y)
267
+ y_d_g, fmap_g = d(y_hat)
268
+ y_d_rs.append(y_d_r)
269
+ fmap_rs.append(fmap_r)
270
+ y_d_gs.append(y_d_g)
271
+ fmap_gs.append(fmap_g)
272
+
273
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
274
+
275
+
276
+ def feature_loss(fmap_r, fmap_g):
277
+ loss = 0
278
+ for dr, dg in zip(fmap_r, fmap_g):
279
+ for rl, gl in zip(dr, dg):
280
+ loss += torch.mean(torch.abs(rl - gl))
281
+
282
+ return loss*2
283
+
284
+
285
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
286
+ loss = 0
287
+ r_losses = []
288
+ g_losses = []
289
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
290
+ r_loss = torch.mean((1-dr)**2)
291
+ g_loss = torch.mean(dg**2)
292
+ loss += (r_loss + g_loss)
293
+ r_losses.append(r_loss.item())
294
+ g_losses.append(g_loss.item())
295
+
296
+ return loss, r_losses, g_losses
297
+
298
+
299
+ def generator_loss(disc_outputs):
300
+ loss = 0
301
+ gen_losses = []
302
+ for dg in disc_outputs:
303
+ l = torch.mean((1-dg)**2)
304
+ gen_losses.append(l)
305
+ loss += l
306
+
307
+ return loss, gen_losses
308
+
309
+ def load_checkpoint(filepath, device):
310
+ assert os.path.isfile(filepath)
311
+ checkpoint_dict = torch.load(filepath, map_location=device)
312
+ return checkpoint_dict
313
+
314
+ def load_vocoder_model(config_path,checkpoint_path,device):
315
+ # config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
316
+ with open(config_path) as f:
317
+ data = f.read()
318
+
319
+ global h
320
+ json_config = json.loads(data)
321
+ h = AttrDict(json_config)
322
+
323
+ torch.manual_seed(h.seed)
324
+
325
+ generator = Generator(h).to(device)
326
+
327
+ state_dict_g = load_checkpoint(checkpoint_path, device)
328
+ generator.load_state_dict(state_dict_g['generator'])
329
+
330
+ generator.eval()
331
+ generator.remove_weight_norm()
332
+
333
+ return generator
334
+
335
+ def infer_wav(mel,generator):
336
+ MAX_WAV_VALUE =32768.0
337
+ with torch.no_grad():
338
+ y_g_hat = generator(mel)
339
+ audio = y_g_hat.squeeze()
340
+ audio = audio * MAX_WAV_VALUE
341
+ audio = audio.cpu().numpy().astype('int16')
342
+ return audio
maha_tts/pretrained_models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
maha_tts/pretrained_models/hifigan/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1bd98e99062ddbced38729a5252dc2aa772328d16d70097ac139dab2f269dc9
3
+ size 799
maha_tts/pretrained_models/hifigan/g_02500000 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:771eaf4876485a35e25577563d390c262e23c2421e4a8c929eacfde34a5b7a60
3
+ size 55788858
maha_tts/pretrained_models/smolie/S2A/s2a_latest.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf359fab98b047ef89d79a99a78fee9c38880e307630d3b3af7bc9cb170f366b
3
+ size 432971673
maha_tts/pretrained_models/smolie/T2S/t2s_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67a10c3bf12a8bca3dd67075ccbfbd79887b244109bd9c96013b0f348d9e2570
3
+ size 276146627
maha_tts/text/__init__.py ADDED
File without changes
maha_tts/text/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (178 Bytes). View file
 
maha_tts/text/__pycache__/cleaners.cpython-311.pyc ADDED
Binary file (7.03 kB). View file
 
maha_tts/text/__pycache__/symbols.cpython-311.pyc ADDED
Binary file (2.37 kB). View file
 
maha_tts/text/cleaners.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from unidecode import unidecode
3
+ import inflect
4
+ import re
5
+
6
+
7
+ _inflect = inflect.engine()
8
+ _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
9
+ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
10
+ _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
11
+ _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
12
+ _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
13
+ _number_re = re.compile(r'[0-9]+')
14
+
15
+
16
+ def _remove_commas(m):
17
+ return m.group(1).replace(',', '')
18
+
19
+
20
+ def _expand_decimal_point(m):
21
+ return m.group(1).replace('.', ' point ')
22
+
23
+
24
+ def _expand_dollars(m):
25
+ match = m.group(1)
26
+ parts = match.split('.')
27
+ if len(parts) > 2:
28
+ return match + ' dollars' # Unexpected format
29
+ dollars = int(parts[0]) if parts[0] else 0
30
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31
+ if dollars and cents:
32
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
33
+ cent_unit = 'cent' if cents == 1 else 'cents'
34
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
35
+ elif dollars:
36
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
37
+ return '%s %s' % (dollars, dollar_unit)
38
+ elif cents:
39
+ cent_unit = 'cent' if cents == 1 else 'cents'
40
+ return '%s %s' % (cents, cent_unit)
41
+ else:
42
+ return 'zero dollars'
43
+
44
+
45
+ def _expand_ordinal(m):
46
+ return _inflect.number_to_words(m.group(0))
47
+
48
+
49
+ def _expand_number(m):
50
+ num = int(m.group(0))
51
+ if num > 1000 and num < 3000:
52
+ if num == 2000:
53
+ return 'two thousand'
54
+ elif num > 2000 and num < 2010:
55
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
56
+ elif num % 100 == 0:
57
+ return _inflect.number_to_words(num // 100) + ' hundred'
58
+ else:
59
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
60
+ else:
61
+ return _inflect.number_to_words(num, andword='')
62
+
63
+
64
+ def normalize_numbers(text):
65
+ text = re.sub(_comma_number_re, _remove_commas, text)
66
+ text = re.sub(_pounds_re, r'\1 pounds', text)
67
+ text = re.sub(_dollars_re, _expand_dollars, text)
68
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
69
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
70
+ text = re.sub(_number_re, _expand_number, text)
71
+ return text
72
+
73
+ # Regular expression matching whitespace:
74
+ _whitespace_re = re.compile(r'\s+')
75
+
76
+ # List of (regular expression, replacement) pairs for abbreviations:
77
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
78
+ ('mrs', 'misess'),
79
+ ('mr', 'mister'),
80
+ ('dr', 'doctor'),
81
+ ('st', 'saint'),
82
+ ('co', 'company'),
83
+ ('jr', 'junior'),
84
+ ('maj', 'major'),
85
+ ('gen', 'general'),
86
+ ('drs', 'doctors'),
87
+ ('rev', 'reverend'),
88
+ ('lt', 'lieutenant'),
89
+ ('hon', 'honorable'),
90
+ ('sgt', 'sergeant'),
91
+ ('capt', 'captain'),
92
+ ('esq', 'esquire'),
93
+ ('ltd', 'limited'),
94
+ ('col', 'colonel'),
95
+ ('ft', 'fort'),
96
+ ]]
97
+
98
+
99
+ def expand_abbreviations(text):
100
+ for regex, replacement in _abbreviations:
101
+ text = re.sub(regex, replacement, text)
102
+ return text
103
+
104
+
105
+ def expand_numbers(text):
106
+ return normalize_numbers(text)
107
+
108
+
109
+ def lowercase(text):
110
+ return text.lower()
111
+
112
+
113
+ def collapse_whitespace(text):
114
+ return re.sub(_whitespace_re, ' ', text)
115
+
116
+
117
+ def convert_to_ascii(text):
118
+ return unidecode(text)
119
+
120
+
121
+ def basic_cleaners(text):
122
+ '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
123
+ text = lowercase(text)
124
+ text = collapse_whitespace(text)
125
+ return text
126
+
127
+
128
+ def transliteration_cleaners(text):
129
+ '''Pipeline for non-English text that transliterates to ASCII.'''
130
+ text = convert_to_ascii(text)
131
+ text = lowercase(text)
132
+ text = collapse_whitespace(text)
133
+ return text
134
+
135
+
136
+ def english_cleaners(text):
137
+ '''Pipeline for English text, including number and abbreviation expansion.'''
138
+ text = convert_to_ascii(text)
139
+ text = lowercase(text)
140
+ text = expand_numbers(text)
141
+ text = expand_abbreviations(text)
142
+ text = collapse_whitespace(text)
143
+ return text
maha_tts/text/symbols.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from maha_tts.config import config
3
+
4
+ labels=" abcdefghijklmnopqrstuvwxyz.,:;'()?!\""
5
+ labels=" !\"'(),-.:;?[]abcdefghijklmnopqrstuvwxyzàâèéêü’“”"
6
+ labels= [i for i in labels]
7
+
8
+ text_labels = [i for i in labels]
9
+ text_labels+='<S>','<E>','<PAD>'
10
+
11
+ code_labels= [str(i) for i in range(config.semantic_model_centroids)]
12
+ labels+=code_labels
13
+ code_labels+='<SST>','<EST>','<PAD>'
14
+
15
+ labels+='<S>','<E>','<SST>','<EST>','<PAD>'
16
+
17
+ tok_enc = {j:i for i,j in enumerate(labels)}
18
+ tok_dec = {i:j for i,j in enumerate(labels)}
19
+
20
+ #text encdec
21
+ text_enc = {j:i for i,j in enumerate(text_labels)}
22
+ text_dec = {i:j for i,j in enumerate(text_labels)}
23
+
24
+ #code encdec
25
+ code_enc = {j:i for i,j in enumerate(code_labels)}
26
+ code_dec = {i:j for i,j in enumerate(code_labels)}
27
+
28
+ # print('length of the labels: ',len(labels))
maha_tts/utils/__init__.py ADDED
File without changes
maha_tts/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (179 Bytes). View file
 
maha_tts/utils/__pycache__/audio.cpython-311.pyc ADDED
Binary file (5.3 kB). View file
 
maha_tts/utils/__pycache__/diffusion.cpython-311.pyc ADDED
Binary file (58.7 kB). View file
 
maha_tts/utils/__pycache__/stft.cpython-311.pyc ADDED
Binary file (6.9 kB). View file
 
maha_tts/utils/audio.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import librosa.util as librosa_util
4
+
5
+ from scipy.signal import get_window
6
+ from scipy.io.wavfile import read
7
+ from maha_tts.config import config
8
+
9
+ TACOTRON_MEL_MAX = 2.3143386840820312
10
+ TACOTRON_MEL_MIN = -11.512925148010254
11
+
12
+
13
+ def denormalize_tacotron_mel(norm_mel):
14
+ return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
15
+
16
+
17
+ def normalize_tacotron_mel(mel):
18
+ return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
19
+
20
+
21
+ def get_mask_from_lengths(lengths, max_len=None):
22
+ if not max_len:
23
+ max_len = torch.max(lengths).item()
24
+ ids = torch.arange(0, max_len, device=lengths.device, dtype=torch.long)
25
+ mask = (ids < lengths.unsqueeze(1)).bool()
26
+ return mask
27
+
28
+
29
+ def get_mask(lengths, max_len=None):
30
+ if not max_len:
31
+ max_len = torch.max(lengths).item()
32
+ lens = torch.arange(max_len,)
33
+ mask = lens[:max_len].unsqueeze(0) < lengths.unsqueeze(1)
34
+ return mask
35
+
36
+
37
+
38
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
39
+ """
40
+ PARAMS
41
+ ------
42
+ C: compression factor
43
+ """
44
+ return torch.log(torch.clamp(x, min=clip_val) * C)
45
+
46
+
47
+ def dynamic_range_decompression(x, C=1):
48
+ """
49
+ PARAMS
50
+ ------
51
+ C: compression factor used to compress
52
+ """
53
+ return torch.exp(x) / C
54
+
55
+
56
+ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
57
+ n_fft=800, dtype=np.float32, norm=None):
58
+ """
59
+ # from librosa 0.6
60
+ Compute the sum-square envelope of a window function at a given hop length.
61
+ This is used to estimate modulation effects induced by windowing
62
+ observations in short-time fourier transforms.
63
+ Parameters
64
+ ----------
65
+ window : string, tuple, number, callable, or list-like
66
+ Window specification, as in `get_window`
67
+ n_frames : int > 0
68
+ The number of analysis frames
69
+ hop_length : int > 0
70
+ The number of samples to advance between frames
71
+ win_length : [optional]
72
+ The length of the window function. By default, this matches `n_fft`.
73
+ n_fft : int > 0
74
+ The length of each analysis frame.
75
+ dtype : np.dtype
76
+ The data type of the output
77
+ Returns
78
+ -------
79
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
80
+ The sum-squared envelope of the window function
81
+ """
82
+ if win_length is None:
83
+ win_length = n_fft
84
+
85
+ n = n_fft + hop_length * (n_frames - 1)
86
+ x = np.zeros(n, dtype=dtype)
87
+
88
+ # Compute the squared window at the desired length
89
+ win_sq = get_window(window, win_length, fftbins=True)
90
+ win_sq = librosa_util.normalize(win_sq, norm=norm)**2
91
+ win_sq = librosa_util.pad_center(win_sq, size=n_fft)
92
+
93
+ # Fill the envelope
94
+ for i in range(n_frames):
95
+ sample = i * hop_length
96
+ x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
97
+ return x
98
+
99
+ def load_wav_to_torch(full_path):
100
+ sampling_rate, data = read(full_path,)
101
+ return torch.FloatTensor(data), sampling_rate
102
+
103
+
104
+
105
+ if __name__ == "__main__":
106
+ lens = torch.tensor([2, 3, 7, 5, 4])
107
+ mask = get_mask(lens)
108
+ print(mask)
109
+ print(mask.shape)
maha_tts/utils/diffusion.py ADDED
@@ -0,0 +1,1283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from Tortoise-tts
3
+ ########################################
4
+ This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself:
5
+
6
+ This code started out as a PyTorch port of Ho et al's diffusion models:
7
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
8
+
9
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
10
+ ########################################
11
+ """
12
+
13
+ import enum
14
+ import math
15
+ import torch
16
+ import torch as th
17
+ import torch.nn.functional as F
18
+ import numpy as np
19
+ from tqdm import tqdm
20
+
21
+ def normal_kl(mean1, logvar1, mean2, logvar2):
22
+ """
23
+ Compute the KL divergence between two gaussians.
24
+
25
+ Shapes are automatically broadcasted, so batches can be compared to
26
+ scalars, among other use cases.
27
+ """
28
+ tensor = None
29
+ for obj in (mean1, logvar1, mean2, logvar2):
30
+ if isinstance(obj, th.Tensor):
31
+ tensor = obj
32
+ break
33
+ assert tensor is not None, "at least one argument must be a Tensor"
34
+
35
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
36
+ # Tensors, but it does not work for th.exp().
37
+ logvar1, logvar2 = [
38
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
39
+ for x in (logvar1, logvar2)
40
+ ]
41
+
42
+ return 0.5 * (
43
+ -1.0
44
+ + logvar2
45
+ - logvar1
46
+ + th.exp(logvar1 - logvar2)
47
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
48
+ )
49
+
50
+
51
+ def approx_standard_normal_cdf(x):
52
+ """
53
+ A fast approximation of the cumulative distribution function of the
54
+ standard normal.
55
+ """
56
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
57
+
58
+
59
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
60
+ """
61
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
62
+ given image.
63
+
64
+ :param x: the target images. It is assumed that this was uint8 values,
65
+ rescaled to the range [-1, 1].
66
+ :param means: the Gaussian mean Tensor.
67
+ :param log_scales: the Gaussian log stddev Tensor.
68
+ :return: a tensor like x of log probabilities (in nats).
69
+ """
70
+ assert x.shape == means.shape == log_scales.shape
71
+ centered_x = x - means
72
+ inv_stdv = th.exp(-log_scales)
73
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
74
+ cdf_plus = approx_standard_normal_cdf(plus_in)
75
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
76
+ cdf_min = approx_standard_normal_cdf(min_in)
77
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
78
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
79
+ cdf_delta = cdf_plus - cdf_min
80
+ log_probs = th.where(
81
+ x < -0.999,
82
+ log_cdf_plus,
83
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
84
+ )
85
+ assert log_probs.shape == x.shape
86
+ return log_probs
87
+
88
+
89
+ def mean_flat(tensor):
90
+ """
91
+ Take the mean over all non-batch dimensions.
92
+ """
93
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
94
+
95
+
96
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
97
+ """
98
+ Get a pre-defined beta schedule for the given name.
99
+
100
+ The beta schedule library consists of beta schedules which remain similar
101
+ in the limit of num_diffusion_timesteps.
102
+ Beta schedules may be added, but should not be removed or changed once
103
+ they are committed to maintain backwards compatibility.
104
+ """
105
+ if schedule_name == "linear":
106
+ # Linear schedule from Ho et al, extended to work for any number of
107
+ # diffusion steps.
108
+ scale = 1000 / num_diffusion_timesteps
109
+ beta_start = scale * 0.0001
110
+ beta_end = scale * 0.02
111
+ return np.linspace(
112
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
113
+ )
114
+ elif schedule_name == "cosine":
115
+ return betas_for_alpha_bar(
116
+ num_diffusion_timesteps,
117
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
118
+ )
119
+ else:
120
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
121
+
122
+
123
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
124
+ """
125
+ Create a beta schedule that discretizes the given alpha_t_bar function,
126
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
127
+
128
+ :param num_diffusion_timesteps: the number of betas to produce.
129
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
130
+ produces the cumulative product of (1-beta) up to that
131
+ part of the diffusion process.
132
+ :param max_beta: the maximum beta to use; use values lower than 1 to
133
+ prevent singularities.
134
+ """
135
+ betas = []
136
+ for i in range(num_diffusion_timesteps):
137
+ t1 = i / num_diffusion_timesteps
138
+ t2 = (i + 1) / num_diffusion_timesteps
139
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
140
+ return np.array(betas)
141
+
142
+
143
+ class ModelMeanType(enum.Enum):
144
+ """
145
+ Which type of output the model predicts.
146
+ """
147
+
148
+ PREVIOUS_X = 'previous_x' # the model predicts x_{t-1}
149
+ START_X = 'start_x' # the model predicts x_0
150
+ EPSILON = 'epsilon' # the model predicts epsilon
151
+
152
+
153
+ class ModelVarType(enum.Enum):
154
+ """
155
+ What is used as the model's output variance.
156
+
157
+ The LEARNED_RANGE option has been added to allow the model to predict
158
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
159
+ """
160
+
161
+ LEARNED = 'learned'
162
+ FIXED_SMALL = 'fixed_small'
163
+ FIXED_LARGE = 'fixed_large'
164
+ LEARNED_RANGE = 'learned_range'
165
+
166
+
167
+ class LossType(enum.Enum):
168
+ MSE = 'mse' # use raw MSE loss (and KL when learning variances)
169
+ RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances)
170
+ KL = 'kl' # use the variational lower-bound
171
+ RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB
172
+
173
+ def is_vb(self):
174
+ return self == LossType.KL or self == LossType.RESCALED_KL
175
+
176
+
177
+ class GaussianDiffusion:
178
+ """
179
+ Utilities for training and sampling diffusion models.
180
+
181
+ Ported directly from here, and then adapted over time to further experimentation.
182
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
183
+
184
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
185
+ starting at T and going to 1.
186
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
187
+ :param model_var_type: a ModelVarType determining how variance is output.
188
+ :param loss_type: a LossType determining the loss function to use.
189
+ :param rescale_timesteps: if True, pass floating point timesteps into the
190
+ model so that they are always scaled like in the
191
+ original paper (0 to 1000).
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ *,
197
+ betas,
198
+ model_mean_type,
199
+ model_var_type,
200
+ loss_type,
201
+ rescale_timesteps=False,
202
+ conditioning_free=False,
203
+ conditioning_free_k=1,
204
+ ramp_conditioning_free=True,
205
+ ):
206
+ self.model_mean_type = ModelMeanType(model_mean_type)
207
+ self.model_var_type = ModelVarType(model_var_type)
208
+ self.loss_type = LossType(loss_type)
209
+ self.rescale_timesteps = rescale_timesteps
210
+ self.conditioning_free = conditioning_free
211
+ self.conditioning_free_k = conditioning_free_k
212
+ self.ramp_conditioning_free = ramp_conditioning_free
213
+
214
+ # Use float64 for accuracy.
215
+ betas = np.array(betas, dtype=np.float64)
216
+ self.betas = betas
217
+ assert len(betas.shape) == 1, "betas must be 1-D"
218
+ assert (betas > 0).all() and (betas <= 1).all()
219
+
220
+ self.num_timesteps = int(betas.shape[0])
221
+
222
+ alphas = 1.0 - betas
223
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
224
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
225
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
226
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
227
+
228
+ # calculations for diffusion q(x_t | x_{t-1}) and others
229
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
230
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
231
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
232
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
233
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
234
+
235
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
236
+ self.posterior_variance = (
237
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
238
+ )
239
+ # log calculation clipped because the posterior variance is 0 at the
240
+ # beginning of the diffusion chain.
241
+ self.posterior_log_variance_clipped = np.log(
242
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
243
+ )
244
+ self.posterior_mean_coef1 = (
245
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
246
+ )
247
+ self.posterior_mean_coef2 = (
248
+ (1.0 - self.alphas_cumprod_prev)
249
+ * np.sqrt(alphas)
250
+ / (1.0 - self.alphas_cumprod)
251
+ )
252
+
253
+ def q_mean_variance(self, x_start, t):
254
+ """
255
+ Get the distribution q(x_t | x_0).
256
+
257
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
258
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
259
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape. of the sample at timestep t
260
+ """
261
+ mean = (
262
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
263
+ )
264
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
265
+ log_variance = _extract_into_tensor(
266
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
267
+ )
268
+ return mean, variance, log_variance
269
+
270
+ def q_sample(self, x_start, t, noise=None):
271
+ """
272
+ Diffuse the data for a given number of diffusion steps.
273
+
274
+ In other words, sample from q(x_t | x_0).
275
+
276
+ :param x_start: the initial data batch.
277
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
278
+ :param noise: if specified, the split-out normal noise.
279
+ :return: A noisy version of x_start.
280
+ """
281
+ if noise is None:
282
+ noise = th.randn_like(x_start)
283
+ assert noise.shape == x_start.shape
284
+ return (
285
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
286
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
287
+ * noise
288
+ )
289
+
290
+ def q_posterior_mean_variance(self, x_start, x_t, t):
291
+ """
292
+ Compute the mean and variance of the diffusion posterior:
293
+
294
+ q(x_{t-1} | x_t, x_0)
295
+
296
+ """
297
+ assert x_start.shape == x_t.shape
298
+ posterior_mean = (
299
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
300
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
301
+ )
302
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
303
+ posterior_log_variance_clipped = _extract_into_tensor(
304
+ self.posterior_log_variance_clipped, t, x_t.shape
305
+ )
306
+ assert (
307
+ posterior_mean.shape[0]
308
+ == posterior_variance.shape[0]
309
+ == posterior_log_variance_clipped.shape[0]
310
+ == x_start.shape[0]
311
+ )
312
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
313
+
314
+ def p_mean_variance(
315
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
316
+ ):
317
+ """
318
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
319
+ the initial x, x_0.
320
+
321
+ :param model: the model, which takes a signal and a batch of timesteps
322
+ as input.
323
+ :param x: the [N x C x ...] tensor at time t.
324
+ :param t: a 1-D Tensor of timesteps.
325
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
326
+ :param denoised_fn: if not None, a function which applies to the
327
+ x_start prediction before it is used to sample. Applies before
328
+ clip_denoised.
329
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
330
+ pass to the model. This can be used for conditioning.
331
+ :return: a dict with the following keys:
332
+ - 'mean': the model mean output.
333
+ - 'variance': the model variance output.
334
+ - 'log_variance': the log of 'variance'.
335
+ - 'pred_xstart': the prediction for x_0.
336
+ """
337
+ if model_kwargs is None:
338
+ model_kwargs = {}
339
+
340
+ B, C = x.shape[:2]
341
+ assert t.shape == (B,)
342
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
343
+ if self.conditioning_free:
344
+ model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
345
+
346
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
347
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
348
+ model_output, model_var_values = th.split(model_output, C, dim=1)
349
+ if self.conditioning_free:
350
+ model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1)
351
+ if self.model_var_type == ModelVarType.LEARNED:
352
+ model_log_variance = model_var_values
353
+ model_variance = th.exp(model_log_variance)
354
+ else:
355
+ min_log = _extract_into_tensor(
356
+ self.posterior_log_variance_clipped, t, x.shape
357
+ )
358
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
359
+ # The model_var_values is [-1, 1] for [min_var, max_var].
360
+ frac = (model_var_values + 1) / 2
361
+ model_log_variance = frac * max_log + (1 - frac) * min_log
362
+ model_variance = th.exp(model_log_variance)
363
+ else:
364
+ model_variance, model_log_variance = {
365
+ # for fixedlarge, we set the initial (log-)variance like so
366
+ # to get a better decoder log likelihood.
367
+ ModelVarType.FIXED_LARGE: (
368
+ np.append(self.posterior_variance[1], self.betas[1:]),
369
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
370
+ ),
371
+ ModelVarType.FIXED_SMALL: (
372
+ self.posterior_variance,
373
+ self.posterior_log_variance_clipped,
374
+ ),
375
+ }[self.model_var_type]
376
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
377
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
378
+
379
+ if self.conditioning_free:
380
+ if self.ramp_conditioning_free:
381
+ assert t.shape[0] == 1 # This should only be used in inference.
382
+ cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
383
+ else:
384
+ cfk = self.conditioning_free_k
385
+ model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
386
+
387
+ def process_xstart(x):
388
+ if denoised_fn is not None:
389
+ x = denoised_fn(x)
390
+ if clip_denoised:
391
+ return x.clamp(-1, 1)
392
+ return x
393
+
394
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
395
+ pred_xstart = process_xstart(
396
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
397
+ )
398
+ model_mean = model_output
399
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
400
+ if self.model_mean_type == ModelMeanType.START_X:
401
+ pred_xstart = process_xstart(model_output)
402
+ else:
403
+ pred_xstart = process_xstart(
404
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
405
+ )
406
+ model_mean, _, _ = self.q_posterior_mean_variance(
407
+ x_start=pred_xstart, x_t=x, t=t
408
+ )
409
+ else:
410
+ raise NotImplementedError(self.model_mean_type)
411
+
412
+ assert (
413
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
414
+ )
415
+ return {
416
+ "mean": model_mean,
417
+ "variance": model_variance,
418
+ "log_variance": model_log_variance,
419
+ "pred_xstart": pred_xstart,
420
+ }
421
+
422
+ def _predict_xstart_from_eps(self, x_t, t, eps):
423
+ assert x_t.shape == eps.shape
424
+ return (
425
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
426
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
427
+ )
428
+
429
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
430
+ assert x_t.shape == xprev.shape
431
+ return ( # (xprev - coef2*x_t) / coef1
432
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
433
+ - _extract_into_tensor(
434
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
435
+ )
436
+ * x_t
437
+ )
438
+
439
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
440
+ return (
441
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
442
+ - pred_xstart
443
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
444
+
445
+ def _scale_timesteps(self, t):
446
+ if self.rescale_timesteps:
447
+ return t.float() * (1000.0 / self.num_timesteps)
448
+ return t
449
+
450
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
451
+ """
452
+ Compute the mean for the previous step, given a function cond_fn that
453
+ computes the gradient of a conditional log probability with respect to
454
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
455
+ condition on y.
456
+
457
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
458
+ """
459
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
460
+ new_mean = (
461
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
462
+ )
463
+ return new_mean
464
+
465
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
466
+ """
467
+ Compute what the p_mean_variance output would have been, should the
468
+ model's score function be conditioned by cond_fn.
469
+
470
+ See condition_mean() for details on cond_fn.
471
+
472
+ Unlike condition_mean(), this instead uses the conditioning strategy
473
+ from Song et al (2020).
474
+ """
475
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
476
+
477
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
478
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
479
+ x, self._scale_timesteps(t), **model_kwargs
480
+ )
481
+
482
+ out = p_mean_var.copy()
483
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
484
+ out["mean"], _, _ = self.q_posterior_mean_variance(
485
+ x_start=out["pred_xstart"], x_t=x, t=t
486
+ )
487
+ return out
488
+
489
+ def p_sample(
490
+ self,
491
+ model,
492
+ x,
493
+ t,
494
+ clip_denoised=True,
495
+ denoised_fn=None,
496
+ cond_fn=None,
497
+ model_kwargs=None,
498
+ ):
499
+ """
500
+ Sample x_{t-1} from the model at the given timestep.
501
+
502
+ :param model: the model to sample from.
503
+ :param x: the current tensor at x_{t-1}.
504
+ :param t: the value of t, starting at 0 for the first diffusion step.
505
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
506
+ :param denoised_fn: if not None, a function which applies to the
507
+ x_start prediction before it is used to sample.
508
+ :param cond_fn: if not None, this is a gradient function that acts
509
+ similarly to the model.
510
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
511
+ pass to the model. This can be used for conditioning.
512
+ :return: a dict containing the following keys:
513
+ - 'sample': a random sample from the model.
514
+ - 'pred_xstart': a prediction of x_0.
515
+ """
516
+ out = self.p_mean_variance(
517
+ model,
518
+ x,
519
+ t,
520
+ clip_denoised=clip_denoised,
521
+ denoised_fn=denoised_fn,
522
+ model_kwargs=model_kwargs,
523
+ )
524
+ noise = th.randn_like(x)
525
+ nonzero_mask = (
526
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
527
+ ) # no noise when t == 0
528
+ if cond_fn is not None:
529
+ out["mean"] = self.condition_mean(
530
+ cond_fn, out, x, t, model_kwargs=model_kwargs
531
+ )
532
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
533
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
534
+
535
+ def p_sample_loop(
536
+ self,
537
+ model,
538
+ shape,
539
+ noise=None,
540
+ clip_denoised=True,
541
+ denoised_fn=None,
542
+ cond_fn=None,
543
+ model_kwargs=None,
544
+ device=None,
545
+ progress=False,
546
+ ):
547
+ """
548
+ Generate samples from the model.
549
+
550
+ :param model: the model module.
551
+ :param shape: the shape of the samples, (N, C, H, W).
552
+ :param noise: if specified, the noise from the encoder to sample.
553
+ Should be of the same shape as `shape`.
554
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
555
+ :param denoised_fn: if not None, a function which applies to the
556
+ x_start prediction before it is used to sample.
557
+ :param cond_fn: if not None, this is a gradient function that acts
558
+ similarly to the model.
559
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
560
+ pass to the model. This can be used for conditioning.
561
+ :param device: if specified, the device to create the samples on.
562
+ If not specified, use a model parameter's device.
563
+ :param progress: if True, show a tqdm progress bar.
564
+ :return: a non-differentiable batch of samples.
565
+ """
566
+ final = None
567
+ for sample in self.p_sample_loop_progressive(
568
+ model,
569
+ shape,
570
+ noise=noise,
571
+ clip_denoised=clip_denoised,
572
+ denoised_fn=denoised_fn,
573
+ cond_fn=cond_fn,
574
+ model_kwargs=model_kwargs,
575
+ device=device,
576
+ progress=progress,
577
+ ):
578
+ final = sample
579
+ return final["sample"]
580
+
581
+ def p_sample_loop_progressive(
582
+ self,
583
+ model,
584
+ shape,
585
+ noise=None,
586
+ clip_denoised=True,
587
+ denoised_fn=None,
588
+ cond_fn=None,
589
+ model_kwargs=None,
590
+ device=None,
591
+ progress=False,
592
+ ):
593
+ """
594
+ Generate samples from the model and yield intermediate samples from
595
+ each timestep of diffusion.
596
+
597
+ Arguments are the same as p_sample_loop().
598
+ Returns a generator over dicts, where each dict is the return value of
599
+ p_sample().
600
+ """
601
+ if device is None:
602
+ device = next(model.parameters()).device
603
+ assert isinstance(shape, (tuple, list))
604
+ if noise is not None:
605
+ img = noise
606
+ else:
607
+ img = th.randn(*shape, device=device)
608
+ indices = list(range(self.num_timesteps))[::-1]
609
+
610
+ for i in tqdm(indices, disable=not progress):
611
+ t = th.tensor([i] * shape[0], device=device)
612
+ with th.no_grad():
613
+ out = self.p_sample(
614
+ model,
615
+ img,
616
+ t,
617
+ clip_denoised=clip_denoised,
618
+ denoised_fn=denoised_fn,
619
+ cond_fn=cond_fn,
620
+ model_kwargs=model_kwargs,
621
+ )
622
+ yield out
623
+ img = out["sample"]
624
+
625
+ def ddim_sample(
626
+ self,
627
+ model,
628
+ x,
629
+ t,
630
+ clip_denoised=True,
631
+ denoised_fn=None,
632
+ cond_fn=None,
633
+ model_kwargs=None,
634
+ eta=0.0,
635
+ ):
636
+ """
637
+ Sample x_{t-1} from the model using DDIM.
638
+
639
+ Same usage as p_sample().
640
+ """
641
+ out = self.p_mean_variance(
642
+ model,
643
+ x,
644
+ t,
645
+ clip_denoised=clip_denoised,
646
+ denoised_fn=denoised_fn,
647
+ model_kwargs=model_kwargs,
648
+ )
649
+ if cond_fn is not None:
650
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
651
+
652
+ # Usually our model outputs epsilon, but we re-derive it
653
+ # in case we used x_start or x_prev prediction.
654
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
655
+
656
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
657
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
658
+ sigma = (
659
+ eta
660
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
661
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
662
+ )
663
+ # Equation 12.
664
+ noise = th.randn_like(x)
665
+ mean_pred = (
666
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
667
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
668
+ )
669
+ nonzero_mask = (
670
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
671
+ ) # no noise when t == 0
672
+ sample = mean_pred + nonzero_mask * sigma * noise
673
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
674
+
675
+ def ddim_reverse_sample(
676
+ self,
677
+ model,
678
+ x,
679
+ t,
680
+ clip_denoised=True,
681
+ denoised_fn=None,
682
+ model_kwargs=None,
683
+ eta=0.0,
684
+ ):
685
+ """
686
+ Sample x_{t+1} from the model using DDIM reverse ODE.
687
+ """
688
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
689
+ out = self.p_mean_variance(
690
+ model,
691
+ x,
692
+ t,
693
+ clip_denoised=clip_denoised,
694
+ denoised_fn=denoised_fn,
695
+ model_kwargs=model_kwargs,
696
+ )
697
+ # Usually our model outputs epsilon, but we re-derive it
698
+ # in case we used x_start or x_prev prediction.
699
+ eps = (
700
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
701
+ - out["pred_xstart"]
702
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
703
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
704
+
705
+ # Equation 12. reversed
706
+ mean_pred = (
707
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
708
+ + th.sqrt(1 - alpha_bar_next) * eps
709
+ )
710
+
711
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
712
+
713
+ def ddim_sample_loop(
714
+ self,
715
+ model,
716
+ shape,
717
+ noise=None,
718
+ clip_denoised=True,
719
+ denoised_fn=None,
720
+ cond_fn=None,
721
+ model_kwargs=None,
722
+ device=None,
723
+ progress=False,
724
+ eta=0.0,
725
+ ):
726
+ """
727
+ Generate samples from the model using DDIM.
728
+
729
+ Same usage as p_sample_loop().
730
+ """
731
+ final = None
732
+ for sample in self.ddim_sample_loop_progressive(
733
+ model,
734
+ shape,
735
+ noise=noise,
736
+ clip_denoised=clip_denoised,
737
+ denoised_fn=denoised_fn,
738
+ cond_fn=cond_fn,
739
+ model_kwargs=model_kwargs,
740
+ device=device,
741
+ progress=progress,
742
+ eta=eta,
743
+ ):
744
+ final = sample
745
+ return final["sample"]
746
+
747
+ def ddim_sample_loop_progressive(
748
+ self,
749
+ model,
750
+ shape,
751
+ noise=None,
752
+ clip_denoised=True,
753
+ denoised_fn=None,
754
+ cond_fn=None,
755
+ model_kwargs=None,
756
+ device=None,
757
+ progress=False,
758
+ eta=0.0,
759
+ ):
760
+ """
761
+ Use DDIM to sample from the model and yield intermediate samples from
762
+ each timestep of DDIM.
763
+
764
+ Same usage as p_sample_loop_progressive().
765
+ """
766
+ if device is None:
767
+ device = next(model.parameters()).device
768
+ assert isinstance(shape, (tuple, list))
769
+ if noise is not None:
770
+ img = noise
771
+ else:
772
+ img = th.randn(*shape, device=device)
773
+ indices = list(range(self.num_timesteps))[::-1]
774
+
775
+ if progress:
776
+ # Lazy import so that we don't depend on tqdm.
777
+ from tqdm.auto import tqdm
778
+
779
+ indices = tqdm(indices, disable=not progress)
780
+
781
+ for i in indices:
782
+ t = th.tensor([i] * shape[0], device=device)
783
+ with th.no_grad():
784
+ out = self.ddim_sample(
785
+ model,
786
+ img,
787
+ t,
788
+ clip_denoised=clip_denoised,
789
+ denoised_fn=denoised_fn,
790
+ cond_fn=cond_fn,
791
+ model_kwargs=model_kwargs,
792
+ eta=eta,
793
+ )
794
+ yield out
795
+ img = out["sample"]
796
+
797
+ def _vb_terms_bpd(
798
+ self, model, x_start, x_t, t, mask,clip_denoised=True, model_kwargs=None
799
+ ):
800
+ """
801
+ Get a term for the variational lower-bound.
802
+
803
+ The resulting units are bits (rather than nats, as one might expect).
804
+ This allows for comparison to other papers.
805
+
806
+ :return: a dict with the following keys:
807
+ - 'output': a shape [N] tensor of NLLs or KLs.
808
+ - 'pred_xstart': the x_0 predictions.
809
+ """
810
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
811
+ x_start=x_start, x_t=x_t, t=t
812
+ )
813
+ out = self.p_mean_variance(
814
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
815
+ )
816
+ kl = normal_kl(
817
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
818
+ )
819
+
820
+ mask = mask.squeeze(1).float()
821
+ kl= kl.mean(dim=-2)
822
+ kl *= mask
823
+ kl = kl.sum(-1) / mask.sum(-1)
824
+ kl = kl/np.log(2.0)
825
+ # kl = mean_flat(kl) / np.log(2.0)
826
+ # print(kl)
827
+ decoder_nll = -discretized_gaussian_log_likelihood(
828
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
829
+ )
830
+ assert decoder_nll.shape == x_start.shape
831
+
832
+ # print(decoder_nll.shape)
833
+ decoder_nll = decoder_nll.mean(dim=-2)
834
+ decoder_nll *= mask
835
+ decoder_nll = decoder_nll.sum(-1) / mask.sum(-1)
836
+ decoder_nll = decoder_nll/np.log(2.0)
837
+ # decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
838
+ # print(decoder_nll)
839
+ # At the first timestep return the decoder NLL,
840
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
841
+ output = th.where((t == 0), decoder_nll, kl)
842
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
843
+
844
+ def training_losses(self, model, x_start, t, mask,model_kwargs=None, noise=None):
845
+ """
846
+ Compute training losses for a single timestep.
847
+
848
+ :param model: the model to evaluate loss on.
849
+ :param x_start: the [N x C x ...] tensor of inputs.
850
+ :param t: a batch of timestep indices.
851
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
852
+ pass to the model. This can be used for conditioning.
853
+ :param noise: if specified, the specific Gaussian noise to try to remove.
854
+ :return: a dict with the key "loss" containing a tensor of shape [N].
855
+ Some mean or variance settings may also have other keys.
856
+ """
857
+ if model_kwargs is None:
858
+ model_kwargs = {}
859
+ if noise is None:
860
+ noise = th.randn_like(x_start)
861
+ x_t = self.q_sample(x_start, t, noise=noise)
862
+ # print(x_t.shape,mask.shape)
863
+ terms = {}
864
+ # mask = torch.ones(mask.shape).to(mask.device)
865
+
866
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
867
+ # TODO: support multiple model outputs for this mode.
868
+ terms["loss"] = self._vb_terms_bpd(
869
+ model=model,
870
+ x_start=x_start,
871
+ x_t=x_t,
872
+ t=t,
873
+ clip_denoised=False,
874
+ model_kwargs=model_kwargs,
875
+ )["output"]
876
+ if self.loss_type == LossType.RESCALED_KL:
877
+ terms["loss"] *= self.num_timesteps
878
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
879
+ # print('timestep : ',self._scale_timesteps(t))
880
+ model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
881
+ if isinstance(model_outputs, tuple):
882
+ model_output = model_outputs[0]
883
+ terms['extra_outputs'] = model_outputs[1:]
884
+ else:
885
+ model_output = model_outputs
886
+
887
+ if self.model_var_type in [
888
+ ModelVarType.LEARNED,
889
+ ModelVarType.LEARNED_RANGE,
890
+ ]:
891
+ B, C = x_t.shape[:2]
892
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
893
+ model_output, model_var_values = th.split(model_output, C, dim=1)
894
+ # Learn the variance using the variational bound, but don't let
895
+ # it affect our mean prediction.
896
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
897
+ terms["vb"] = self._vb_terms_bpd(
898
+ model=lambda *args, r=frozen_out: r,
899
+ x_start=x_start,
900
+ x_t=x_t,
901
+ t=t,
902
+ mask=mask,
903
+ clip_denoised=False,
904
+ )["output"]
905
+ if self.loss_type == LossType.RESCALED_MSE:
906
+ # Divide by 1000 for equivalence with initial implementation.
907
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
908
+ # terms["vb"] *= self.num_timesteps / 1000.0
909
+ terms["vb"] *= 1/1000
910
+ # print('scaling vb :',self.num_timesteps / 1000.0)
911
+
912
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
913
+ target = self.q_posterior_mean_variance(
914
+ x_start=x_start, x_t=x_t, t=t
915
+ )[0]
916
+ x_start_pred = torch.zeros(x_start) # Not supported.
917
+ elif self.model_mean_type == ModelMeanType.START_X:
918
+ target = x_start
919
+ x_start_pred = model_output
920
+ elif self.model_mean_type == ModelMeanType.EPSILON:
921
+ target = noise
922
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
923
+ else:
924
+ raise NotImplementedError(self.model_mean_type)
925
+ assert model_output.shape == target.shape == x_start.shape
926
+
927
+ mask = mask.squeeze(1).float()
928
+
929
+ loss = F.mse_loss(target, model_output, reduction='none').mean(dim=-2)
930
+ loss *= mask
931
+ loss = loss.sum(-1) / mask.sum(-1)
932
+
933
+ terms["mse"] = loss
934
+ # terms["mse"] = mean_flat((target - model_output) ** 2)
935
+ terms["x_start_predicted"] = x_start_pred
936
+ # print(terms['vb'],terms['mse'])
937
+ if "vb" in terms:
938
+ terms["loss"] = terms["mse"] + terms["vb"]
939
+ else:
940
+ terms["loss"] = terms["mse"]
941
+ else:
942
+ raise NotImplementedError(self.loss_type)
943
+
944
+ # print(terms['loss'])
945
+ # terms["loss"]=terms['loss'].sum()/terms['loss'].shape[0]
946
+ terms["mse"]=terms['mse'].sum()/terms['mse'].shape[0]
947
+ terms["vb"]=terms['vb'].sum()/terms['vb'].shape[0]
948
+ # print(terms['loss'],terms['mse'],terms['vb'])
949
+ return terms
950
+
951
+ def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
952
+ """
953
+ Compute training losses for a single timestep.
954
+
955
+ :param model: the model to evaluate loss on.
956
+ :param x_start: the [N x C x ...] tensor of inputs.
957
+ :param t: a batch of timestep indices.
958
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
959
+ pass to the model. This can be used for conditioning.
960
+ :param noise: if specified, the specific Gaussian noise to try to remove.
961
+ :return: a dict with the key "loss" containing a tensor of shape [N].
962
+ Some mean or variance settings may also have other keys.
963
+ """
964
+ if model_kwargs is None:
965
+ model_kwargs = {}
966
+ if noise is None:
967
+ noise = th.randn_like(x_start)
968
+ x_t = self.q_sample(x_start, t, noise=noise)
969
+ terms = {}
970
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
971
+ assert False # not currently supported for this type of diffusion.
972
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
973
+ model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
974
+ terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
975
+ model_output = terms[gd_out_key]
976
+ if self.model_var_type in [
977
+ ModelVarType.LEARNED,
978
+ ModelVarType.LEARNED_RANGE,
979
+ ]:
980
+ B, C = x_t.shape[:2]
981
+ assert model_output.shape == (B, C, 2, *x_t.shape[2:])
982
+ model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1]
983
+ # Learn the variance using the variational bound, but don't let
984
+ # it affect our mean prediction.
985
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
986
+ terms["vb"] = self._vb_terms_bpd(
987
+ model=lambda *args, r=frozen_out: r,
988
+ x_start=x_start,
989
+ x_t=x_t,
990
+ t=t,
991
+ clip_denoised=False,
992
+ )["output"]
993
+ if self.loss_type == LossType.RESCALED_MSE:
994
+ # Divide by 1000 for equivalence with initial implementation.
995
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
996
+ terms["vb"] *= self.num_timesteps / 1000.0
997
+
998
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
999
+ target = self.q_posterior_mean_variance(
1000
+ x_start=x_start, x_t=x_t, t=t
1001
+ )[0]
1002
+ x_start_pred = torch.zeros(x_start) # Not supported.
1003
+ elif self.model_mean_type == ModelMeanType.START_X:
1004
+ target = x_start
1005
+ x_start_pred = model_output
1006
+ elif self.model_mean_type == ModelMeanType.EPSILON:
1007
+ target = noise
1008
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
1009
+ else:
1010
+ raise NotImplementedError(self.model_mean_type)
1011
+ assert model_output.shape == target.shape == x_start.shape
1012
+ terms["mse"] = mean_flat((target - model_output) ** 2)
1013
+ terms["x_start_predicted"] = x_start_pred
1014
+ if "vb" in terms:
1015
+ terms["loss"] = terms["mse"] + terms["vb"]
1016
+ else:
1017
+ terms["loss"] = terms["mse"]
1018
+ else:
1019
+ raise NotImplementedError(self.loss_type)
1020
+
1021
+ return terms
1022
+
1023
+ def _prior_bpd(self, x_start):
1024
+ """
1025
+ Get the prior KL term for the variational lower-bound, measured in
1026
+ bits-per-dim.
1027
+
1028
+ This term can't be optimized, as it only depends on the encoder.
1029
+
1030
+ :param x_start: the [N x C x ...] tensor of inputs.
1031
+ :return: a batch of [N] KL values (in bits), one per batch element.
1032
+ """
1033
+ batch_size = x_start.shape[0]
1034
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1035
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1036
+ kl_prior = normal_kl(
1037
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
1038
+ )
1039
+ return mean_flat(kl_prior) / np.log(2.0)
1040
+
1041
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
1042
+ """
1043
+ Compute the entire variational lower-bound, measured in bits-per-dim,
1044
+ as well as other related quantities.
1045
+
1046
+ :param model: the model to evaluate loss on.
1047
+ :param x_start: the [N x C x ...] tensor of inputs.
1048
+ :param clip_denoised: if True, clip denoised samples.
1049
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1050
+ pass to the model. This can be used for conditioning.
1051
+
1052
+ :return: a dict containing the following keys:
1053
+ - total_bpd: the total variational lower-bound, per batch element.
1054
+ - prior_bpd: the prior term in the lower-bound.
1055
+ - vb: an [N x T] tensor of terms in the lower-bound.
1056
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
1057
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
1058
+ """
1059
+ device = x_start.device
1060
+ batch_size = x_start.shape[0]
1061
+
1062
+ vb = []
1063
+ xstart_mse = []
1064
+ mse = []
1065
+ for t in list(range(self.num_timesteps))[::-1]:
1066
+ t_batch = th.tensor([t] * batch_size, device=device)
1067
+ noise = th.randn_like(x_start)
1068
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
1069
+ # Calculate VLB term at the current timestep
1070
+ with th.no_grad():
1071
+ out = self._vb_terms_bpd(
1072
+ model,
1073
+ x_start=x_start,
1074
+ x_t=x_t,
1075
+ t=t_batch,
1076
+ clip_denoised=clip_denoised,
1077
+ model_kwargs=model_kwargs,
1078
+ )
1079
+ vb.append(out["output"])
1080
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1081
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1082
+ mse.append(mean_flat((eps - noise) ** 2))
1083
+
1084
+ vb = th.stack(vb, dim=1)
1085
+ xstart_mse = th.stack(xstart_mse, dim=1)
1086
+ mse = th.stack(mse, dim=1)
1087
+
1088
+ prior_bpd = self._prior_bpd(x_start)
1089
+ total_bpd = vb.sum(dim=1) + prior_bpd
1090
+ return {
1091
+ "total_bpd": total_bpd,
1092
+ "prior_bpd": prior_bpd,
1093
+ "vb": vb,
1094
+ "xstart_mse": xstart_mse,
1095
+ "mse": mse,
1096
+ }
1097
+
1098
+
1099
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
1100
+ """
1101
+ Get a pre-defined beta schedule for the given name.
1102
+
1103
+ The beta schedule library consists of beta schedules which remain similar
1104
+ in the limit of num_diffusion_timesteps.
1105
+ Beta schedules may be added, but should not be removed or changed once
1106
+ they are committed to maintain backwards compatibility.
1107
+ """
1108
+ if schedule_name == "linear":
1109
+ # Linear schedule from Ho et al, extended to work for any number of
1110
+ # diffusion steps.
1111
+ scale = 1000 / num_diffusion_timesteps
1112
+ beta_start = scale * 0.0001
1113
+ beta_end = scale * 0.02
1114
+ return np.linspace(
1115
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
1116
+ )
1117
+ elif schedule_name == "cosine":
1118
+ return betas_for_alpha_bar(
1119
+ num_diffusion_timesteps,
1120
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
1121
+ )
1122
+ else:
1123
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
1124
+
1125
+
1126
+ class SpacedDiffusion(GaussianDiffusion):
1127
+ """
1128
+ A diffusion process which can skip steps in a base diffusion process.
1129
+
1130
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
1131
+ original diffusion process to retain.
1132
+ :param kwargs: the kwargs to create the base diffusion process.
1133
+ """
1134
+
1135
+ def __init__(self, use_timesteps, **kwargs):
1136
+ self.use_timesteps = set(use_timesteps)
1137
+ self.timestep_map = []
1138
+ self.original_num_steps = len(kwargs["betas"])
1139
+
1140
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
1141
+ last_alpha_cumprod = 1.0
1142
+ new_betas = []
1143
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
1144
+ if i in self.use_timesteps:
1145
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
1146
+ last_alpha_cumprod = alpha_cumprod
1147
+ self.timestep_map.append(i)
1148
+ kwargs["betas"] = np.array(new_betas)
1149
+ super().__init__(**kwargs)
1150
+
1151
+ def p_mean_variance(
1152
+ self, model, *args, **kwargs
1153
+ ): # pylint: disable=signature-differs
1154
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
1155
+
1156
+ def training_losses(
1157
+ self, model, *args, **kwargs
1158
+ ): # pylint: disable=signature-differs
1159
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
1160
+
1161
+ def autoregressive_training_losses(
1162
+ self, model, *args, **kwargs
1163
+ ): # pylint: disable=signature-differs
1164
+ return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
1165
+
1166
+ def condition_mean(self, cond_fn, *args, **kwargs):
1167
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
1168
+
1169
+ def condition_score(self, cond_fn, *args, **kwargs):
1170
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
1171
+
1172
+ def _wrap_model(self, model, autoregressive=False):
1173
+ if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
1174
+ return model
1175
+ mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
1176
+ return mod(
1177
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
1178
+ )
1179
+
1180
+ def _scale_timesteps(self, t):
1181
+ # Scaling is done by the wrapped model.
1182
+ return t
1183
+
1184
+
1185
+ def space_timesteps(num_timesteps, section_counts):
1186
+ """
1187
+ Create a list of timesteps to use from an original diffusion process,
1188
+ given the number of timesteps we want to take from equally-sized portions
1189
+ of the original process.
1190
+
1191
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
1192
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
1193
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
1194
+
1195
+ If the stride is a string starting with "ddim", then the fixed striding
1196
+ from the DDIM paper is used, and only one section is allowed.
1197
+
1198
+ :param num_timesteps: the number of diffusion steps in the original
1199
+ process to divide up.
1200
+ :param section_counts: either a list of numbers, or a string containing
1201
+ comma-separated numbers, indicating the step count
1202
+ per section. As a special case, use "ddimN" where N
1203
+ is a number of steps to use the striding from the
1204
+ DDIM paper.
1205
+ :return: a set of diffusion steps from the original process to use.
1206
+ """
1207
+ if isinstance(section_counts, str):
1208
+ if section_counts.startswith("ddim"):
1209
+ desired_count = int(section_counts[len("ddim") :])
1210
+ for i in range(1, num_timesteps):
1211
+ if len(range(0, num_timesteps, i)) == desired_count:
1212
+ return set(range(0, num_timesteps, i))
1213
+ raise ValueError(
1214
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
1215
+ )
1216
+ section_counts = [int(x) for x in section_counts.split(",")]
1217
+ size_per = num_timesteps // len(section_counts)
1218
+ extra = num_timesteps % len(section_counts)
1219
+ start_idx = 0
1220
+ all_steps = []
1221
+ for i, section_count in enumerate(section_counts):
1222
+ size = size_per + (1 if i < extra else 0)
1223
+ if size < section_count:
1224
+ raise ValueError(
1225
+ f"cannot divide section of {size} steps into {section_count}"
1226
+ )
1227
+ if section_count <= 1:
1228
+ frac_stride = 1
1229
+ else:
1230
+ frac_stride = (size - 1) / (section_count - 1)
1231
+ cur_idx = 0.0
1232
+ taken_steps = []
1233
+ for _ in range(section_count):
1234
+ taken_steps.append(start_idx + round(cur_idx))
1235
+ cur_idx += frac_stride
1236
+ all_steps += taken_steps
1237
+ start_idx += size
1238
+ return set(all_steps)
1239
+
1240
+
1241
+ class _WrappedModel:
1242
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
1243
+ self.model = model
1244
+ self.timestep_map = timestep_map
1245
+ self.rescale_timesteps = rescale_timesteps
1246
+ self.original_num_steps = original_num_steps
1247
+
1248
+ def __call__(self, x, ts, **kwargs):
1249
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1250
+ new_ts = map_tensor[ts]
1251
+ if self.rescale_timesteps:
1252
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
1253
+ return self.model(x, new_ts, **kwargs)
1254
+
1255
+
1256
+ class _WrappedAutoregressiveModel:
1257
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
1258
+ self.model = model
1259
+ self.timestep_map = timestep_map
1260
+ self.rescale_timesteps = rescale_timesteps
1261
+ self.original_num_steps = original_num_steps
1262
+
1263
+ def __call__(self, x, x0, ts, **kwargs):
1264
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1265
+ new_ts = map_tensor[ts]
1266
+ if self.rescale_timesteps:
1267
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
1268
+ return self.model(x, x0, new_ts, **kwargs)
1269
+
1270
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1271
+ """
1272
+ Extract values from a 1-D numpy array for a batch of indices.
1273
+
1274
+ :param arr: the 1-D numpy array.
1275
+ :param timesteps: a tensor of indices into the array to extract.
1276
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1277
+ dimension equal to the length of timesteps.
1278
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1279
+ """
1280
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1281
+ while len(res.shape) < len(broadcast_shape):
1282
+ res = res[..., None]
1283
+ return res.expand(broadcast_shape)
maha_tts/utils/stft.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ from scipy.signal import get_window
6
+ from librosa.util import pad_center, tiny
7
+ from maha_tts.utils.audio import window_sumsquare
8
+
9
+
10
+ class STFT(torch.nn.Module):
11
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
12
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
13
+ window='hann'):
14
+ super(STFT, self).__init__()
15
+ self.filter_length = filter_length
16
+ self.hop_length = hop_length
17
+ self.win_length = win_length
18
+ self.window = window
19
+ self.forward_transform = None
20
+ scale = self.filter_length / self.hop_length
21
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
22
+
23
+ cutoff = int((self.filter_length / 2 + 1))
24
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
25
+ np.imag(fourier_basis[:cutoff, :])])
26
+
27
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
28
+ inverse_basis = torch.FloatTensor(
29
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
30
+
31
+ if window is not None:
32
+ assert(filter_length >= win_length)
33
+ # get window and zero center pad it to filter_length
34
+ fft_window = get_window(window, win_length, fftbins=True)
35
+ fft_window = pad_center(fft_window, size = filter_length)
36
+ fft_window = torch.from_numpy(fft_window).float()
37
+
38
+ # window the bases
39
+ forward_basis *= fft_window
40
+ inverse_basis *= fft_window
41
+
42
+ self.register_buffer('forward_basis', forward_basis.float())
43
+ self.register_buffer('inverse_basis', inverse_basis.float())
44
+
45
+ def transform(self, input_data):
46
+ num_batches = input_data.size(0)
47
+ num_samples = input_data.size(1)
48
+
49
+ self.num_samples = num_samples
50
+
51
+ # similar to librosa, reflect-pad the input
52
+ input_data = input_data.view(num_batches, 1, num_samples)
53
+ input_data = F.pad(
54
+ input_data.unsqueeze(1),
55
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
56
+ mode='reflect')
57
+ input_data = input_data.squeeze(1)
58
+
59
+ forward_transform = F.conv1d(
60
+ input_data,
61
+ Variable(self.forward_basis, requires_grad=False),
62
+ stride=self.hop_length,
63
+ padding=0)
64
+
65
+ cutoff = int((self.filter_length / 2) + 1)
66
+ real_part = forward_transform[:, :cutoff, :]
67
+ imag_part = forward_transform[:, cutoff:, :]
68
+
69
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
70
+ phase = torch.autograd.Variable(
71
+ torch.atan2(imag_part.data, real_part.data))
72
+
73
+ return magnitude, phase
74
+
75
+ def inverse(self, magnitude, phase):
76
+ recombine_magnitude_phase = torch.cat(
77
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
78
+
79
+ inverse_transform = F.conv_transpose1d(
80
+ recombine_magnitude_phase,
81
+ Variable(self.inverse_basis, requires_grad=False),
82
+ stride=self.hop_length,
83
+ padding=0)
84
+
85
+ if self.window is not None:
86
+ window_sum = window_sumsquare(
87
+ self.window, magnitude.size(-1), hop_length=self.hop_length,
88
+ win_length=self.win_length, n_fft=self.filter_length,
89
+ dtype=np.float32)
90
+ # remove modulation effects
91
+ approx_nonzero_indices = torch.from_numpy(
92
+ np.where(window_sum > tiny(window_sum))[0])
93
+ window_sum = torch.autograd.Variable(
94
+ torch.from_numpy(window_sum), requires_grad=False)
95
+ window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
96
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
97
+
98
+ # scale by hop ratio
99
+ inverse_transform *= float(self.filter_length) / self.hop_length
100
+
101
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
102
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
103
+
104
+ return inverse_transform
105
+
106
+ def forward(self, input_data):
107
+ self.magnitude, self.phase = self.transform(input_data)
108
+ reconstruction = self.inverse(self.magnitude, self.phase)
109
+ return reconstruction
ref_clips/2971_4275_000003_000007.wav ADDED
Binary file (392 kB). View file
 
ref_clips/2971_4275_000020_000001.wav ADDED
Binary file (386 kB). View file
 
ref_clips/2971_4275_000023_000010.wav ADDED
Binary file (435 kB). View file
 
ref_clips/2971_4275_000049_000000.wav ADDED
Binary file (366 kB). View file
 
ref_clips/2971_4275_000049_000004.wav ADDED
Binary file (321 kB). View file
 
ref_clips/2971_4275_000050_000000.wav ADDED
Binary file (385 kB). View file
 
tts.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch,glob
2
+ from maha_tts import load_diffuser,load_models,infer_tts
3
+ from scipy.io.wavfile import write
4
+
5
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
+ print('Using:',device)
7
+ text = 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition.'
8
+ ref_clips = glob.glob('/Users/jaskaransingh/Desktop/NeuralSpeak/ref_clips/*.wav')
9
+ # print(len(ref_clips))
10
+
11
+ # diffuser = load_diffuser()
12
+ diff_model,ts_model,vocoder,diffuser = load_models('Smolie',device)
13
+ audio,sr = infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder)
14
+ write('test.wav',sr,audio)