TenzinGayche commited on
Commit
a6157cf
1 Parent(s): 31ba694

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +79 -21
handler.py CHANGED
@@ -1,36 +1,94 @@
1
- from typing import Dict
2
- from transformers.pipelines.audio_utils import ffmpeg_read
 
3
  import torch
4
  import pyewts
5
- from transformers import pipeline
 
 
 
 
 
 
6
  converter = pyewts.pyewts()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- SAMPLE_RATE = 16000
9
 
10
 
11
 
12
  class EndpointHandler():
13
  def __init__(self, path=""):
14
  # load the model
15
- self.pipe = pipeline(model="TenzinGayche/whisper-small-3",chunk_length_s=30,device='cuda')
 
 
 
16
 
17
 
18
- def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
19
- """
 
20
  Args:
21
- data (:obj:):
22
- includes the deserialized audio file as bytes
23
- Return:
24
- A :obj:`dict`:. base64 encoded image
25
  """
 
 
26
  # process input
27
- inputs = data.pop("inputs", data)
28
- audio_nparray = ffmpeg_read(inputs, SAMPLE_RATE)
29
- audio_tensor= torch.from_numpy(audio_nparray)
30
- text = self.pipe(audio_tensor.numpy())["text"]
31
-
32
- # run inference pipeline
33
- result = converter.toUnicode(text)
34
-
35
- # postprocess the prediction
36
- return {"text": result}
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any,Union
2
+ import librosa
3
+ import numpy as np
4
  import torch
5
  import pyewts
6
+ import noisereduce as nr
7
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
8
+ from num2tib.core import convert
9
+ from num2tib.core import convert2text
10
+ import base64
11
+ import re
12
+ import requests
13
  converter = pyewts.pyewts()
14
+ def download_file(url, destination):
15
+ response = requests.get(url)
16
+ with open(destination, 'wb') as file:
17
+ file.write(response.content)
18
+
19
+ # Example usage:
20
+ download_file('https://huggingface.co/openpecha/speecht5-tts-01/resolve/main/female_2.npy', 'female_2.npy')
21
+ def replace_numbers_with_convert(sentence, wylie=True):
22
+ pattern = r'\d+(\.\d+)?'
23
+ def replace(match):
24
+ return convert(match.group(), wylie)
25
+ result = re.sub(pattern, replace, sentence)
26
+
27
+ return result
28
+
29
+ def cleanup_text(inputs):
30
+ for src, dst in replacements:
31
+ inputs = inputs.replace(src, dst)
32
+ return inputs
33
+
34
+ speaker_embeddings = {
35
+ "Lhasa(female)": "female_2.npy",
36
+
37
+ }
38
+
39
+ replacements = [
40
+ ('_', '_'),
41
+ ('*', 'v'),
42
+ ('`', ';'),
43
+ ('~', ','),
44
+ ('+', ','),
45
+ ('\\', ';'),
46
+ ('|', ';'),
47
+ ('╚',''),
48
+ ('╗','')
49
+ ]
50
+
51
 
 
52
 
53
 
54
 
55
  class EndpointHandler():
56
  def __init__(self, path=""):
57
  # load the model
58
+ self.processor = SpeechT5Processor.from_pretrained("TenzinGayche/TTS_run3_ep20_174k_b")
59
+ self.model = SpeechT5ForTextToSpeech.from_pretrained("TenzinGayche/TTS_run3_ep20_174k_b")
60
+ self.model.to('cuda')
61
+ self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
62
 
63
 
64
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Union[int, str]]:
65
+ """_summary_
66
+
67
  Args:
68
+ data (Dict[str, Any]): _description_
69
+
70
+ Returns:
71
+ bytes: _description_
72
  """
73
+ text = data.pop("inputs",data)
74
+
75
  # process input
76
+
77
+ if len(text.strip()) == 0:
78
+ return (16000, np.zeros(0).astype(np.int16))
79
+ text = converter.toWylie(text)
80
+ text=cleanup_text(text)
81
+ text=replace_numbers_with_convert(text)
82
+ inputs = self.processor(text=text, return_tensors="pt")
83
+ # limit input length
84
+ input_ids = inputs["input_ids"]
85
+ input_ids = input_ids[..., :self.model.config.max_text_positions]
86
+ speaker_embedding = np.load(speaker_embeddings['Lhasa(female)'])
87
+ speaker_embedding = torch.tensor(speaker_embedding)
88
+ speech = self.model.generate_speech(input_ids.to('cuda'), speaker_embedding.to('cuda'), vocoder=self.vocoder.to('cuda'))
89
+ speech = nr.reduce_noise(y=speech.to('cpu'), sr=16000)
90
+ return {
91
+ "sample_rate": 16000,
92
+ "audio": base64.b64encode(speech.tostring()).decode("utf-8"),
93
+
94
+ }