Add Number Normalization and other fix

#8
by ylacombe HF staff - opened
Files changed (1) hide show
  1. app.py +27 -3
app.py CHANGED
@@ -1,6 +1,10 @@
1
  import spaces
2
  import gradio as gr
3
  import torch
 
 
 
 
4
 
5
  from parler_tts import ParlerTTSForConditionalGeneration
6
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
@@ -38,11 +42,31 @@ examples = [
38
  ],
39
  ]
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- @spaces.GPU
43
  def gen_tts(text, description):
44
  inputs = tokenizer(description, return_tensors="pt").to(device)
45
- prompt = tokenizer(text, return_tensors="pt").to(device)
46
 
47
  set_seed(SEED)
48
  generation = model.generate(
@@ -145,4 +169,4 @@ with gr.Blocks(css=css) as block:
145
  )
146
 
147
  block.queue()
148
- block.launch(share=True)
 
1
  import spaces
2
  import gradio as gr
3
  import torch
4
+ from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer
5
+ from string import punctuation
6
+ import re
7
+
8
 
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
 
42
  ],
43
  ]
44
 
45
+ number_normalizer = EnglishNumberNormalizer()
46
+
47
+ def preprocess(text):
48
+ text = number_normalizer(text).strip()
49
+ text = text.replace("-", " ")
50
+ if text[-1] not in punctuation:
51
+ text = f"{text}."
52
+
53
+ abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
54
+
55
+ def separate_abb(chunk):
56
+ chunk = chunk.replace(".","")
57
+ print(chunk)
58
+ return " ".join(chunk)
59
+
60
+ abbreviations = re.findall(abbreviations_pattern, text)
61
+ for abv in abbreviations:
62
+ if abv in text:
63
+ text = text.replace(abv, separate_abb(abv))
64
+ return text
65
+
66
 
 
67
  def gen_tts(text, description):
68
  inputs = tokenizer(description, return_tensors="pt").to(device)
69
+ prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
70
 
71
  set_seed(SEED)
72
  generation = model.generate(
 
169
  )
170
 
171
  block.queue()
172
+ block.launch(share=True)