keithhon commited on
Commit
dc50595
1 Parent(s): 626f208

Upload synthesizer/inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. synthesizer/inference.py +171 -0
synthesizer/inference.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from synthesizer import audio
3
+ from synthesizer.hparams import hparams
4
+ from synthesizer.models.tacotron import Tacotron
5
+ from synthesizer.utils.symbols import symbols
6
+ from synthesizer.utils.text import text_to_sequence
7
+ from vocoder.display import simple_table
8
+ from pathlib import Path
9
+ from typing import Union, List
10
+ import numpy as np
11
+ import librosa
12
+
13
+
14
+ class Synthesizer:
15
+ sample_rate = hparams.sample_rate
16
+ hparams = hparams
17
+
18
+ def __init__(self, model_fpath: Path, verbose=True):
19
+ """
20
+ The model isn't instantiated and loaded in memory until needed or until load() is called.
21
+
22
+ :param model_fpath: path to the trained model file
23
+ :param verbose: if False, prints less information when using the model
24
+ """
25
+ self.model_fpath = model_fpath
26
+ self.verbose = verbose
27
+
28
+ # Check for GPU
29
+ if torch.cuda.is_available():
30
+ self.device = torch.device("cuda")
31
+ else:
32
+ self.device = torch.device("cpu")
33
+ if self.verbose:
34
+ print("Synthesizer using device:", self.device)
35
+
36
+ # Tacotron model will be instantiated later on first use.
37
+ self._model = None
38
+
39
+ def is_loaded(self):
40
+ """
41
+ Whether the model is loaded in memory.
42
+ """
43
+ return self._model is not None
44
+
45
+ def load(self):
46
+ """
47
+ Instantiates and loads the model given the weights file that was passed in the constructor.
48
+ """
49
+ self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
50
+ num_chars=len(symbols),
51
+ encoder_dims=hparams.tts_encoder_dims,
52
+ decoder_dims=hparams.tts_decoder_dims,
53
+ n_mels=hparams.num_mels,
54
+ fft_bins=hparams.num_mels,
55
+ postnet_dims=hparams.tts_postnet_dims,
56
+ encoder_K=hparams.tts_encoder_K,
57
+ lstm_dims=hparams.tts_lstm_dims,
58
+ postnet_K=hparams.tts_postnet_K,
59
+ num_highways=hparams.tts_num_highways,
60
+ dropout=hparams.tts_dropout,
61
+ stop_threshold=hparams.tts_stop_threshold,
62
+ speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
63
+
64
+ self._model.load(self.model_fpath)
65
+ self._model.eval()
66
+
67
+ if self.verbose:
68
+ print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
69
+
70
+ def synthesize_spectrograms(self, texts: List[str],
71
+ embeddings: Union[np.ndarray, List[np.ndarray]],
72
+ return_alignments=False):
73
+ """
74
+ Synthesizes mel spectrograms from texts and speaker embeddings.
75
+
76
+ :param texts: a list of N text prompts to be synthesized
77
+ :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
78
+ :param return_alignments: if True, a matrix representing the alignments between the
79
+ characters
80
+ and each decoder output step will be returned for each spectrogram
81
+ :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
82
+ sequence length of spectrogram i, and possibly the alignments.
83
+ """
84
+ # Load the model on the first request.
85
+ if not self.is_loaded():
86
+ self.load()
87
+
88
+ # Print some info about the model when it is loaded
89
+ tts_k = self._model.get_step() // 1000
90
+
91
+ simple_table([("Tacotron", str(tts_k) + "k"),
92
+ ("r", self._model.r)])
93
+
94
+ # Preprocess text inputs
95
+ inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
96
+ if not isinstance(embeddings, list):
97
+ embeddings = [embeddings]
98
+
99
+ # Batch inputs
100
+ batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
101
+ for i in range(0, len(inputs), hparams.synthesis_batch_size)]
102
+ batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
103
+ for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
104
+
105
+ specs = []
106
+ for i, batch in enumerate(batched_inputs, 1):
107
+ if self.verbose:
108
+ print(f"\n| Generating {i}/{len(batched_inputs)}")
109
+
110
+ # Pad texts so they are all the same length
111
+ text_lens = [len(text) for text in batch]
112
+ max_text_len = max(text_lens)
113
+ chars = [pad1d(text, max_text_len) for text in batch]
114
+ chars = np.stack(chars)
115
+
116
+ # Stack speaker embeddings into 2D array for batch processing
117
+ speaker_embeds = np.stack(batched_embeds[i-1])
118
+
119
+ # Convert to tensor
120
+ chars = torch.tensor(chars).long().to(self.device)
121
+ speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
122
+
123
+ # Inference
124
+ _, mels, alignments = self._model.generate(chars, speaker_embeddings)
125
+ mels = mels.detach().cpu().numpy()
126
+ for m in mels:
127
+ # Trim silence from end of each spectrogram
128
+ while np.max(m[:, -1]) < hparams.tts_stop_threshold:
129
+ m = m[:, :-1]
130
+ specs.append(m)
131
+
132
+ if self.verbose:
133
+ print("\n\nDone.\n")
134
+ return (specs, alignments) if return_alignments else specs
135
+
136
+ @staticmethod
137
+ def load_preprocess_wav(fpath):
138
+ """
139
+ Loads and preprocesses an audio file under the same conditions the audio files were used to
140
+ train the synthesizer.
141
+ """
142
+ wav = librosa.load(str(fpath), hparams.sample_rate)[0]
143
+ if hparams.rescale:
144
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
145
+ return wav
146
+
147
+ @staticmethod
148
+ def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
149
+ """
150
+ Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
151
+ were fed to the synthesizer when training.
152
+ """
153
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
154
+ wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
155
+ else:
156
+ wav = fpath_or_wav
157
+
158
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
159
+ return mel_spectrogram
160
+
161
+ @staticmethod
162
+ def griffin_lim(mel):
163
+ """
164
+ Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
165
+ with the same parameters present in hparams.py.
166
+ """
167
+ return audio.inv_mel_spectrogram(mel, hparams)
168
+
169
+
170
+ def pad1d(x, max_len, pad_value=0):
171
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)