jaskaran Singh commited on
Commit
390d94d
1 Parent(s): d0a22d2
.gitattributes CHANGED
@@ -37,3 +37,6 @@ maha_tts/pretrained_models/smolie/T2S/t2s_best.pt filter=lfs diff=lfs merge=lfs
37
  maha_tts/pretrained_models/smolie/S2A/s2a_latest.pt filter=lfs diff=lfs merge=lfs -text
38
  maha_tts/pretrained_models/hifigan/config.json filter=lfs diff=lfs merge=lfs -text
39
  maha_tts/pretrained_models/hifigan/g_02500000 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
37
  maha_tts/pretrained_models/smolie/S2A/s2a_latest.pt filter=lfs diff=lfs merge=lfs -text
38
  maha_tts/pretrained_models/hifigan/config.json filter=lfs diff=lfs merge=lfs -text
39
  maha_tts/pretrained_models/hifigan/g_02500000 filter=lfs diff=lfs merge=lfs -text
40
+ maha_tts/pretrained_models/hifigan filter=lfs diff=lfs merge=lfs -text
41
+ maha_tts/pretrained_models/Smolie-en filter=lfs diff=lfs merge=lfs -text
42
+ maha_tts/pretrained_models/Smolie-in filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,37 +1,37 @@
1
  <div align="center">
2
 
3
  <h1>MahaTTS: An Open-Source Large Speech Generation Model in the making</h1>
4
- a Dubverse Black initiative <br> <br>
5
 
