MoritzLaurer HF staff commited on
Commit
479ac18
·
verified ·
1 Parent(s): 9403396

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -7
handler.py CHANGED
@@ -1,14 +1,34 @@
1
  from typing import Dict, List, Any
2
  from parler_tts import ParlerTTSForConditionalGeneration
3
- from transformers import AutoTokenizer
 
4
  import torch
 
 
 
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
  # load model and processor from path
9
  self.tokenizer = AutoTokenizer.from_pretrained(path)
10
- self.model = ParlerTTSForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
13
  """
14
  Args:
@@ -24,19 +44,27 @@ class EndpointHandler:
24
  if parameters is not None:
25
  gen_kwargs.update(parameters)
26
 
 
 
27
  # preprocess
28
  inputs = self.tokenizer(
29
- text=[inputs],
30
  padding=True,
31
- return_tensors="pt",).to("cuda")
 
32
  voice_description = self.tokenizer(
33
  text=[voice_description],
34
  padding=True,
35
- return_tensors="pt",).to("cuda")
 
36
 
37
  # pass inputs with all kwargs in data
38
- with torch.autocast("cuda"):
39
- outputs = self.model.generate(**voice_description, prompt_input_ids=inputs.input_ids, prompt_attention_mask=inputs.attention_mask, **gen_kwargs)
 
 
 
 
40
 
41
  # postprocess the prediction
42
  prediction = outputs[0].cpu().numpy().tolist()
 
1
  from typing import Dict, List, Any
2
  from parler_tts import ParlerTTSForConditionalGeneration
3
+ from transformers import AutoTokenizer, AutoFeatureExtractor
4
+ from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer
5
  import torch
6
+ import re
7
+ from string import punctuation
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
  # load model and processor from path
14
  self.tokenizer = AutoTokenizer.from_pretrained(path)
15
+ #self.feature_extractor = AutoFeatureExtractor.from_pretrained(path)
16
+ self.model = ParlerTTSForConditionalGeneration.from_pretrained(path).to(device) #torch_dtype=torch.float16
17
 
18
+ def preprocess_text(self, text):
19
+ """Implement the same preprocessing as the Gradio app"""
20
+ text = self.number_normalizer(text).strip()
21
+ text = text.replace("-", " ")
22
+ if text[-1] not in punctuation:
23
+ text = f"{text}."
24
+
25
+ abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
26
+ abbreviations = re.findall(abbreviations_pattern, text)
27
+ for abv in abbreviations:
28
+ if abv in text:
29
+ text = text.replace(abv, " ".join(abv.replace(".","")))
30
+ return text
31
+
32
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
33
  """
34
  Args:
 
44
  if parameters is not None:
45
  gen_kwargs.update(parameters)
46
 
47
+ processed_text = self.preprocess_text(inputs)
48
+
49
  # preprocess
50
  inputs = self.tokenizer(
51
+ text=[processed_text],
52
  padding=True,
53
+ return_tensors="pt",
54
+ ).to(device)
55
  voice_description = self.tokenizer(
56
  text=[voice_description],
57
  padding=True,
58
+ return_tensors="pt",
59
+ ).to(device)
60
 
61
  # pass inputs with all kwargs in data
62
+ with torch.autocast(device):
63
+ outputs = self.model.generate(
64
+ **voice_description, prompt_input_ids=inputs.input_ids,
65
+ prompt_attention_mask=inputs.attention_mask, attention_mask=inputs.attention_mask,
66
+ **parameters
67
+ )
68
 
69
  # postprocess the prediction
70
  prediction = outputs[0].cpu().numpy().tolist()