jbetker commited on
Commit
17af2df
1 Parent(s): 8215af8

support presets for generation

Browse files
Files changed (3) hide show
  1. api.py +15 -0
  2. eval_multiple.py +3 -7
  3. read.py +3 -2
api.py CHANGED
@@ -160,6 +160,21 @@ class TextToSpeech:
160
  self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
161
  self.vocoder.eval(inference=True)
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def tts(self, text, voice_samples, k=1,
164
  # autoregressive generation parameters follow
165
  num_autoregressive_samples=512, temperature=.5, length_penalty=1, repetition_penalty=2.0, top_p=.5,
 
160
  self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
161
  self.vocoder.eval(inference=True)
162
 
163
+ def tts_with_preset(self, text, voice_samples, preset='intelligible', **kwargs):
164
+ """
165
+ Calls TTS with one of a set of preset generation parameters. Options:
166
+ 'intelligible': Maximizes the probability of understandable words at the cost of diverse voices, intonation and prosody.
167
+ 'realistic': Increases the diversity of spoken voices and improves realism of vocal characteristics at the cost of intelligibility.
168
+ 'mid': Somewhere between 'intelligible' and 'realistic'.
169
+ """
170
+ presets = {
171
+ 'intelligible': {'temperature': .5, 'length_penalty': 2.0, 'repetition_penalty': 2.0, 'top_p': .5, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': .7, 'diffusion_temperature': .7},
172
+ 'mid': {'temperature': .7, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .7, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 1.5, 'diffusion_temperature': .8},
173
+ 'realistic': {'temperature': .9, 'length_penalty': 1.0, 'repetition_penalty': 1.3, 'top_p': .9, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 2, 'diffusion_temperature': 1},
174
+ }
175
+ kwargs.update(presets[preset])
176
+ return self.tts(text, voice_samples, **kwargs)
177
+
178
  def tts(self, text, voice_samples, k=1,
179
  # autoregressive generation parameters follow
180
  num_autoregressive_samples=512, temperature=.5, length_penalty=1, repetition_penalty=2.0, top_p=.5,
eval_multiple.py CHANGED
@@ -7,7 +7,7 @@ from utils.audio import load_audio
7
 
8
  if __name__ == '__main__':
9
  fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
10
- outpath = 'D:\\tmp\\tortoise-tts-eval\\compare_vocoders'
11
  outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
12
 
13
  os.makedirs(outpath, exist_ok=True)
@@ -24,16 +24,12 @@ if __name__ == '__main__':
24
  path = os.path.join(os.path.dirname(fname), line[1])
25
  cond_audio = load_audio(path, 22050)
26
  torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
27
- sample, sample2 = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=512, k=1,
28
  repetition_penalty=2.0, length_penalty=2, temperature=.5, top_p=.5,
29
  diffusion_temperature=.7, cond_free_k=2, diffusion_iterations=200)
30
 
31
  down = torchaudio.functional.resample(sample, 24000, 22050)
32
- fout_path = os.path.join(outpath, 'old', os.path.basename(line[1]))
33
- torchaudio.save(fout_path, down.squeeze(0), 22050)
34
-
35
- down = torchaudio.functional.resample(sample2, 24000, 22050)
36
- fout_path = os.path.join(outpath, 'new', os.path.basename(line[1]))
37
  torchaudio.save(fout_path, down.squeeze(0), 22050)
38
 
39
  recorder.write(f'{transcript}\t{fout_path}\n')
 
7
 
8
  if __name__ == '__main__':
9
  fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
10
+ outpath = 'D:\\tmp\\tortoise-tts-eval\\eval_new_autoregressive'
11
  outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
12
 
13
  os.makedirs(outpath, exist_ok=True)
 
24
  path = os.path.join(os.path.dirname(fname), line[1])
25
  cond_audio = load_audio(path, 22050)
26
  torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
27
+ sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=512, k=1,
28
  repetition_penalty=2.0, length_penalty=2, temperature=.5, top_p=.5,
29
  diffusion_temperature=.7, cond_free_k=2, diffusion_iterations=200)
30
 
31
  down = torchaudio.functional.resample(sample, 24000, 22050)
32
+ fout_path = os.path.join(outpath, os.path.basename(line[1]))
 
 
 
 
33
  torchaudio.save(fout_path, down.squeeze(0), 22050)
34
 
35
  recorder.write(f'{transcript}\t{fout_path}\n')
read.py CHANGED
@@ -48,9 +48,10 @@ if __name__ == '__main__':
48
  parser = argparse.ArgumentParser()
49
  parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
50
  parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice')
51
- parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=256)
52
  parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
53
  parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/')
 
54
  args = parser.parse_args()
55
  os.makedirs(args.output_path, exist_ok=True)
56
 
@@ -67,7 +68,7 @@ if __name__ == '__main__':
67
  for cond_path in cond_paths:
68
  c = load_audio(cond_path, 22050)
69
  conds.append(c)
70
- gen = tts.tts(text, conds, num_autoregressive_samples=args.num_samples, temperature=.7, top_p=.7)
71
  torchaudio.save(os.path.join(args.output_path, f'{j}.wav'), gen.squeeze(0).cpu(), 24000)
72
 
73
  priors.append(torchaudio.functional.resample(gen, 24000, 22050).squeeze(0))
 
48
  parser = argparse.ArgumentParser()
49
  parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
50
  parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice')
51
+ parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
52
  parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
53
  parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/')
54
+ parser.add_argument('-generation_preset', type=str, help='Preset to use for generation', default='intelligible')
55
  args = parser.parse_args()
56
  os.makedirs(args.output_path, exist_ok=True)
57
 
 
68
  for cond_path in cond_paths:
69
  c = load_audio(cond_path, 22050)
70
  conds.append(c)
71
+ gen = tts.tts_with_preset(text, conds, preset=args.generation_preset, num_autoregressive_samples=args.num_samples)
72
  torchaudio.save(os.path.join(args.output_path, f'{j}.wav'), gen.squeeze(0).cpu(), 24000)
73
 
74
  priors.append(torchaudio.functional.resample(gen, 24000, 22050).squeeze(0))