theodotus commited on
Commit
9b0dd40
1 Parent(s): 879beec

Added inference pipeline

Browse files
Files changed (2) hide show
  1. README.md +3 -0
  2. pipeline.py +36 -0
README.md CHANGED
@@ -1,3 +1,6 @@
1
  ---
 
 
 
2
  license: bsd-3-clause
3
  ---
 
1
  ---
2
+ tags:
3
+ - text-to-speech
4
+ library_name: generic
5
  license: bsd-3-clause
6
  ---
pipeline.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ from torch import no_grad, package
4
+ import numpy as np
5
+ import os
6
+
7
+
8
+
9
+
10
+
11
+ class PreTrainedPipeline():
12
+ def __init__(self, path: str):
13
+ model_path = os.path.join(path, "model.pt")
14
+ importer = package.PackageImporter(model_path)
15
+ synt = importer.load_pickle("tts_models", "model")
16
+ self.synt = synt
17
+
18
+ self.tts_kwargs = {
19
+ "speaker_name": "uk",
20
+ "language_name": "uk",
21
+ }
22
+
23
+ self.sampling_rate = self.synt.output_sample_rate
24
+
25
+ def __call__(self, inputs: str) -> Tuple[np.array, int]:
26
+ """
27
+ Args:
28
+ inputs (:obj:`str`):
29
+ The text to generate audio from
30
+ Return:
31
+ A :obj:`np.array` and a :obj:`int`: The raw waveform as a numpy array, and the sampling rate as an int.
32
+ """
33
+ with no_grad():
34
+ waveforms = self.synt.tts(inputs, **self.tts_kwargs)
35
+ waveforms = np.array(waveforms, dtype=np.float32)
36
+ return waveforms, self.sampling_rate