jbetker commited on
Commit
ab9b68f
1 Parent(s): 01877d2

add option to specify model directory to API

Browse files
Files changed (1) hide show
  1. api.py +21 -16
api.py CHANGED
@@ -170,35 +170,40 @@ class TextToSpeech:
170
  :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
171
  GPU OOM errors. Larger numbers generates slightly faster.
172
  """
173
- def __init__(self, autoregressive_batch_size=16):
174
  self.autoregressive_batch_size = autoregressive_batch_size
175
  self.tokenizer = VoiceBpeTokenizer()
176
  download_models()
177
 
178
- self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
179
- model_dim=1024,
180
- heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
181
- train_solo_embeddings=False,
182
- average_conditioning_embeddings=True).cpu().eval()
183
- self.autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
 
 
 
 
 
 
 
 
 
 
184
 
185
  self.clvp = CLVP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
186
  text_seq_len=350, text_heads=8,
187
  num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
188
  use_xformers=True).cpu().eval()
189
- self.clvp.load_state_dict(torch.load('.models/clvp.pth'))
190
 
191
  self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
192
  speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
193
- self.cvvp.load_state_dict(torch.load('.models/cvvp.pth'))
194
-
195
- self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
196
- in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
197
- layer_drop=0, unconditioned_percentage=0).cpu().eval()
198
- self.diffusion.load_state_dict(torch.load('.models/diffusion_decoder.pth'))
199
 
200
  self.vocoder = UnivNetGenerator().cpu()
201
- self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
202
  self.vocoder.eval(inference=True)
203
 
204
  def tts_with_preset(self, text, voice_samples, preset='fast', **kwargs):
@@ -216,7 +221,7 @@ class TextToSpeech:
216
  'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
217
  # Presets are defined here.
218
  presets = {
219
- 'ultra_fast': {'num_autoregressive_samples': 32, 'diffusion_iterations': 16, 'cond_free': False},
220
  'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 32},
221
  'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 128},
222
  'high_quality': {'num_autoregressive_samples': 512, 'diffusion_iterations': 1024},
 
170
  :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
171
  GPU OOM errors. Larger numbers generates slightly faster.
172
  """
173
+ def __init__(self, autoregressive_batch_size=16, models_dir='.models'):
174
  self.autoregressive_batch_size = autoregressive_batch_size
175
  self.tokenizer = VoiceBpeTokenizer()
176
  download_models()
177
 
178
+ if os.path.exists(f'{models_dir}/autoregressive.ptt'):
179
+ # Assume this is a traced directory.
180
+ self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
181
+ self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt')
182
+ else:
183
+ self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
184
+ model_dim=1024,
185
+ heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
186
+ train_solo_embeddings=False,
187
+ average_conditioning_embeddings=True).cpu().eval()
188
+ self.autoregressive.load_state_dict(torch.load(f'{models_dir}/autoregressive.pth'))
189
+
190
+ self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
191
+ in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
192
+ layer_drop=0, unconditioned_percentage=0).cpu().eval()
193
+ self.diffusion.load_state_dict(torch.load(f'{models_dir}/diffusion_decoder.pth'))
194
 
195
  self.clvp = CLVP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
196
  text_seq_len=350, text_heads=8,
197
  num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
198
  use_xformers=True).cpu().eval()
199
+ self.clvp.load_state_dict(torch.load(f'{models_dir}/clvp.pth'))
200
 
201
  self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
202
  speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
203
+ self.cvvp.load_state_dict(torch.load(f'{models_dir}/cvvp.pth'))
 
 
 
 
 
204
 
205
  self.vocoder = UnivNetGenerator().cpu()
206
+ self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g'])
207
  self.vocoder.eval(inference=True)
208
 
209
  def tts_with_preset(self, text, voice_samples, preset='fast', **kwargs):
 
221
  'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
222
  # Presets are defined here.
223
  presets = {
224
+ 'ultra_fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 32, 'cond_free': False},
225
  'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 32},
226
  'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 128},
227
  'high_quality': {'num_autoregressive_samples': 512, 'diffusion_iterations': 1024},