6
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-eOQqznKWwAfMdusJ_LDtDhjIyAlSMrG?usp=sharing)
7
  [![Discord Shield](https://discordapp.com/api/guilds/1162007551987171410/widget.png?style=shield)](https://discord.gg/4VGnrgpBN)
8
-
9
  </div>
10
 
11
  ------
12
 
13
  ## Description
14
- MahaTTS (Maha means 'Great' in sanskrit), is a speech generation model which is inspired from tortoise-tts, except it uses seamless M4t wav2vec2 to extract semantic tokens.
15
- Since seamless M4t wav2vec2 is trained on multilingual data, it makes this model easier to scale on multilingual data.
16
 
17
- <img width="993" alt="Screenshot 2023-11-19 at 11 53 52 PM" src="https://github.com/dubverse-ai/MahaTTS/assets/32906806/7429d3b6-3f19-4bd8-9005-ff9e16a698f8">
18
 
19
- ### Architecture
20
- | Model (Smolie) | Parameters | Model Type | Output |
21
- |:-------------------------:|:----------:|------------|:-----------------:|
22
- | Text to Semantic (M1) | 69 M | Causal LM | 10,001 Tokens |
23
- | Semantic to MelSpec(M2) | 108 M | Diffusion | 2x 80x Melspec |
24
- | Hifi Gan Vocoder | 13 M | GAN | Audio Waveform |
 
 
 
 
25
 
26
  ## Features
27
- 1. Multilinguality
 
28
  2. Realistic Prosody and intonation
29
  3. Multi-voice capabilities
30
 
31
- ## Current Progress
32
- Trained on 200 hours of LibriTTS model -> 'Smolie'
33
-
34
  ## Installation
 
35
  ```bash
36
  pip install git+https://github.com/dubverse-ai/MahaTTS.git
37
  ```
@@ -39,8 +39,88 @@ pip install git+https://github.com/dubverse-ai/MahaTTS.git
39
  ```bash
40
  pip install maha-tts
41
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  ## Roadmap
43
- - [x] Smolie - eng
44
- - [ ] Smolie - indic
45
- - [ ] Optimizations for inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
1
  <div align="center">
2
 
3
  <h1>MahaTTS: An Open-Source Large Speech Generation Model in the making</h1>
4
+ a <a href = "https://black.dubverse.ai">Dubverse Black</a> initiative <br> <br>
5
 
6
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-eOQqznKWwAfMdusJ_LDtDhjIyAlSMrG?usp=sharing)
7
  [![Discord Shield](https://discordapp.com/api/guilds/1162007551987171410/widget.png?style=shield)](https://discord.gg/4VGnrgpBN)
 
8
  </div>
9
 
10
  ------
11
 
12
  ## Description
 
 
13
 
14
+ MahaTTS, with Maha signifying 'Great' in Sanskrit, is a Text to Speech Model developed by [Dubverse.ai](https://dubverse.ai). We drew inspiration from the [tortoise-tts](https://github.com/neonbjb/tortoise-tts) model, but our model uniquely utilizes seamless M4t wav2vec2 for semantic token extraction. As this specific variant of wav2vec2 is trained on multilingual data, it enhances our model's scalability across different languages.
15
 
16
+ We are providing access to pretrained model checkpoints, which are ready for inference and available for commercial use.
17
+
18
+ <img width="993" alt="MahaTTS Architecture" src="https://github.com/dubverse-ai/MahaTTS/assets/32906806/7429d3b6-3f19-4bd8-9005-ff9e16a698f8">
19
+
20
+ ## Updates
21
+
22
+ **2023-11-13**
23
+
24
+ - MahaTTS Released! Open sourced Smolie
25
+ - Community and access to new features on our [Discord](https://discord.gg/uFPrzBqyF2)
26
 
27
  ## Features
28
+
29
+ 1. Multilinguality (coming soon)
30
  2. Realistic Prosody and intonation
31
  3. Multi-voice capabilities
32
 
 
 
 
33
  ## Installation
34
+
35
  ```bash
36
  pip install git+https://github.com/dubverse-ai/MahaTTS.git
37
  ```
 
39
  ```bash
40
  pip install maha-tts
41
  ```
42
+
43
+ ## api usage
44
+
45
+ ```bash
46
+ !gdown --folder 1-HEc3V4f6X93I8_IfqExLfL3s8I_dXGZ -q # download speakers ref files
47
+
48
+ import torch,glob
49
+ from maha_tts import load_models,infer_tts
50
+ from scipy.io.wavfile import write
51
+ from IPython.display import Audio,display
52
+
53
+ # PATH TO THE SPEAKERS WAV FILES
54
+ speaker =['/content/infer_ref_wavs/2272_152282_000019_000001/',
55
+ '/content/infer_ref_wavs/2971_4275_000049_000000/',
56
+ '/content/infer_ref_wavs/4807_26852_000062_000000/',
57
+ '/content/infer_ref_wavs/6518_66470_000014_000002/']
58
+
59
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
+ diff_model,ts_model,vocoder,diffuser = load_models('Smolie',device)
61
+ print('Using:',device)
62
+
63
+ speaker_num = 0 # @param ["0", "1", "2", "3"] {type:"raw"}
64
+ text = "I freakin love how Elon came to life the moment they started talking about gaming and specifically diablo, you can tell that he didn't want that part of the discussion to end, while Lex to move on to the next subject! Once a true gamer, always a true gamer!" # @param {type:"string"}
65
+
66
+ ref_clips = glob.glob(speaker[speaker_num]+'*.wav')
67
+ audio,sr = infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder)
68
+
69
+ write('/content/test.wav',sr,audio)
70
+ ```
71
  ## Roadmap
72
+ - [x] Smolie - eng (trained on 200 hours of LibriTTS)
73
+ - [ ] Smolie - indic (Train on Indian languages, coming soon)
74
+ - [ ] Optimizations for inference (looking for contributors, check issues)
75
+
76
+ ## Some Generated Samples
77
+ 0 -> "I seriously laughed so much hahahaha (seals with headphones...) and appreciate both the interviewer and the subject. Major respect for two extraordinary humans - and in this time of gratefulness, I'm thankful for you both and this forum!"
78
+
79
+ 1 -> "I freakin love how Elon came to life the moment they started talking about gaming and specifically diablo, you can tell that he didn't want that part of the discussion to end, while Lex to move on to the next subject! Once a true gamer, always a true gamer!"
80
+
81
+ 2 -> "hello there! how are you?" (This one didn't work well, M1 model hallucinated)
82
+
83
+ 3 -> "Who doesn't love a good scary story, something to send a chill across your skin in the middle of summer's heat or really, any other time? And this year, we're celebrating the two hundredth birthday of one of the most famous scary stories of all time: Frankenstein."
84
+
85
+
86
+
87
+ https://github.com/dubverse-ai/MahaTTS/assets/32906806/462ee134-5d8c-43c8-a425-3b6cabd2ff85
88
+
89
+
90
+
91
+
92
+ https://github.com/dubverse-ai/MahaTTS/assets/32906806/40c62402-7f65-4a35-b739-d8b8a082ad62
93
+
94
+
95
+
96
+ https://github.com/dubverse-ai/MahaTTS/assets/32906806/f0a9628c-ef81-450d-ab82-2f4c4626864e
97
+
98
+
99
+
100
+ https://github.com/dubverse-ai/MahaTTS/assets/32906806/15476151-72ea-410d-bcdc-177433df7884
101
+
102
+
103
+ ## Technical Details
104
+
105
+ ### Model Params
106
+ | Model (Smolie) | Parameters | Model Type | Output |
107
+ |:-------------------------:|:----------:|------------|:-----------------:|
108
+ | Text to Semantic (M1) | 69 M | Causal LM | 10,001 Tokens |
109
+ | Semantic to MelSpec(M2) | 108 M | Diffusion | 2x 80x Melspec |
110
+ | Hifi Gan Vocoder | 13 M | GAN | Audio Waveform |
111
+
112
+ ### Languages Supported
113
+ | Language | Status |
114
+ | --- | :---: |
115
+ | English (en) | ✅ |
116
+
117
+ ## License
118
+
119
+ MahaTTS is licensed under the Apache 2.0 License.
120
+
121
+ ## 🙏 Appreciation
122
 
123
+ - [tortoise-tts](https://github.com/neonbjb/tortoise-tts)
124
+ - [M4t Seamless](https://github.com/facebookresearch/seamless_communication) [AudioLM](https://arxiv.org/abs/2209.03143) and many other ground-breaking papers that enabled the development of MahaTTS
125
+ - [Diffusion training](https://github.com/openai/guided-diffusion) for training diffusion model
126
+ - [Huggingface](https://huggingface.co/docs/transformers/index) for related training and inference code
maha_tts/__init__.py CHANGED
@@ -1 +1,3 @@
1
- from .inference import load_models,load_diffuser,infer_tts
 
 
 
1
+ from maha_tts.inference import load_models,load_diffuser,infer_tts
2
+ from maha_tts.config import config
3
+ __version__ = '1.0.0'
maha_tts/config.py CHANGED
@@ -5,8 +5,9 @@ class config:
5
  seed_value = 3407
6
 
7
  # Text to Semantic
8
- t2s_position = 2048
9
-
 
10
  # Semantic to acoustic
11
  sa_timesteps_max = 1000
12
 
 
5
  seed_value = 3407
6
 
7
  # Text to Semantic
8
+ t2s_position = 4096
9
+ langs = ['english','tamil', 'telugu', 'punjabi', 'marathi', 'hindi', 'gujarati', 'bengali', 'assamese']
10
+ lang_index = {i:j for j,i in enumerate(langs)}
11
  # Semantic to acoustic
12
  sa_timesteps_max = 1000
13
 
maha_tts/inference.py CHANGED
@@ -1,7 +1,8 @@
1
- import torch,glob,os
2
  import numpy as np
3
  import torch.nn.functional as F
4
 
 
5
  from librosa.filters import mel as librosa_mel_fn
6
  from scipy.io.wavfile import write
7
  from scipy.special import softmax
@@ -11,10 +12,12 @@ from maha_tts.models.vocoder import load_vocoder_model,infer_wav
11
  from maha_tts.utils.audio import denormalize_tacotron_mel,normalize_tacotron_mel,load_wav_to_torch,dynamic_range_compression
12
  from maha_tts.utils.stft import STFT
13
  from maha_tts.utils.diffusion import SpacedDiffusion,get_named_beta_schedule,space_timesteps
14
- from maha_tts.text.symbols import labels,text_labels,code_labels,text_enc,text_dec,code_enc,code_dec
15
  from maha_tts.text.cleaners import english_cleaners
16
  from maha_tts.config import config
17
 
 
 
18
  stft_fn = STFT(config.filter_length, config.hop_length, config.win_length)
19
 
20
  mel_basis = librosa_mel_fn(
@@ -23,13 +26,52 @@ mel_basis = librosa_mel_fn(
23
  mel_basis = torch.from_numpy(mel_basis).float()
24
 
25
  model_dirs= {
26
- 'Smolie':'asdf',
27
- 'hifigan':'asdf'
 
 
 
 
28
  }
29
 
30
- def download_model(name):
31
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def load_models(name,device=torch.device('cpu')):
35
  '''
@@ -51,10 +93,10 @@ def load_models(name,device=torch.device('cpu')):
51
 
52
  assert name in model_dirs, "no model name "+name
53
 
54
- checkpoint_diff = 'maha_tts/pretrained_models/'+str(name)+'/S2A/s2a_latest.pt'
55
- checkpoint_ts = 'maha_tts/pretrained_models/'+str(name)+'/T2S/t2s_best.pt'
56
- checkpoint_voco = 'maha_tts/pretrained_models/hifigan/g_02500000'
57
- voco_config_path = 'maha_tts/pretrained_models/hifigan/config.json'
58
 
59
  # for i in [checkpoint_diff,checkpoint_ts,checkpoint_voco,voco_config_path]:
60
  if not os.path.exists(checkpoint_diff) or not os.path.exists(checkpoint_ts):
@@ -64,15 +106,16 @@ def load_models(name,device=torch.device('cpu')):
64
  download_model('hifigan')
65
 
66
  diff_model = load_diff_model(checkpoint_diff,device)
67
- ts_model = load_TS_model(checkpoint_ts,device)
68
  vocoder = load_vocoder_model(voco_config_path,checkpoint_voco,device)
69
  diffuser = load_diffuser()
70
 
71
  return diff_model,ts_model,vocoder,diffuser
72
 
73
- def infer_mel(model,timeshape,code,ref_mel,diffuser,temperature=0.1):
74
  device = next(model.parameters()).device
75
  code = code.to(device)
 
76
  output_shape = (1,80,timeshape)
77
  noise = torch.randn(output_shape, device=code.device) * temperature
78
  mel = diffuser.p_sample_loop(model, output_shape, noise=noise,
@@ -84,17 +127,18 @@ def generate_semantic_tokens(
84
  text,
85
  model,
86
  ref_mels,
 
87
  temp = 0.7,
88
  top_p= None,
89
- top_k= None,
90
  n_tot_steps = 1000,
91
  device = None
92
  ):
93
  semb = []
94
  with torch.no_grad():
95
- for n in range(n_tot_steps):
96
- x = get_inputs(text,semb,ref_mels,device)
97
- _,result = model(**x)
98
  relevant_logits = result[0,:,-1]
99
  if top_p is not None:
100
  # faster to convert to numpy
@@ -125,9 +169,13 @@ def generate_semantic_tokens(
125
  semb = torch.tensor([int(i) for i in semb[:-1]])
126
  return semb,result
127
 
128
- def get_inputs(text,semb=[],ref_mels=[],device=torch.device('cpu')):
129
  text = text.lower()
130
- text_ids=[text_enc['<S>']]+[text_enc[i] for i in text.strip()]+[text_enc['<E>']]
 
 
 
 
131
  semb_ids=[code_enc['<SST>']]+[code_enc[i] for i in semb]#+[tok_enc['<EST>']]
132
 
133
  input_ids = text_ids+semb_ids
@@ -166,7 +214,7 @@ def get_mel(filepath):
166
  energy = torch.norm(magnitudes, dim=1).squeeze(0)
167
  return melspec,list(energy)
168
 
169
- def infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder):
170
  '''
171
  Generate audio from the given text using a text-to-speech (TTS) pipeline.
172
 
@@ -193,6 +241,7 @@ def infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder):
193
  Example usage:
194
  audio, sampling_rate = infer_tts("Hello, how are you?", ref_clips, diffuser, diff_model, ts_model, vocoder)
195
  '''
 
196
  text = english_cleaners(text)
197
  ref_mels = get_ref_mels(ref_clips)
198
  with torch.no_grad():
@@ -200,20 +249,21 @@ def infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder):
200
  text,
201
  ts_model,
202
  ref_mels,
 
203
  temp = 0.7,
204
  top_p= 0.8,
205
  top_k= 5,
206
  n_tot_steps = 1000,
207
- device = None
208
  )
209
  mel = infer_mel(diff_model,int(((sem_tok.shape[-1] * 320 / 16000) * 22050/256)+1),sem_tok.unsqueeze(0) + 1,
210
- ref_mels,diffuser,temperature=1.0)
211
 
212
  audio = infer_wav(mel,vocoder)
213
 
214
  return audio,config.sampling_rate
215
 
216
- def load_diffuser(timesteps = 100, gudiance=3):
217
  '''
218
  Load and configure a diffuser for denoising and guidance in the diffusion model.
219
 
@@ -227,10 +277,10 @@ def load_diffuser(timesteps = 100, gudiance=3):
227
  Description:
228
  The `load_diffuser` function initializes a diffuser with specific settings for denoising and guidance.
229
  '''
230
- betas = get_named_beta_schedule('cosine',config.sa_timesteps_max)
231
  diffuser = SpacedDiffusion(use_timesteps=space_timesteps(1000, [timesteps]), model_mean_type='epsilon',
232
  model_var_type='learned_range', loss_type='rescaled_mse', betas=betas,
233
- conditioning_free=True, conditioning_free_k=gudiance)
234
  diffuser.training=False
235
  return diffuser
236
 
 
1
+ import torch,glob,os,requests
2
  import numpy as np
3
  import torch.nn.functional as F
4
 
5
+ from tqdm import tqdm
6
  from librosa.filters import mel as librosa_mel_fn
7
  from scipy.io.wavfile import write
8
  from scipy.special import softmax
 
12
  from maha_tts.utils.audio import denormalize_tacotron_mel,normalize_tacotron_mel,load_wav_to_torch,dynamic_range_compression
13
  from maha_tts.utils.stft import STFT
14
  from maha_tts.utils.diffusion import SpacedDiffusion,get_named_beta_schedule,space_timesteps
15
+ from maha_tts.text.symbols import labels,text_labels,text_labels_en,code_labels,text_enc,text_dec,code_enc,code_dec,text_enc_en,text_dec_en
16
  from maha_tts.text.cleaners import english_cleaners
17
  from maha_tts.config import config
18
 
19
+ DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'maha_tts', 'models')
20
+ DEFAULT_MODELS_DIR = '/Users/jaskaransingh/Desktop/MahaTTS/models/'
21
  stft_fn = STFT(config.filter_length, config.hop_length, config.win_length)
22
 
23
  mel_basis = librosa_mel_fn(
 
26
  mel_basis = torch.from_numpy(mel_basis).float()
27
 
28
  model_dirs= {
29
+ 'Smolie':['https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/smolie/S2A/s2a_latest.pt',
30
+ 'https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/smolie/T2S/t2s_best.pt'],
31
+ 'Smolie-en':[''],
32
+ 'Smolie-in':[''],
33
+ 'hifigan':['https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/hifigan/g_02500000',
34
+ 'https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/hifigan/config.json']
35
  }
36
 
37
+ def download_file(url, filename):
38
+ response = requests.get(url, stream=True)
39
+ total_size = int(response.headers.get('content-length', 0))
40
+
41
+ # Check if the response was successful (status code 200)
42
+ response.raise_for_status()
43
+
44
+ with open(filename, 'wb') as file, tqdm(
45
+ desc=filename,
46
+ total=total_size,
47
+ unit='B',
48
+ unit_scale=True,
49
+ unit_divisor=1024,
50
+ ) as bar:
51
+ for data in response.iter_content(chunk_size=1024):
52
+ # Write data to the file
53
+ file.write(data)
54
+ # Update the progress bar
55
+ bar.update(len(data))
56
 
57
+ print(f"Download complete: {filename}\n")
58
+
59
+ def download_model(name):
60
+ print('Downloading ',name," ....")
61
+ checkpoint_diff = os.path.join(DEFAULT_MODELS_DIR,name,'s2a_latest.pt')
62
+ checkpoint_ts = os.path.join(DEFAULT_MODELS_DIR,name,'t2s_best.pt')
63
+ checkpoint_voco = os.path.join(DEFAULT_MODELS_DIR,'hifigan','g_02500000')
64
+ voco_config_path = os.path.join(DEFAULT_MODELS_DIR,'hifigan','config.json')
65
+
66
+ os.makedirs(os.path.join(DEFAULT_MODELS_DIR,name),exist_ok=True)
67
+
68
+ if name == 'hifigan':
69
+ download_file(model_dirs[name][0],checkpoint_voco)
70
+ download_file(model_dirs[name][1],voco_config_path)
71
+
72
+ else:
73
+ download_file(model_dirs[name][0],checkpoint_diff)
74
+ download_file(model_dirs[name][1],checkpoint_ts)
75
 
76
  def load_models(name,device=torch.device('cpu')):
77
  '''
 
93
 
94
  assert name in model_dirs, "no model name "+name
95
 
96
+ checkpoint_diff = os.path.join(DEFAULT_MODELS_DIR,name,'s2a_latest.pt')
97
+ checkpoint_ts = os.path.join(DEFAULT_MODELS_DIR,name,'t2s_best.pt')
98
+ checkpoint_voco = os.path.join(DEFAULT_MODELS_DIR,'hifigan','g_02500000')
99
+ voco_config_path = os.path.join(DEFAULT_MODELS_DIR,'hifigan','config.json')
100
 
101
  # for i in [checkpoint_diff,checkpoint_ts,checkpoint_voco,voco_config_path]:
102
  if not os.path.exists(checkpoint_diff) or not os.path.exists(checkpoint_ts):
 
106
  download_model('hifigan')
107
 
108
  diff_model = load_diff_model(checkpoint_diff,device)
109
+ ts_model = load_TS_model(checkpoint_ts,device,name)
110
  vocoder = load_vocoder_model(voco_config_path,checkpoint_voco,device)
111
  diffuser = load_diffuser()
112
 
113
  return diff_model,ts_model,vocoder,diffuser
114
 
115
+ def infer_mel(model,timeshape,code,ref_mel,diffuser,temperature=1.0):
116
  device = next(model.parameters()).device
117
  code = code.to(device)
118
+ ref_mel =ref_mel.to(device)
119
  output_shape = (1,80,timeshape)
120
  noise = torch.randn(output_shape, device=code.device) * temperature
121
  mel = diffuser.p_sample_loop(model, output_shape, noise=noise,
 
127
  text,
128
  model,
129
  ref_mels,
130
+ language=None,
131
  temp = 0.7,
132
  top_p= None,
133
+ top_k= 1,
134
  n_tot_steps = 1000,
135
  device = None
136
  ):
137
  semb = []
138
  with torch.no_grad():
139
+ for n in tqdm(range(n_tot_steps)):
140
+ x = get_inputs(text,semb,ref_mels,device,model.name)
141
+ _,result = model(**x,language=language)
142
  relevant_logits = result[0,:,-1]
143
  if top_p is not None:
144
  # faster to convert to numpy
 
169
  semb = torch.tensor([int(i) for i in semb[:-1]])
170
  return semb,result
171
 
172
+ def get_inputs(text,semb=[],ref_mels=[],device=torch.device('cpu'),name = 'Smolie-in'):
173
  text = text.lower()
174
+ if name=='Smolie-en':
175
+ text_ids=[text_enc_en['<S>']]+[text_enc_en[i] for i in text.strip()]+[text_enc_en['<E>']]
176
+ else:
177
+ text_ids=[text_enc['<S>']]+[text_enc[i] for i in text.strip()]+[text_enc['<E>']]
178
+
179
  semb_ids=[code_enc['<SST>']]+[code_enc[i] for i in semb]#+[tok_enc['<EST>']]
180
 
181
  input_ids = text_ids+semb_ids
 
214
  energy = torch.norm(magnitudes, dim=1).squeeze(0)
215
  return melspec,list(energy)
216
 
217
+ def infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder,language=None):
218
  '''
219
  Generate audio from the given text using a text-to-speech (TTS) pipeline.
220
 
 
241
  Example usage:
242
  audio, sampling_rate = infer_tts("Hello, how are you?", ref_clips, diffuser, diff_model, ts_model, vocoder)
243
  '''
244
+ device = next(ts_model.parameters()).device
245
  text = english_cleaners(text)
246
  ref_mels = get_ref_mels(ref_clips)
247
  with torch.no_grad():
 
249
  text,
250
  ts_model,
251
  ref_mels,
252
+ language,
253
  temp = 0.7,
254
  top_p= 0.8,
255
  top_k= 5,
256
  n_tot_steps = 1000,
257
+ device = device
258
  )
259
  mel = infer_mel(diff_model,int(((sem_tok.shape[-1] * 320 / 16000) * 22050/256)+1),sem_tok.unsqueeze(0) + 1,
260
+ normalize_tacotron_mel(ref_mels),diffuser,temperature=0.5)
261
 
262
  audio = infer_wav(mel,vocoder)
263
 
264
  return audio,config.sampling_rate
265
 
266
+ def load_diffuser(timesteps = 100, guidance=3):
267
  '''
268
  Load and configure a diffuser for denoising and guidance in the diffusion model.
269
 
 
277
  Description:
278
  The `load_diffuser` function initializes a diffuser with specific settings for denoising and guidance.
279
  '''
280
+ betas = get_named_beta_schedule('linear',config.sa_timesteps_max)
281
  diffuser = SpacedDiffusion(use_timesteps=space_timesteps(1000, [timesteps]), model_mean_type='epsilon',
282
  model_var_type='learned_range', loss_type='rescaled_mse', betas=betas,
283
+ conditioning_free=True, conditioning_free_k=guidance)
284
  diffuser.training=False
285
  return diffuser
286
 
maha_tts/models/__init__.py DELETED
File without changes
maha_tts/models/autoregressive.py DELETED
@@ -1,135 +0,0 @@
1
- '''
2
- Inspiration taken from https://github.com/neonbjb/tortoise-tts/blob/main/tortoise/models/autoregressive.py
3
- '''
4
- import os,sys
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- import torch.optim as optim
9
- import functools
10
-
11
- from typing import Any
12
- from torch.utils.data import Dataset,DataLoader
13
- from transformers import GPT2Tokenizer,GPT2Config, GPT2Model, GPT2LMHeadModel
14
- from tqdm import tqdm
15
- from maha_tts.config import config
16
- from maha_tts.text.symbols import labels,code_labels,text_labels
17
- from maha_tts.models.modules import GST
18
-
19
- def null_position_embeddings(range, dim):
20
- return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
21
-
22
- class TS_model(nn.Module):
23
- def __init__(self,n_embed = 512, n_layer = 16, n_head = 8):
24
- super(TS_model,self).__init__()
25
-
26
- self.vocab_size=len(labels)
27
- self.n_positions=config.t2s_position
28
- self.n_embed=n_embed
29
- self.n_layer=n_layer
30
- self.n_head=n_head
31
-
32
- self.config = GPT2Config(vocab_size=self.vocab_size,n_positions=self.n_positions,n_embd=self.n_embed,n_layer=self.n_layer,n_head=self.n_head)
33
- self.gpt = GPT2Model(self.config)
34
- del self.gpt.wpe
35
- self.gpt.wpe = functools.partial(null_position_embeddings, dim=self.n_embed)
36
- # Built-in token embeddings are unused.
37
- del self.gpt.wte
38
- self.GST = GST(model_channels=self.n_embed,num_heads=self.n_head,in_channels=config.n_mel_channels,k=1)
39
- self.text_head = nn.Linear(self.n_embed,len(text_labels))
40
- self.code_head = nn.Linear(self.n_embed,len(code_labels))
41
-
42
- self.text_positional_embed = LearnedPositionEmbeddings(self.n_positions,self.n_embed)
43
- self.code_positional_embed = LearnedPositionEmbeddings(self.n_positions,self.n_embed)
44
-
45
- self.text_embed = nn.Embedding(len(text_labels),self.n_embed)
46
- self.code_embed = nn.Embedding(len(code_labels),self.n_embed)
47
- self.final_norm = nn.LayerNorm(self.n_embed)
48
-
49
- def get_speaker_latent(self, ref_mels):
50
- ref_mels = ref_mels.unsqueeze(1) if len(
51
- ref_mels.shape) == 3 else ref_mels
52
-
53
- conds = []
54
- for j in range(ref_mels.shape[1]):
55
- conds.append(self.GST(ref_mels[:, j,:,:]))
56
-
57
- conds = torch.cat(conds, dim=-1)
58
- conds = conds.mean(dim=-1)
59
-
60
- return conds.unsqueeze(1)
61
-
62
- def forward(self,text_ids,codes_ids = None,speaker_embed=None,ref_clips=None,return_loss = False):
63
- assert speaker_embed is not None or ref_clips is not None
64
- text_embed = self.text_embed(text_ids)
65
- text_embed += self.text_positional_embed(text_embed)
66
-
67
- code_embed = None
68
- code_probs= None
69
-
70
- if codes_ids is not None:
71
- code_embed = self.code_embed(codes_ids)
72
- code_embed+= self.code_positional_embed(code_embed)
73
-
74
- if ref_clips is not None:
75
- speaker_embed = self.get_speaker_latent(ref_clips)
76
-
77
- text_embed,code_embed = self.get_logits(speaker_embed=speaker_embed,text_embed=text_embed,code_embed=code_embed)
78
-
79
- text_probs = self.text_head(text_embed).permute(0,2,1)
80
-
81
- if codes_ids is not None:
82
- code_probs = self.code_head(code_embed).permute(0,2,1)
83
-
84
- if return_loss:
85
- loss_text = F.cross_entropy(text_probs[:,:,:-1], text_ids[:,1:].long(), reduce=False)
86
- loss_mel = F.cross_entropy(code_probs[:,:,:-1], codes_ids[:,1:].long(), reduce=False)
87
- return loss_text,loss_mel,code_probs
88
-
89
- return text_probs,code_probs
90
-
91
-
92
- def get_logits(self,speaker_embed,text_embed,code_embed=None):
93
-
94
- if code_embed is not None:
95
- embed = torch.cat([speaker_embed,text_embed,code_embed],dim=1)
96
- else:
97
- embed = torch.cat([speaker_embed,text_embed],dim=1)
98
-
99
- gpt_output = self.gpt(inputs_embeds=embed, return_dict=True)
100
- enc = gpt_output.last_hidden_state[:, 1:]
101
- enc = self.final_norm(enc)
102
- if code_embed is not None:
103
- return enc[:,:text_embed.shape[1]],enc[:,-code_embed.shape[1]:]
104
-
105
- return enc[:,:text_embed.shape[1]],None
106
-
107
- class LearnedPositionEmbeddings(nn.Module):
108
- def __init__(self, seq_len, model_dim, init=.02):
109
- super().__init__()
110
- self.emb = nn.Embedding(seq_len, model_dim)
111
- # Initializing this way is standard for GPT-2
112
- self.emb.weight.data.normal_(mean=0.0, std=init)
113
-
114
- def forward(self, x):
115
- sl = x.shape[1]
116
- return self.emb(torch.arange(0, sl, device=x.device))
117
-
118
- def get_fixed_embedding(self, ind, dev):
119
- return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
120
-
121
- def load_TS_model(checkpoint,device):
122
- sem_model= TS_model(n_embed = 512, n_layer = 16, n_head = 8)
123
- sem_model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu')),strict=False)
124
- sem_model.eval().to(device)
125
-
126
- return sem_model
127
-
128
- if __name__ == '__main__':
129
- model=TS_model(n_embed = 256, n_layer = 6, n_head = 4)
130
-
131
- text_ids = torch.randint(0,100,(5,20))
132
- code_ids = torch.randint(0,100,(5,200))
133
- speaker_embed = torch.randn((5,1,256))
134
-
135
- output=model(text_ids=text_ids,speaker_embed=speaker_embed,codes_ids=code_ids,return_loss=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
maha_tts/models/diff_model.py DELETED
@@ -1,303 +0,0 @@
1
- '''
2
- inspiration taken from https://github.com/neonbjb/tortoise-tts/blob/main/tortoise/models/diffusion_decoder.py
3
- '''
4
- import sys
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- import math
9
-
10
- from maha_tts.config import config
11
- from torch import autocast
12
- from maha_tts.models.modules import QuartzNetBlock,AttentionBlock,mySequential,normalization,SCBD,SqueezeExcite,GST
13
-
14
- def timestep_embedding(timesteps, dim, max_period=10000):
15
- """
16
- Create sinusoidal timestep embeddings.
17
-
18
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
19
- These may be fractional.
20
- :param dim: the dimension of the output.
21
- :param max_period: controls the minimum frequency of the embeddings.
22
- :return: an [N x dim] Tensor of positional embeddings.
23
- """
24
- half = dim // 2
25
- freqs = torch.exp(
26
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
27
- ).to(device=timesteps.device)
28
- args = timesteps[:, None].float() * freqs[None]
29
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
30
- if dim % 2:
31
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
32
- return embedding
33
-
34
- class TimestepBlock(nn.Module):
35
- def forward(self, x, emb):
36
- """
37
- Apply the module to `x` given `emb` timestep embeddings.
38
- """
39
-
40
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
41
- def forward(self, x, emb):
42
- for layer in self:
43
- if isinstance(layer, TimestepBlock):
44
- x = layer(x, emb)
45
- else:
46
- x = layer(x)
47
- return x
48
-
49
- class QuartzNetBlock(TimestepBlock):
50
- '''Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
51
- if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
52
- '''
53
- def __init__(self,nin,nout,emb_channels,kernel_size=3,dropout=0.1,R=1,se=True,ratio=8,separable=False,bias=True,use_scale_shift_norm=True):
54
- super(QuartzNetBlock,self).__init__()
55
- self.use_scale_shift_norm = use_scale_shift_norm
56
- self.se=se
57
- self.in_layers = mySequential(
58
- nn.Conv1d(nin,nout,kernel_size=1,padding='same',bias=bias),
59
- normalization(nout) #nn.BatchNorm1d(nout,eps)
60
- )
61
-
62
- self.residual=mySequential(
63
- nn.Conv1d(nin,nout,kernel_size=1,padding='same',bias=bias),
64
- normalization(nout) #nn.BatchNorm1d(nout,eps)
65
- )
66
-
67
- nin=nout
68
- model=[]
69
-
70
- self.emb_layers = nn.Sequential(
71
- nn.SiLU(),
72
- nn.Linear(
73
- emb_channels,
74
- 2 * nout if use_scale_shift_norm else nout,
75
- ),
76
- )
77
-
78
- for i in range(R-1):
79
- model.append(SCBD(nin,nout,kernel_size,dropout,bias=bias))
80
- nin=nout
81
-
82
- if separable:
83
- model.append(SCBD(nin,nout,kernel_size,dropout,rd=False,bias=bias))
84
- else:
85
- model.append(SCBD(nin,nout,kernel_size,dropout,rd=False,separable=False,bias=bias))
86
-
87
- self.model=mySequential(*model)
88
- if self.se:
89
- self.se_layer=SqueezeExcite(nin,ratio)
90
-
91
- self.mout= mySequential(nn.SiLU(),nn.Dropout(dropout))
92
-
93
- def forward(self,x,emb,mask=None):
94
- x_new=self.in_layers(x)
95
- emb = self.emb_layers(emb)
96
- while len(emb.shape) < len(x_new.shape):
97
- emb = emb[..., None]
98
- scale, shift = torch.chunk(emb, 2, dim=1)
99
- x_new = x_new * (1 + scale) + shift
100
- y,_=self.model(x_new)
101
-
102
- if self.se:
103
- y,_=self.se_layer(y,mask)
104
- y+=self.residual(x)
105
- y=self.mout(y)
106
-
107
- return y
108
-
109
- class QuartzAttn(TimestepBlock):
110
- def __init__(self, model_channels, dropout, num_heads):
111
- super().__init__()
112
- self.resblk = QuartzNetBlock(model_channels, model_channels, model_channels,dropout=dropout,use_scale_shift_norm=True)
113
- self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
114
-
115
- def forward(self, x, time_emb):
116
- y = self.resblk(x, time_emb)
117
- return self.attn(y)
118
-
119
- class QuartzNet9x5(nn.Module):
120
- def __init__(self,model_channels,num_heads,enable_fp16=False):
121
- super(QuartzNet9x5,self).__init__()
122
- self.enable_fp16 = enable_fp16
123
-
124
- self.conv1=QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=3,dropout=0.1,R=3)
125
- kernels=[5,7,9,13,15,17]
126
- quartznet=[]
127
- attn=[]
128
- for i in kernels:
129
- quartznet.append(QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=i,dropout=0.1,R=5,se=True))
130
- attn.append(AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True))
131
- kernels=[21,23,25]
132
- quartznet.append(QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=21,dropout=0.1,R=5,se=True))
133
- attn.append(AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True))
134
-
135
- for i in kernels[1:]:
136
- quartznet.append(QuartzNetBlock(model_channels,model_channels,model_channels,kernel_size=i,dropout=0.1,R=5,se=True))
137
- attn.append(AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True))
138
- self.quartznet= nn.ModuleList(quartznet)
139
- self.attn = nn.ModuleList(attn)
140
- self.conv3=nn.Conv1d(model_channels, model_channels, 1, padding='same')
141
-
142
-
143
- def forward(self, x, time_emb):
144
- x = self.conv1(x,time_emb)
145
- # with autocast(x.device.type, enabled=self.enable_fp16):
146
- for n,(layer,attn) in enumerate(zip(self.quartznet,self.attn)):
147
- x = layer(x,time_emb) #256 dim
148
- x = attn(x)
149
- x = self.conv3(x.float())
150
- return x
151
-
152
- class DiffModel(nn.Module):
153
-
154
- def __init__(
155
- self,
156
- input_channels=80,
157
- output_channels=160,
158
- model_channels=512,
159
- num_heads=8,
160
- dropout=0.0,
161
- multispeaker = True,
162
- condition_free_per=0.1,
163
- training = False,
164
- ar_active = False,
165
- in_latent_channels = 10004
166
- ):
167
-
168
- super().__init__()
169
- self.input_channels = input_channels
170
- self.model_channels = model_channels
171
- self.output_channels = output_channels
172
- self.num_heads = num_heads
173
- self.dropout = dropout
174
- self.condition_free_per = condition_free_per
175
- self.training = training
176
- self.multispeaker = multispeaker
177
- self.ar_active = ar_active
178
- self.in_latent_channels = in_latent_channels
179
-
180
- if not self.ar_active:
181
- self.code_emb = nn.Embedding(config.semantic_model_centroids+1,model_channels)
182
- self.code_converter = mySequential(
183
- AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
184
- AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
185
- AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
186
- )
187
- else:
188
- self.code_converter = mySequential(
189
- nn.Conv1d(self.in_latent_channels, model_channels, 3, padding=1),
190
- AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
191
- AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
192
- AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
193
- AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
194
- )
195
- if self.multispeaker:
196
- self.GST = GST(model_channels,num_heads)
197
-
198
- self.code_norm = normalization(model_channels)
199
- self.time_norm = normalization(model_channels)
200
- self.noise_norm = normalization(model_channels)
201
- self.code_time_norm = normalization(model_channels)
202
-
203
- # self.code_latent = []
204
- self.time_embed = mySequential(
205
- nn.Linear(model_channels, model_channels),
206
- nn.SiLU(),
207
- nn.Linear(model_channels, model_channels),)
208
-
209
- self.input_block = nn.Conv1d(input_channels,model_channels,3,1,1)
210
- self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
211
-
212
- self.code_time = TimestepEmbedSequential(QuartzAttn(model_channels, dropout, num_heads),QuartzAttn(model_channels, dropout, num_heads),QuartzAttn(model_channels, dropout, num_heads))
213
- self.layers = QuartzNet9x5(model_channels,num_heads)
214
-
215
- self.out = nn.Sequential(
216
- normalization(model_channels),
217
- nn.SiLU(),
218
- nn.Conv1d(model_channels, output_channels, 3, padding=1),
219
- )
220
-
221
- def get_speaker_latent(self, ref_mels):
222
- ref_mels = ref_mels.unsqueeze(1) if len(
223
- ref_mels.shape) == 3 else ref_mels
224
-
225
- conds = []
226
- for j in range(ref_mels.shape[1]):
227
- conds.append(self.GST(ref_mels[:, j,:,:]))
228
-
229
- conds = torch.cat(conds, dim=-1)
230
- conds = conds.mean(dim=-1)
231
-
232
- return conds.unsqueeze(2)
233
-
234
- def forward(self ,x,t,code_emb,ref_clips=None,speaker_latents=None,conditioning_free=False):
235
- time_embed = self.time_norm(self.time_embed(timestep_embedding(t.unsqueeze(-1),self.model_channels)).permute(0,2,1)).squeeze(2)
236
- if conditioning_free:
237
- code_embed = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
238
- else:
239
- if not self.ar_active:
240
- code_embed = self.code_norm(self.code_converter(self.code_emb(code_emb).permute(0,2,1)))
241
- else:
242
- code_embed = self.code_norm(self.code_converter(code_emb))
243
- if self.multispeaker:
244
- assert speaker_latents is not None or ref_clips is not None
245
- if ref_clips is not None:
246
- speaker_latents = self.get_speaker_latent(ref_clips)
247
- cond_scale, cond_shift = torch.chunk(speaker_latents, 2, dim=1)
248
- code_embed = code_embed * (1 + cond_scale) + cond_shift
249
- if self.training and self.condition_free_per > 0:
250
- unconditioned_batches = torch.rand((code_embed.shape[0], 1, 1),
251
- device=code_embed.device) < self.condition_free_per
252
- code_embed = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_embed.shape[0], 1, 1),
253
- code_embed)
254
-
255
- expanded_code_emb = F.interpolate(code_embed, size=x.shape[-1], mode='nearest') #try different modes
256
-
257
- x_cond = self.code_time_norm(self.code_time(expanded_code_emb,time_embed))
258
-
259
- x = self.noise_norm(self.input_block(x))
260
- x += x_cond
261
- x = self.layers(x, time_embed)
262
- out = self.out(x)
263
- return out
264
-
265
- def load_diff_model(checkpoint,device,model_channels=512,ar_active=False,len_code_labels=10004):
266
- diff_model = DiffModel(input_channels=80,
267
- output_channels=160,
268
- model_channels=512,
269
- num_heads=8,
270
- dropout=0.15,
271
- condition_free_per=0.15,
272
- multispeaker=True,
273
- training=False,
274
- ar_active=ar_active,
275
- in_latent_channels = len_code_labels)
276
-
277
- # diff_model.load_state_dict(torch.load('/content/LibriTTS_fp64_10k/S2A/_latest.pt',map_location=torch.device('cpu')),strict=True)
278
- diff_model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu')),strict=True)
279
- diff_model=diff_model.eval().to(device)
280
- return diff_model
281
-
282
-
283
- if __name__ == '__main__':
284
-
285
- device = torch.device('cpu')
286
- diff_model = DiffModel(input_channels=80,
287
- output_channels=160,
288
- model_channels=1024,
289
- num_heads=8,
290
- dropout=0.1,
291
- num_layers=8,
292
- enable_fp16=True,
293
- condition_free_per=0.1,
294
- multispeaker=True,
295
- training=True).to(device)
296
-
297
- batch_Size = 32
298
- timeseries = 800
299
- from torchinfo import summary
300
- summary(diff_model, input_data={'x': torch.randn(batch_Size, 80, timeseries).to(device),
301
- 'ref_clips': torch.randn(batch_Size,3, 80, timeseries).to(device),
302
- 't':torch.LongTensor(size=[batch_Size,]).to(device),
303
- 'code_emb':torch.randint(0,201,(batch_Size,timeseries)).to(device)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
maha_tts/models/modules.py DELETED
@@ -1,406 +0,0 @@
1
- import torch,math
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torch.nn.init as init
5
- from einops import rearrange, repeat
6
-
7
- def zero_module(module):
8
- """
9
- Zero out the parameters of a module and return it.
10
- Using it for Zero Convolutions
11
- """
12
- for p in module.parameters():
13
- p.detach().zero_()
14
- return module
15
-
16
-
17
- class GroupNorm32(nn.GroupNorm):
18
- def forward(self, x):
19
- return super().forward(x.float()).type(x.dtype)
20
-
21
-
22
- def normalization(channels):
23
- """
24
- Make a standard normalization layer. of groups ranging from 2 to 32.
25
-
26
- :param channels: number of input channels.
27
- :return: an nn.Module for normalization.
28
- """
29
- groups = 32
30
- if channels <= 16:
31
- groups = 8
32
- elif channels <= 64:
33
- groups = 16
34
- while channels % groups != 0:
35
- groups = int(groups / 2)
36
- assert groups > 2
37
- return GroupNorm32(groups, channels)
38
-
39
-
40
- class mySequential(nn.Sequential):
41
- '''Using this to pass mask variable to nn layers
42
- '''
43
- def forward(self, *inputs):
44
- for module in self._modules.values():
45
- if type(inputs) == tuple:
46
- inputs = module(*inputs)
47
- else:
48
- inputs = module(inputs)
49
- return inputs
50
-
51
- class SepConv1D(nn.Module):
52
- '''Depth wise separable Convolution layer with mask
53
- '''
54
- def __init__(self,nin,nout,kernel_size,stride=1,dilation=1,padding_mode='same',bias=True):
55
- super(SepConv1D,self).__init__()
56
- self.conv1=nn.Conv1d(nin, nin, kernel_size=kernel_size, stride=stride,groups=nin,dilation=dilation,padding=padding_mode,bias=bias)
57
- self.conv2=nn.Conv1d(nin,nout,kernel_size=1,stride=1,padding=padding_mode,bias=bias)
58
-
59
- def forward(self,x,mask=None):
60
- if mask is not None:
61
- x = x * mask.unsqueeze(1).to(device=x.device)
62
- x=self.conv1(x)
63
- x=self.conv2(x)
64
- return x,mask
65
-
66
- class Conv1DBN(nn.Module):
67
- def __init__(self,nin,nout,kernel_size,stride=1,dilation=1,dropout=0.1,padding_mode='same',bias=False):
68
- super(Conv1DBN,self).__init__()
69
- self.conv1=nn.Conv1d(nin, nout, kernel_size=kernel_size, stride=stride,padding=padding_mode,dilation=dilation,bias=bias)
70
- self.bn=nn.BatchNorm1d(nout)
71
- self.drop=nn.Dropout(dropout)
72
-
73
- def forward(self,x,mask=None):
74
- if mask is not None:
75
- x = x * mask.unsqueeze(1).to(device=x.device)
76
- x=self.conv1(x)
77
- x=self.bn(x)
78
- x=F.relu(x)
79
- x=self.drop(x)
80
- return x,mask
81
-
82
- class Conv1d(nn.Module):
83
- '''normal conv1d with mask
84
- '''
85
- def __init__(self,nin,nout,kernel_size,padding,bias=True):
86
- super(Conv1d,self).__init__()
87
- self.l=nn.Conv1d(nin,nout,kernel_size,padding=padding,bias=bias)
88
- def forward(self,x,mask):
89
- if mask is not None:
90
- x = x * mask.unsqueeze(1).to(device=x.device)
91
- y=self.l(x)
92
- return y,mask
93
-
94
- class SqueezeExcite(nn.Module):
95
- '''Let the CNN decide how to add across channels
96
- '''
97
- def __init__(self,nin,ratio=8):
98
- super(SqueezeExcite,self).__init__()
99
- self.nin=nin
100
- self.ratio=ratio
101
-
102
- self.fc=mySequential(
103
- nn.Linear(nin,nin//ratio,bias=True),nn.SiLU(inplace=True),nn.Linear(nin//ratio,nin,bias=True)
104
- )
105
-
106
- def forward(self,x,mask=None):
107
- if mask is None:
108
- mask = torch.ones((x.shape[0],x.shape[-1]),dtype=torch.bool).to(x.device)
109
- mask=~mask
110
- x=x.float()
111
- x.masked_fill_(mask.unsqueeze(1), 0.0)
112
- mask=~mask
113
- y = (torch.sum(x, dim=-1, keepdim=True) / mask.unsqueeze(1).sum(dim=-1, keepdim=True)).type(x.dtype)
114
- # y=torch.mean(x,-1,keepdim=True)
115
- y=y.transpose(1, -1)
116
- y=self.fc(y)
117
- y=torch.sigmoid(y)
118
- y=y.transpose(1, -1)
119
- y= x * y
120
- return y,mask
121
-
122
-
123
-
124
- class SCBD(nn.Module):
125
- '''SeparableConv1D + Batchnorm + Dropout, Generally use it for middle layers and resnet
126
- '''
127
- def __init__(self,nin,nout,kernel_size,p=0.1,rd=True,separable=True,bias=True):
128
- super(SCBD,self).__init__()
129
- if separable:
130
- self.SC=SepConv1D(nin,nout,kernel_size,bias=bias)
131
- else:
132
- self.SC=Conv1d(nin,nout,kernel_size,padding='same',bias=bias)
133
-
134
- if rd: #relu and Dropout
135
- self.mout=mySequential(normalization(nout),nn.SiLU(), # nn.BatchNorm1d(nout,eps)
136
- nn.Dropout(p))
137
- else:
138
- self.mout=normalization(nout) # nn.BatchNorm1d(nout,eps)
139
-
140
- def forward(self,x,mask=None):
141
- if mask is not None:
142
- x = x * mask.unsqueeze(1).to(device=x.device)
143
- x,_= self.SC(x,mask)
144
- y = self.mout(x)
145
- return y,mask
146
-
147
- class QuartzNetBlock(nn.Module):
148
- '''Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
149
- if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
150
- '''
151
- def __init__(self,nin,nout,kernel_size,dropout=0.1,R=5,se=False,ratio=8,separable=False,bias=True):
152
- super(QuartzNetBlock,self).__init__()
153
- self.se=se
154
- self.residual=mySequential(
155
- nn.Conv1d(nin,nout,kernel_size=1,padding='same',bias=bias),
156
- normalization(nout) #nn.BatchNorm1d(nout,eps)
157
- )
158
- model=[]
159
-
160
- for i in range(R-1):
161
- model.append(SCBD(nin,nout,kernel_size,dropout,eps=0.001,bias=bias))
162
- nin=nout
163
-
164
- if separable:
165
- model.append(SCBD(nin,nout,kernel_size,dropout,eps=0.001,rd=False,bias=bias))
166
- else:
167
- model.append(SCBD(nin,nout,kernel_size,dropout,eps=0.001,rd=False,separable=False,bias=bias))
168
- self.model=mySequential(*model)
169
-
170
- if self.se:
171
- self.se_layer=SqueezeExcite(nin,ratio)
172
-
173
- self.mout= mySequential(nn.SiLU(),nn.Dropout(dropout))
174
-
175
- def forward(self,x,mask=None):
176
- if mask is not None:
177
- x = x * mask.unsqueeze(1).to(device=x.device)
178
- y,_=self.model(x,mask)
179
- if self.se:
180
- y,_=self.se_layer(y,mask)
181
- y+=self.residual(x)
182
- y=self.mout(y)
183
- return y,mask
184
-
185
- class QKVAttentionLegacy(nn.Module):
186
- """
187
- A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
188
- """
189
-
190
- def __init__(self, n_heads):
191
- super().__init__()
192
- self.n_heads = n_heads
193
-
194
- def forward(self, qkv, mask=None, rel_pos=None):
195
- """
196
- Apply QKV attention.
197
-
198
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
199
- :return: an [N x (H * C) x T] tensor after attention.
200
- """
201
- bs, width, length = qkv.shape
202
- assert width % (3 * self.n_heads) == 0
203
- ch = width // (3 * self.n_heads)
204
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
205
- scale = 1 / math.sqrt(math.sqrt(ch))
206
- weight = torch.einsum(
207
- "bct,bcs->bts", q * scale, k * scale
208
- ) # More stable with f16 than dividing afterwards
209
- if rel_pos is not None:
210
- weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
211
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
212
- if mask is not None:
213
- # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
214
- mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
215
- weight = weight * mask
216
- a = torch.einsum("bts,bcs->bct", weight, v)
217
-
218
- return a.reshape(bs, -1, length)
219
-
220
- class AttentionBlock(nn.Module):
221
- """
222
- An attention block that allows spatial positions to attend to each other.
223
-
224
- Originally ported from here, but adapted to the N-d case.
225
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
226
- """
227
-
228
- def __init__(
229
- self,
230
- channels,
231
- num_heads=1,
232
- num_head_channels=-1,
233
- do_checkpoint=True,
234
- relative_pos_embeddings=False,
235
- ):
236
- super().__init__()
237
- self.channels = channels
238
- self.do_checkpoint = do_checkpoint
239
- if num_head_channels == -1:
240
- self.num_heads = num_heads
241
- else:
242
- assert (
243
- channels % num_head_channels == 0
244
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
245
- self.num_heads = channels // num_head_channels
246
- self.norm = normalization(channels)
247
- self.qkv = nn.Conv1d(channels, channels * 3, 1)
248
- # split heads before split qkv
249
- self.attention = QKVAttentionLegacy(self.num_heads)
250
-
251
- self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) # no effect of attention in the inital stages.
252
- # if relative_pos_embeddings:
253
- self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) #need to read about this, vit and swin transformers
254
- # self.relative_pos_embeddings = FixedPositionalEmbedding(dim=channels)
255
- # else:
256
- # self.relative_pos_embeddings = None
257
-
258
- def forward(self, x, mask=None):
259
- b, c, *spatial = x.shape
260
- x = x.reshape(b, c, -1)
261
- qkv = self.qkv(self.norm(x))
262
- h = self.attention(qkv, mask, self.relative_pos_embeddings)
263
- h = self.proj_out(h)
264
- return (x + h).reshape(b, c, *spatial)
265
-
266
- class AbsolutePositionalEmbedding(nn.Module):
267
- def __init__(self, dim, max_seq_len):
268
- super().__init__()
269
- self.scale = dim ** -0.5
270
- self.emb = nn.Embedding(max_seq_len, dim)
271
-
272
- def forward(self, x):
273
- n = torch.arange(x.shape[1], device=x.device)
274
- pos_emb = self.emb(n)
275
- pos_emb = rearrange(pos_emb, 'n d -> () n d')
276
- return pos_emb * self.scale
277
-
278
-
279
- class FixedPositionalEmbedding(nn.Module):
280
- def __init__(self, dim):
281
- super().__init__()
282
- inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
283
- self.register_buffer('inv_freq', inv_freq)
284
-
285
- def forward(self, x, seq_dim=1, offset=0):
286
- t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
287
- sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
288
- emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
289
- return rearrange(emb, 'n d -> () n d')
290
-
291
- class RelativePositionBias(nn.Module):
292
- def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
293
- super().__init__()
294
- self.scale = scale
295
- self.causal = causal
296
- self.num_buckets = num_buckets
297
- self.max_distance = max_distance
298
- self.relative_attention_bias = nn.Embedding(num_buckets, heads)
299
-
300
- @staticmethod
301
- def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
302
- ret = 0
303
- n = -relative_position
304
- if not causal:
305
- num_buckets //= 2
306
- ret += (n < 0).long() * num_buckets
307
- n = torch.abs(n)
308
- else:
309
- n = torch.max(n, torch.zeros_like(n))
310
-
311
- max_exact = num_buckets // 2
312
- is_small = n < max_exact
313
-
314
- val_if_large = max_exact + (
315
- torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
316
- ).long()
317
- val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
318
-
319
- ret += torch.where(is_small, n, val_if_large)
320
- return ret
321
-
322
- def forward(self, qk_dots):
323
- i, j, device = *qk_dots.shape[-2:], qk_dots.device
324
- q_pos = torch.arange(i, dtype=torch.long, device=device)
325
- k_pos = torch.arange(j, dtype=torch.long, device=device)
326
- rel_pos = k_pos[None, :] - q_pos[:, None]
327
- rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
328
- max_distance=self.max_distance)
329
- values = self.relative_attention_bias(rp_bucket)
330
- bias = rearrange(values, 'i j h -> () h i j')
331
- return qk_dots + (bias * self.scale)
332
-
333
-
334
-
335
- class MultiHeadAttention(nn.Module):
336
- '''
337
- only for GST
338
- input:
339
- query --- [N, T_q, query_dim]
340
- key --- [N, T_k, key_dim]
341
- output:
342
- out --- [N, T_q, num_units]
343
- '''
344
- def __init__(self, query_dim, key_dim, num_units, num_heads):
345
- super().__init__()
346
- self.num_units = num_units
347
- self.num_heads = num_heads
348
- self.key_dim = key_dim
349
-
350
- self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
351
- self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
352
- self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
353
-
354
- def forward(self, query, key):
355
- querys = self.W_query(query) # [N, T_q, num_units]
356
- keys = self.W_key(key) # [N, T_k, num_units]
357
- values = self.W_value(key)
358
-
359
- split_size = self.num_units // self.num_heads
360
- querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
361
- keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
362
- values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
363
-
364
- # score = softmax(QK^T / (d_k ** 0.5))
365
- scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
366
- scores = scores / (self.key_dim ** 0.5)
367
- scores = F.softmax(scores, dim=3)
368
-
369
- # out = score * V
370
- out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
371
- out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
372
-
373
- return out
374
-
375
-
376
- class GST(nn.Module):
377
- def __init__(self,model_channels=512,num_heads=8,in_channels=80,k=2):
378
- super(GST,self).__init__()
379
- self.model_channels=model_channels
380
- self.num_heads=num_heads
381
-
382
- self.reference_encoder=nn.Sequential(
383
- nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
384
- nn.Conv1d(model_channels, model_channels*k,3,padding=1,stride=2),
385
- AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
386
- AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
387
- AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
388
- AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
389
- AttentionBlock(model_channels*k, num_heads, relative_pos_embeddings=True, do_checkpoint=False)
390
- )
391
-
392
- def forward(self,x):
393
- x=self.reference_encoder(x)
394
- return x
395
-
396
-
397
- if __name__ == '__main__':
398
- device = torch.device('cpu')
399
- m = GST(512,10).to(device)
400
- mels = torch.rand((16,80,1000)).to(device)
401
-
402
- o = m(mels)
403
- print(o.shape,'final output')
404
-
405
- from torchinfo import summary
406
- summary(m, input_data={'x': torch.randn(16,80,500).to(device)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
maha_tts/models/vocoder.py DELETED
@@ -1,342 +0,0 @@
1
- '''
2
- copde from https://github.com/jik876/hifi-gan/blob/master/models.py
3
- '''
4
-
5
- import json,os
6
- import torch
7
- import torch.nn.functional as F
8
- import torch.nn as nn
9
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
11
- # from utils import init_weights, get_padding
12
-
13
- LRELU_SLOPE = 0.1
14
-
15
- class AttrDict(dict):
16
- def __init__(self, *args, **kwargs):
17
- super(AttrDict, self).__init__(*args, **kwargs)
18
- self.__dict__ = self
19
-
20
- def init_weights(m, mean=0.0, std=0.01):
21
- classname = m.__class__.__name__
22
- if classname.find("Conv") != -1:
23
- m.weight.data.normal_(mean, std)
24
-
25
-
26
- def apply_weight_norm(m):
27
- classname = m.__class__.__name__
28
- if classname.find("Conv") != -1:
29
- weight_norm(m)
30
-
31
-
32
- def get_padding(kernel_size, dilation=1):
33
- return int((kernel_size*dilation - dilation)/2)
34
-
35
-
36
- class ResBlock1(torch.nn.Module):
37
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
38
- super(ResBlock1, self).__init__()
39
- self.h = h
40
- self.convs1 = nn.ModuleList([
41
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
42
- padding=get_padding(kernel_size, dilation[0]))),
43
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
44
- padding=get_padding(kernel_size, dilation[1]))),
45
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
46
- padding=get_padding(kernel_size, dilation[2])))
47
- ])
48
- self.convs1.apply(init_weights)
49
-
50
- self.convs2 = nn.ModuleList([
51
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
52
- padding=get_padding(kernel_size, 1))),
53
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
54
- padding=get_padding(kernel_size, 1))),
55
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
56
- padding=get_padding(kernel_size, 1)))
57
- ])
58
- self.convs2.apply(init_weights)
59
-
60
- def forward(self, x):
61
- for c1, c2 in zip(self.convs1, self.convs2):
62
- xt = F.leaky_relu(x, LRELU_SLOPE)
63
- xt = c1(xt)
64
- xt = F.leaky_relu(xt, LRELU_SLOPE)
65
- xt = c2(xt)
66
- x = xt + x
67
- return x
68
-
69
- def remove_weight_norm(self):
70
- for l in self.convs1:
71
- remove_weight_norm(l)
72
- for l in self.convs2:
73
- remove_weight_norm(l)
74
-
75
-
76
- class ResBlock2(torch.nn.Module):
77
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
78
- super(ResBlock2, self).__init__()
79
- self.h = h
80
- self.convs = nn.ModuleList([
81
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
82
- padding=get_padding(kernel_size, dilation[0]))),
83
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
84
- padding=get_padding(kernel_size, dilation[1])))
85
- ])
86
- self.convs.apply(init_weights)
87
-
88
- def forward(self, x):
89
- for c in self.convs:
90
- xt = F.leaky_relu(x, LRELU_SLOPE)
91
- xt = c(xt)
92
- x = xt + x
93
- return x
94
-
95
- def remove_weight_norm(self):
96
- for l in self.convs:
97
- remove_weight_norm(l)
98
-
99
-
100
- class Generator(torch.nn.Module):
101
- def __init__(self, h):
102
- super(Generator, self).__init__()
103
- self.h = h
104
- self.num_kernels = len(h.resblock_kernel_sizes)
105
- self.num_upsamples = len(h.upsample_rates)
106
- self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
107
- resblock = ResBlock1 if h.resblock == '1' else ResBlock2
108
-
109
- self.ups = nn.ModuleList()
110
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
111
- self.ups.append(weight_norm(
112
- ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
113
- k, u, padding=(k-u)//2)))
114
-
115
- self.resblocks = nn.ModuleList()
116
- for i in range(len(self.ups)):
117
- ch = h.upsample_initial_channel//(2**(i+1))
118
- for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
119
- self.resblocks.append(resblock(h, ch, k, d))
120
-
121
- self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
122
- self.ups.apply(init_weights)
123
- self.conv_post.apply(init_weights)
124
-
125
- def forward(self, x):
126
- x = self.conv_pre(x)
127
- for i in range(self.num_upsamples):
128
- x = F.leaky_relu(x, LRELU_SLOPE)
129
- x = self.ups[i](x)
130
- xs = None
131
- for j in range(self.num_kernels):
132
- if xs is None:
133
- xs = self.resblocks[i*self.num_kernels+j](x)
134
- else:
135
- xs += self.resblocks[i*self.num_kernels+j](x)
136
- x = xs / self.num_kernels
137
- x = F.leaky_relu(x)
138
- x = self.conv_post(x)
139
- x = torch.tanh(x)
140
-
141
- return x
142
-
143
- def remove_weight_norm(self):
144
- # print('Removing weight norm...')
145
- for l in self.ups:
146
- remove_weight_norm(l)
147
- for l in self.resblocks:
148
- l.remove_weight_norm()
149
- remove_weight_norm(self.conv_pre)
150
- remove_weight_norm(self.conv_post)
151
-
152
-
153
- class DiscriminatorP(torch.nn.Module):
154
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
155
- super(DiscriminatorP, self).__init__()
156
- self.period = period
157
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
158
- self.convs = nn.ModuleList([
159
- norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
160
- norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
161
- norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
162
- norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
163
- norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
164
- ])
165
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
166
-
167
- def forward(self, x):
168
- fmap = []
169
-
170
- # 1d to 2d
171
- b, c, t = x.shape
172
- if t % self.period != 0: # pad first
173
- n_pad = self.period - (t % self.period)
174
- x = F.pad(x, (0, n_pad), "reflect")
175
- t = t + n_pad
176
- x = x.view(b, c, t // self.period, self.period)
177
-
178
- for l in self.convs:
179
- x = l(x)
180
- x = F.leaky_relu(x, LRELU_SLOPE)
181
- fmap.append(x)
182
- x = self.conv_post(x)
183
- fmap.append(x)
184
- x = torch.flatten(x, 1, -1)
185
-
186
- return x, fmap
187
-
188
-
189
- class MultiPeriodDiscriminator(torch.nn.Module):
190
- def __init__(self):
191
- super(MultiPeriodDiscriminator, self).__init__()
192
- self.discriminators = nn.ModuleList([
193
- DiscriminatorP(2),
194
- DiscriminatorP(3),
195
- DiscriminatorP(5),
196
- DiscriminatorP(7),
197
- DiscriminatorP(11),
198
- ])
199
-
200
- def forward(self, y, y_hat):
201
- y_d_rs = []
202
- y_d_gs = []
203
- fmap_rs = []
204
- fmap_gs = []
205
- for i, d in enumerate(self.discriminators):
206
- y_d_r, fmap_r = d(y)
207
- y_d_g, fmap_g = d(y_hat)
208
- y_d_rs.append(y_d_r)
209
- fmap_rs.append(fmap_r)
210
- y_d_gs.append(y_d_g)
211
- fmap_gs.append(fmap_g)
212
-
213
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
214
-
215
-
216
- class DiscriminatorS(torch.nn.Module):
217
- def __init__(self, use_spectral_norm=False):
218
- super(DiscriminatorS, self).__init__()
219
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
220
- self.convs = nn.ModuleList([
221
- norm_f(Conv1d(1, 128, 15, 1, padding=7)),
222
- norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
223
- norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
224
- norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
225
- norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
226
- norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
227
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
228
- ])
229
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
230
-
231
- def forward(self, x):
232
- fmap = []
233
- for l in self.convs:
234
- x = l(x)
235
- x = F.leaky_relu(x, LRELU_SLOPE)
236
- fmap.append(x)
237
- x = self.conv_post(x)
238
- fmap.append(x)
239
- x = torch.flatten(x, 1, -1)
240
-
241
- return x, fmap
242
-
243
-
244
- class MultiScaleDiscriminator(torch.nn.Module):
245
- def __init__(self):
246
- super(MultiScaleDiscriminator, self).__init__()
247
- self.discriminators = nn.ModuleList([
248
- DiscriminatorS(use_spectral_norm=True),
249
- DiscriminatorS(),
250
- DiscriminatorS(),
251
- ])
252
- self.meanpools = nn.ModuleList([
253
- AvgPool1d(4, 2, padding=2),
254
- AvgPool1d(4, 2, padding=2)
255
- ])
256
-
257
- def forward(self, y, y_hat):
258
- y_d_rs = []
259
- y_d_gs = []
260
- fmap_rs = []
261
- fmap_gs = []
262
- for i, d in enumerate(self.discriminators):
263
- if i != 0:
264
- y = self.meanpools[i-1](y)
265
- y_hat = self.meanpools[i-1](y_hat)
266
- y_d_r, fmap_r = d(y)
267
- y_d_g, fmap_g = d(y_hat)
268
- y_d_rs.append(y_d_r)
269
- fmap_rs.append(fmap_r)
270
- y_d_gs.append(y_d_g)
271
- fmap_gs.append(fmap_g)
272
-
273
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
274
-
275
-
276
- def feature_loss(fmap_r, fmap_g):
277
- loss = 0
278
- for dr, dg in zip(fmap_r, fmap_g):
279
- for rl, gl in zip(dr, dg):
280
- loss += torch.mean(torch.abs(rl - gl))
281
-
282
- return loss*2
283
-
284
-
285
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
286
- loss = 0
287
- r_losses = []
288
- g_losses = []
289
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
290
- r_loss = torch.mean((1-dr)**2)
291
- g_loss = torch.mean(dg**2)
292
- loss += (r_loss + g_loss)
293
- r_losses.append(r_loss.item())
294
- g_losses.append(g_loss.item())
295
-
296
- return loss, r_losses, g_losses
297
-
298
-
299
- def generator_loss(disc_outputs):
300
- loss = 0
301
- gen_losses = []
302
- for dg in disc_outputs:
303
- l = torch.mean((1-dg)**2)
304
- gen_losses.append(l)
305
- loss += l
306
-
307
- return loss, gen_losses
308
-
309
- def load_checkpoint(filepath, device):
310
- assert os.path.isfile(filepath)
311
- checkpoint_dict = torch.load(filepath, map_location=device)
312
- return checkpoint_dict
313
-
314
- def load_vocoder_model(config_path,checkpoint_path,device):
315
- # config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
316
- with open(config_path) as f:
317
- data = f.read()
318
-
319
- global h
320
- json_config = json.loads(data)
321
- h = AttrDict(json_config)
322
-
323
- torch.manual_seed(h.seed)
324
-
325
- generator = Generator(h).to(device)
326
-
327
- state_dict_g = load_checkpoint(checkpoint_path, device)
328
- generator.load_state_dict(state_dict_g['generator'])
329
-
330
- generator.eval()
331
- generator.remove_weight_norm()
332
-
333
- return generator
334
-
335
- def infer_wav(mel,generator):
336
- MAX_WAV_VALUE =32768.0
337
- with torch.no_grad():
338
- y_g_hat = generator(mel)
339
- audio = y_g_hat.squeeze()
340
- audio = audio * MAX_WAV_VALUE
341
- audio = audio.cpu().numpy().astype('int16')
342
- return audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
maha_tts/pretrained_models/.DS_Store CHANGED
Binary files a/maha_tts/pretrained_models/.DS_Store and b/maha_tts/pretrained_models/.DS_Store differ
 
maha_tts/pretrained_models/Smolie-en/.DS_Store ADDED
Binary file (6.15 kB). View file
 
maha_tts/pretrained_models/Smolie-en/s2a_latest.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1cb3aef9bebda0535dce135de3ae5f23f62ec3890eed87469dfe4a9a07f0f98
3
+ size 1720934888
maha_tts/pretrained_models/{smolie/T2S → Smolie-en}/t2s_best.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:67a10c3bf12a8bca3dd67075ccbfbd79887b244109bd9c96013b0f348d9e2570
3
- size 276146627
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6be1b489366ebbd35e55404be875804d380b0430587319f67e592da7ba1b5240
3
+ size 276143363
maha_tts/pretrained_models/Smolie-in/.DS_Store ADDED
Binary file (6.15 kB). View file
 
maha_tts/pretrained_models/Smolie-in/s2a_latest.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce73f611d8071f69111b71363fdad9465d84c4431fc26dc6b6de4595591c3305
3
+ size 1720934441
maha_tts/pretrained_models/{smolie/S2A/s2a_latest.pt → Smolie-in/t2s_best.pt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bf359fab98b047ef89d79a99a78fee9c38880e307630d3b3af7bc9cb170f366b
3
- size 432971673
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c867f8a11f364b4cf543b42335e0a7f0450078693c66539296de1adcf2f27e6
3
+ size 823446386
maha_tts/text/cleaners.py CHANGED
@@ -135,8 +135,8 @@ def transliteration_cleaners(text):
135
 
136
  def english_cleaners(text):
137
  '''Pipeline for English text, including number and abbreviation expansion.'''
138
- text = convert_to_ascii(text)
139
- text = lowercase(text)
140
  text = expand_numbers(text)
141
  text = expand_abbreviations(text)
142
  text = collapse_whitespace(text)
 
135
 
136
  def english_cleaners(text):
137
  '''Pipeline for English text, including number and abbreviation expansion.'''
138
+ # text = convert_to_ascii(text)
139
+ # text = lowercase(text)
140
  text = expand_numbers(text)
141
  text = expand_abbreviations(text)
142
  text = collapse_whitespace(text)
maha_tts/text/symbols.py CHANGED
@@ -2,12 +2,18 @@ import sys
2
  from maha_tts.config import config
3
 
4
  labels=" abcdefghijklmnopqrstuvwxyz.,:;'()?!\""
5
- labels=" !\"'(),-.:;?[]abcdefghijklmnopqrstuvwxyzàâèéêü’“”"
 
 
6
  labels= [i for i in labels]
 
7
 
8
  text_labels = [i for i in labels]
9
  text_labels+='<S>','<E>','<PAD>'
10
 
 
 
 
11
  code_labels= [str(i) for i in range(config.semantic_model_centroids)]
12
  labels+=code_labels
13
  code_labels+='<SST>','<EST>','<PAD>'
@@ -21,6 +27,10 @@ tok_dec = {i:j for i,j in enumerate(labels)}
21
  text_enc = {j:i for i,j in enumerate(text_labels)}
22
  text_dec = {i:j for i,j in enumerate(text_labels)}
23
 
 
 
 
 
24
  #code encdec
25
  code_enc = {j:i for i,j in enumerate(code_labels)}
26
  code_dec = {i:j for i,j in enumerate(code_labels)}
 
2
  from maha_tts.config import config
3
 
4
  labels=" abcdefghijklmnopqrstuvwxyz.,:;'()?!\""
5
+ labels_en=" !\"'(),-.:;?[]abcdefghijklmnopqrstuvwxyzàâèéêü’“”"
6
+ labels='''ଊతూിਮ০य़లഢਪਟକఝૂएड‌`यঢअచଢ଼ਧ—ତলશರଖच,பવड़ષंಈಮਤਇଥkखഗబ= इਸಣਹછ™ୟ.ोೀৎುഊଳંർਘମഴఙसଗൃlଝਜఇഓਐভയಅಠభാടਔಒ೧পஜaૅૠএଲ৯eകँ৭àৱऊટഒਗহિేயీെஈଓഭೊাੌಙ१ଈःസठખm‘ొऍಿcശrట।ऱଋઘਛெਬಂङಹஞ਼ભ১"એੂചಸગಷ়ଁമಓtஒઉಪs్-pଛ›ढ+ಆ'বનধৰউીଅઝ੍ೂʼൂఔfતषഖঢ়৬ਖक़ਵషணझപળଔઞੇವௗઁത২xెഥख़iটਲધಔೇீથ*ഝॅঃஓूఒীనਜ਼எુுహौ९ൗౌফഔોhஔণంफ़ఋçଯઊൽଆ’ୁைഛ२&ঁണ़ైৌআஆোਠਭजொમळಘஷഏি/ચਾ“ਯ$ଐീवऩ८ઢఛఎেథഠ[औಳରथୃൈಝnজਥऑଷੱल೯wओଵढ़மവरడఊbೖਈૃपdêଉఐ;ै ఢ ઔકচ৩‎ਊൾഉਕ೦ಏj€:ਦಗાളੁशफുழൻಊगફఏఅ?णറഘಞ४ಡಫଠ್ড೨ൊঞमਂસૉॉઅരஙલঘନ്ఠॄvઋృষऎகೕଘઆఞലେূஊఉૈദఫఈदকज़!ధઠవଞறಟਖ਼ਫ਼ইਢഡঠஃஸୂटঅହఆளోईৃಜ॥(ઈଏੀഈक્গ ಚಢഹೃिஏಯyশேଡೋੈਣડఃഷഇਸ਼நখಋோনૐਏgहৗೈृவੰଜग़ੋ୍)ൌరమൺংञਓપయധஇോ५ઃಲళঊತॽ­ന…ঙಭाಇउਅଶরઓି্ூমuపബ\ૌଟबਆुಕଫதছ३దਿದణஐௌ்ৈqఘலહಾ०ಛঐிওऋ‍ి৮ेਨଇүଧഞಶéਚ्৫ୋశఓદঈୀ৪ପüুങਗ਼ઑજথఖঝಐऽਰାആജीઇੜ]आବଡ଼ഫಥుಎણଃયछஅેஹംଢબoদഎగଭాേഅঋসഐಃzਡಬਝன–உಖಉഃযସୈೆకॐನഋয়సசଙড়ୱऒऐઐतଂாতરâèनಧ॑டঔभர”​జ৷ਫଣଚଦधघೌୌਉ'''
7
+
8
  labels= [i for i in labels]
9
+ labels_en= [i for i in labels_en]
10
 
11
  text_labels = [i for i in labels]
12
  text_labels+='<S>','<E>','<PAD>'
13
 
14
+ text_labels_en = [i for i in labels_en]
15
+ text_labels_en+='<S>','<E>','<PAD>'
16
+
17
  code_labels= [str(i) for i in range(config.semantic_model_centroids)]
18
  labels+=code_labels
19
  code_labels+='<SST>','<EST>','<PAD>'
 
27
  text_enc = {j:i for i,j in enumerate(text_labels)}
28
  text_dec = {i:j for i,j in enumerate(text_labels)}
29
 
30
+
31
+ text_enc_en = {j:i for i,j in enumerate(text_labels_en)}
32
+ text_dec_en = {i:j for i,j in enumerate(text_labels_en)}
33
+
34
  #code encdec
35
  code_enc = {j:i for i,j in enumerate(code_labels)}
36
  code_dec = {i:j for i,j in enumerate(code_labels)}
maha_tts/utils/audio.py CHANGED
@@ -6,8 +6,8 @@ from scipy.signal import get_window
6
  from scipy.io.wavfile import read
7
  from maha_tts.config import config
8
 
9
- TACOTRON_MEL_MAX = 2.3143386840820312
10
- TACOTRON_MEL_MIN = -11.512925148010254
11
 
12
 
13
  def denormalize_tacotron_mel(norm_mel):
 
6
  from scipy.io.wavfile import read
7
  from maha_tts.config import config
8
 
9
+ TACOTRON_MEL_MAX = 2.4
10
+ TACOTRON_MEL_MIN = -11.5130
11
 
12
 
13
  def denormalize_tacotron_mel(norm_mel):
ref_clips/2971_4275_000003_000007.wav DELETED
Binary file (392 kB)
 
ref_clips/2971_4275_000020_000001.wav DELETED
Binary file (386 kB)
 
ref_clips/2971_4275_000023_000010.wav DELETED
Binary file (435 kB)
 
ref_clips/2971_4275_000049_000000.wav DELETED
Binary file (366 kB)
 
ref_clips/2971_4275_000049_000004.wav DELETED
Binary file (321 kB)
 
ref_clips/2971_4275_000050_000000.wav DELETED
Binary file (385 kB)
 
requirements.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-types==0.6.0
2
+ audioread==3.0.1
3
+ certifi==2023.11.17
4
+ cffi==1.16.0
5
+ charset-normalizer==3.3.2
6
+ decorator==5.1.1
7
+ einops==0.7.0
8
+ filelock==3.13.1
9
+ fsspec==2023.10.0
10
+ huggingface-hub==0.19.4
11
+ idna==3.4
12
+ inflect==7.0.0
13
+ Jinja2==3.1.2
14
+ joblib==1.3.2
15
+ lazy_loader==0.3
16
+ librosa==0.10.1
17
+ llvmlite==0.41.1
18
+ MarkupSafe==2.1.3
19
+ mpmath==1.3.0
20
+ msgpack==1.0.7
21
+ networkx==3.2.1
22
+ numba==0.58.1
23
+ numpy==1.26.2
24
+ packaging==23.2
25
+ platformdirs==4.0.0
26
+ pooch==1.8.0
27
+ pycparser==2.21
28
+ pydantic==2.5.1
29
+ pydantic_core==2.14.3
30
+ PyYAML==6.0.1
31
+ regex==2023.10.3
32
+ requests==2.31.0
33
+ safetensors==0.4.0
34
+ scikit-learn==1.3.2
35
+ scipy==1.11.3
36
+ soundfile==0.12.1
37
+ soxr==0.3.7
38
+ sympy==1.12
39
+ threadpoolctl==3.2.0
40
+ tokenizers==0.15.0
41
+ torch==2.1.1
42
+ tqdm==4.66.1
43
+ transformers==4.35.2
44
+ typing_extensions==4.8.0
45
+ Unidecode==1.3.7
46
+ urllib3==2.1.0
setup.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup, find_packages
3
+
4
+ __version__ = '1.0.0'
5
+ cwd = os.path.dirname(os.path.abspath(__file__))
6
+ # requirements = open(os.path.join(cwd, "requirements.txt"), "r").readlines()
7
+
8
+ setup(
9
+ name='maha_tts',
10
+ version=__version__,
11
+
12
+ url='https://github.com/dubverse-ai/MahaTTS/tree/main',
13
+ author='Dubverse AI',
14
+ author_email='jaskaran@dubverse.ai',
15
+ install_requires = [
16
+ 'einops',
17
+ 'transformers',
18
+ 'unidecode',
19
+ 'inflect'
20
+ ],
21
+ packages=find_packages(),
22
+ py_modules=['maha_tts'],
23
+ )
tts.py CHANGED
@@ -1,14 +1,16 @@
1
  import torch,glob
2
- from maha_tts import load_diffuser,load_models,infer_tts
3
  from scipy.io.wavfile import write
4
 
5
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
  print('Using:',device)
7
  text = 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition.'
8
- ref_clips = glob.glob('/Users/jaskaransingh/Desktop/NeuralSpeak/ref_clips/*.wav')
 
 
9
  # print(len(ref_clips))
10
 
11
  # diffuser = load_diffuser()
12
- diff_model,ts_model,vocoder,diffuser = load_models('Smolie',device)
13
- audio,sr = infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder)
14
  write('test.wav',sr,audio)
 
1
  import torch,glob
2
+ from maha_tts import load_diffuser,load_models,infer_tts,config
3
  from scipy.io.wavfile import write
4
 
5
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
  print('Using:',device)
7
  text = 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition.'
8
+ langauge = 'english'
9
+ language = torch.tensor(config.lang_index[langauge]).to(device).unsqueeze(0)
10
+ ref_clips = glob.glob('models/Smolie-en/ref_clips/part0_1_1/*.wav')
11
  # print(len(ref_clips))
12
 
13
  # diffuser = load_diffuser()
14
+ diff_model,ts_model,vocoder,diffuser = load_models('Smolie-in',device)
15
+ audio,sr = infer_tts(text,ref_clips,diffuser,diff_model,ts_model,vocoder,language)
16
  write('test.wav',sr,audio)