ekwek commited on
Commit
44a4b98
·
verified ·
1 Parent(s): 9402b2c

Upload 11 files

Browse files
soprano/backends/lmdeploy.py CHANGED
@@ -11,11 +11,9 @@ class LMDeployModel(BaseModel):
11
  assert device == 'cuda', "lmdeploy only supports cuda devices, consider changing device or using a different backend instead."
12
  cache_size_ratio = cache_size_mb * 1024**2 / torch.cuda.get_device_properties('cuda').total_memory
13
  backend_config = TurbomindEngineConfig(cache_max_entry_count=cache_size_ratio)
14
- print("Loaded config.")
15
  self.pipeline = pipeline('ekwek/Soprano-80M',
16
  log_level='ERROR',
17
  backend_config=backend_config)
18
- print("Loaded pipeline.")
19
 
20
  def infer(self,
21
  prompts,
 
11
  assert device == 'cuda', "lmdeploy only supports cuda devices, consider changing device or using a different backend instead."
12
  cache_size_ratio = cache_size_mb * 1024**2 / torch.cuda.get_device_properties('cuda').total_memory
13
  backend_config = TurbomindEngineConfig(cache_max_entry_count=cache_size_ratio)
 
14
  self.pipeline = pipeline('ekwek/Soprano-80M',
15
  log_level='ERROR',
16
  backend_config=backend_config)
 
17
 
18
  def infer(self,
19
  prompts,
soprano/tts.py CHANGED
@@ -1,4 +1,5 @@
1
  from .vocos.decoder import SopranoDecoder
 
2
  import torch
3
  import re
4
  from unidecode import unidecode
@@ -31,9 +32,7 @@ class SopranoTTS:
31
 
32
  if backend == 'lmdeploy':
33
  from .backends.lmdeploy import LMDeployModel
34
- print("Imported lmdeploy.")
35
  self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb)
36
- print("Loaded model.")
37
  elif backend == 'transformers':
38
  from .backends.transformers import TransformersModel
39
  self.pipeline = TransformersModel(device=device)
@@ -55,20 +54,12 @@ class SopranoTTS:
55
  res = []
56
  for text_idx, text in enumerate(texts):
57
  text = text.strip()
58
- sentences = re.split(r"(?<=[.!?])\s+", text)
 
59
  processed = []
60
- for sentence_idx, sentence in enumerate(sentences):
61
- old_len = len(sentence)
62
- new_sentence = re.sub(r"[^A-Za-z !\$%&'*+,-./0123456789<>?_]", "", sentence)
63
- new_sentence = re.sub(r"[<>/_+]", "", new_sentence)
64
- new_sentence = re.sub(r"\.\.[^\.]", ".", new_sentence)
65
- new_sentence = re.sub(r"\s+", " ", new_sentence)
66
- new_len = len(new_sentence)
67
- if old_len != new_len:
68
- print(f"Warning: unsupported characters found in sentence: {sentence}\n\tThese characters have been removed.")
69
- new_sentence = unidecode(new_sentence.strip())
70
  processed.append({
71
- "text": new_sentence,
72
  "text_idx": text_idx,
73
  })
74
 
 
1
  from .vocos.decoder import SopranoDecoder
2
+ from .utils.text import clean_text
3
  import torch
4
  import re
5
  from unidecode import unidecode
 
32
 
33
  if backend == 'lmdeploy':
34
  from .backends.lmdeploy import LMDeployModel
 
35
  self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb)
 
36
  elif backend == 'transformers':
37
  from .backends.transformers import TransformersModel
38
  self.pipeline = TransformersModel(device=device)
 
54
  res = []
55
  for text_idx, text in enumerate(texts):
56
  text = text.strip()
57
+ cleaned_text = clean_text(text)
58
+ sentences = re.split(r"(?<=[.!?])\s+", cleaned_text)
59
  processed = []
60
+ for sentence in sentences:
 
 
 
 
 
 
 
 
 
61
  processed.append({
62
+ "text": sentence,
63
  "text_idx": text_idx,
64
  })
65
 
soprano/utils/text.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Normalize input text to a format that Soprano recognizes.
3
+ Adapted from https://github.com/neonbjb/tortoise-tts/blob/main/tortoise/utils/tokenizer.py
4
+ """
5
+ import os
6
+ import re
7
+
8
+ import inflect
9
+ from unidecode import unidecode
10
+
11
+
12
+ _inflect = inflect.engine()
13
+
14
+ ####################################################################################################
15
+ # Abbreviations
16
+
17
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
18
+ ('mrs', 'misuss'),
19
+ ('ms', 'miss'),
20
+ ('mr', 'mister'),
21
+ ('dr', 'doctor'),
22
+ ('st', 'saint'),
23
+ ('co', 'company'),
24
+ ('jr', 'junior'),
25
+ ('maj', 'major'),
26
+ ('gen', 'general'),
27
+ ('drs', 'doctors'),
28
+ ('rev', 'reverend'),
29
+ ('lt', 'lieutenant'),
30
+ ('hon', 'honorable'),
31
+ ('sgt', 'sergeant'),
32
+ ('capt', 'captain'),
33
+ ('esq', 'esquire'),
34
+ ('ltd', 'limited'),
35
+ ('col', 'colonel'),
36
+ ('ft', 'fort'),
37
+ ]]
38
+ _cased_abbreviations = [(re.compile('\\b%s\\b' % x[0]), x[1]) for x in [
39
+ ('TTS', 'text to speech'),
40
+ ('Hz', 'hertz'),
41
+ ('kHz', 'kilohertz'),
42
+ ('KBs', 'kilobytes'),
43
+ ('KB', 'kilobyte'),
44
+ ('MBs', 'megabytes'),
45
+ ('MB', 'megabyte'),
46
+ ('GBs', 'gigabytes'),
47
+ ('GB', 'gigabyte'),
48
+ ('TBs', 'terabytes'),
49
+ ('TB', 'terabyte'),
50
+ ('APIs', 'a p i\'s'),
51
+ ('API', 'a p i'),
52
+ ('CLIs', 'c l i\'s'),
53
+ ('CLI', 'c l i'),
54
+ ('CPUs', 'c p u\'s'),
55
+ ('CPU', 'c p u'),
56
+ ('GPUs', 'g p u\'s'),
57
+ ('GPU', 'g p u'),
58
+ ('Ave', 'avenue'),
59
+ ]]
60
+
61
+ def expand_abbreviations(text):
62
+ for regex, replacement in _abbreviations + _cased_abbreviations:
63
+ text = re.sub(regex, replacement, text)
64
+ return text
65
+
66
+ ####################################################################################################
67
+ # Numbers
68
+
69
+ _num_prefix_re = re.compile(r'#\d')
70
+ _num_suffix_re = re.compile(r'\d(K|M|B|T)', re.IGNORECASE)
71
+ _num_letter_split_re = re.compile(r'(\d[a-z]|[a-z]\d)', re.IGNORECASE)
72
+
73
+ _comma_number_re = re.compile(r'(\d[\d\,]+\d)')
74
+ _date_re = re.compile(r'(^|[^/])(\d\d?[/-]\d\d?[/-]\d\d(?:\d\d)?)($|[^/])')
75
+ _phone_number_re = re.compile(r'(\(?\d{3}\)?[-.\s]\d{3}[-.\s]?\d{4})')
76
+ _time_re = re.compile(r'(\d\d?:\d\d(?::\d\d)?)')
77
+ _pounds_re = re.compile(r'£([\d\,]*\d+)')
78
+ _dollars_re = re.compile(r'\$([\d\.\,]*\d+)')
79
+ _decimal_number_re = re.compile(r'(\d+(?:\.\d+)+)')
80
+ _multiply_re = re.compile(r'(\d\s?\*\s?\d)')
81
+ _divide_re = re.compile(r'(\d\s?/\s?\d)')
82
+ _add_re = re.compile(r'(\d\s?\+\s?\d)')
83
+ _subtract_re = re.compile(r'(\d?\s?-\s?\d)') # also does negative numbers
84
+ _fraction_re = re.compile(r'(\d+(?:/\d+)+)')
85
+ _ordinal_re = re.compile(r'\d+(st|nd|rd|th)')
86
+ _number_re = re.compile(r'\d+')
87
+
88
+ def _expand_num_prefix(m):
89
+ match = m.group(0)
90
+ return f"number {match[1]}"
91
+
92
+ def _expand_num_suffix(m):
93
+ match = m.group(0)
94
+ if match[1].upper() == 'K': return f"{match[0]} thousand"
95
+ elif match[1].upper() == 'M': return f"{match[0]} million"
96
+ elif match[1].upper() == 'B': return f"{match[0]} billion"
97
+ elif match[1].upper() == 'T': return f"{match[0]} trillion"
98
+ return match # unexpected format
99
+
100
+ def _split_alphanumeric(m):
101
+ match = m.group(1)
102
+ return f"{match[0]} {match[1]}"
103
+
104
+ def _remove_commas(m):
105
+ return m.group(1).replace(',', '')
106
+
107
+ def _expand_date(m):
108
+ match = m.group(2)
109
+ match = re.split('[./-]', match)
110
+ return m.group(1) + ' dash '.join(match) + m.group(3)
111
+
112
+ def _expand_phone_number(m):
113
+ match = m.group(1)
114
+ match = re.sub(r'\D', '', match)
115
+ assert len(match) == 10
116
+ match = f"{' '.join(list(match[:3]))}, {' '.join(list(match[3:6]))}, {' '.join(list(match[6:]))}"
117
+ return match
118
+
119
+ def _expand_time(m):
120
+ match = m.group(1)
121
+ match = match.split(':')
122
+ if len(match) == 2:
123
+ hours, minutes = match
124
+ if minutes == '00':
125
+ if int(hours) == 0:
126
+ return '0'
127
+ elif int(hours) > 12: return f"{hours} minutes"
128
+ return f"{hours} o'clock"
129
+ elif minutes.startswith('0'):
130
+ minutes = f'oh {minutes[1:]}'
131
+ return f"{hours} {minutes}"
132
+ else:
133
+ hours, minutes, seconds = match
134
+ if int(hours) != 0:
135
+ return f"{hours} {'oh oh' if minutes == '00' else f'oh {minutes}' if minutes.startswith('0') else {minutes}} {'' if seconds == '00' else f'oh {seconds}' if seconds.startswith('0') else seconds}"
136
+ elif minutes != '00':
137
+ return f"{minutes} {'oh oh' if seconds == '00' else f'oh {seconds}' if seconds.startswith('0') else seconds}"
138
+ else:
139
+ return seconds
140
+
141
+ def _expand_dollars(m):
142
+ match = m.group(1)
143
+ parts = match.split('.')
144
+ if len(parts) > 2:
145
+ return match + ' dollars' # Unexpected format
146
+ dollars = int(parts[0]) if parts[0] else 0
147
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
148
+ if dollars and cents:
149
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
150
+ cent_unit = 'cent' if cents == 1 else 'cents'
151
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
152
+ elif dollars:
153
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
154
+ return '%s %s' % (dollars, dollar_unit)
155
+ elif cents:
156
+ cent_unit = 'cent' if cents == 1 else 'cents'
157
+ return '%s %s' % (cents, cent_unit)
158
+ else:
159
+ return 'zero dollars'
160
+
161
+ def _expand_decimal_point(m):
162
+ match = m.group(1)
163
+ match = match.split('.')
164
+ return match[0] + ' point ' + ' point '.join(' '.join(list(match[i])) for i in range(1, len(match)))
165
+
166
+ def _expand_fraction(m):
167
+ match = m.group(1)
168
+ match = match.split('/')
169
+ return ' over '.join(match) if len(match)==2 else ' slash '.join(match)
170
+
171
+ def _expand_multiply(m):
172
+ return ' times '.join(m.group(1).split('*'))
173
+
174
+ def _expand_divide(m):
175
+ return ' over '.join(m.group(1).split('/'))
176
+
177
+ def _expand_add(m):
178
+ return ' plus '.join(m.group(1).split('+'))
179
+
180
+ def _expand_subtract(m):
181
+ return ' minus '.join(m.group(1).split('-'))
182
+
183
+ def _expand_ordinal(m):
184
+ return _inflect.number_to_words(m.group(0), andword='')
185
+
186
+ def _expand_number(m):
187
+ num = int(m.group(0))
188
+ if num > 1000 and num < 3000:
189
+ if num == 2000:
190
+ return 'two thousand'
191
+ elif num > 2000 and num < 2010:
192
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
193
+ elif num % 100 == 0:
194
+ return _inflect.number_to_words(num // 100) + ' hundred'
195
+ else:
196
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
197
+ else:
198
+ return _inflect.number_to_words(num, andword='')
199
+
200
+ def normalize_numbers(text):
201
+ text = re.sub(_num_prefix_re, _expand_num_prefix, text)
202
+ text = re.sub(_num_suffix_re, _expand_num_suffix, text)
203
+ for _ in range(2): # need to do this twice to find all matches
204
+ text = re.sub(_num_letter_split_re, _split_alphanumeric, text)
205
+ text = re.sub(_comma_number_re, _remove_commas, text)
206
+ text = re.sub(_date_re, _expand_date, text)
207
+ text = re.sub(_phone_number_re, _expand_phone_number, text)
208
+ text = re.sub(_time_re, _expand_time, text)
209
+ text = re.sub(_pounds_re, r'\1 pounds', text)
210
+ text = re.sub(_dollars_re, _expand_dollars, text)
211
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
212
+ text = re.sub(_multiply_re, _expand_multiply, text)
213
+ text = re.sub(_divide_re, _expand_divide, text)
214
+ text = re.sub(_add_re, _expand_add, text)
215
+ text = re.sub(_subtract_re, _expand_subtract, text)
216
+
217
+ text = re.sub(_fraction_re, _expand_fraction, text)
218
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
219
+ text = re.sub(_number_re, _expand_number, text)
220
+ return text
221
+
222
+ ####################################################################################################
223
+ # Special characters & other patterns
224
+
225
+ _special_characters = [(re.compile(x[0]), x[1]) for x in [
226
+ ('@', ' at '),
227
+ ('&', ' and '),
228
+ ('%', ' percent '),
229
+ (':', '.'),
230
+ (';', ','),
231
+ (r'\+', ' plus '),
232
+ (r'\\', ' backslash '),
233
+ ('~', ' about '),
234
+ ('(^| )<3', ' heart '),
235
+ ('<=', ' less than or equal to '),
236
+ ('>=', ' greater than or equal to '),
237
+ ('<', ' less than '),
238
+ ('>', ' greater than '),
239
+ ('=', ' equals '),
240
+ ('/', ' slash '),
241
+ ('_', ' '),
242
+ ]]
243
+ _link_header_re = re.compile(r'(https?://)')
244
+ _dash_re = re.compile(r'(. - .)')
245
+ _dot_re = re.compile(r'([A-Z]\.[A-Z])', re.IGNORECASE)
246
+ _parentheses_re = re.compile(r'[\(\[\{].*[\)\]\}](.|$)')
247
+
248
+ def expand_special_characters(text):
249
+ for regex, replacement in _special_characters:
250
+ text = re.sub(regex, replacement, text)
251
+ return text
252
+
253
+ def _expand_link_header(m):
254
+ return 'h t t p s colon slash slash '
255
+
256
+ def _expand_dash(m):
257
+ match = m.group(0)
258
+ return f"{match[0]}, {match[4]}"
259
+
260
+ def _expand_dot(m):
261
+ match = m.group(0)
262
+ return f"{match[0]} dot {match[2]}"
263
+
264
+ def _expand_parantheses(m):
265
+ match = m.group(0)
266
+ match = re.sub(r'[\(\[\{]', ', ', match)
267
+ match = re.sub(r'[\)\]\}][^$.!?,]', ', ', match)
268
+ match = re.sub(r'[\)\]\}]', '', match)
269
+ return match
270
+
271
+ def normalize_special(text):
272
+ text = re.sub(_link_header_re, _expand_link_header, text)
273
+ text = re.sub(_dash_re, _expand_dash, text)
274
+ text = re.sub(_dot_re, _expand_dot, text)
275
+ text = re.sub(_parentheses_re, _expand_parantheses, text)
276
+ return text
277
+
278
+ ####################################################################################################
279
+ # Misc
280
+
281
+ def lowercase(text):
282
+ return text.lower()
283
+
284
+ def convert_to_ascii(text):
285
+ return unidecode(text)
286
+
287
+ def normalize_newlines(text):
288
+ text = text.split('\n')
289
+ for i in range(len(text)):
290
+ if not text[i]: continue
291
+ text[i] = text[i].strip()
292
+ if text[i][-1] not in '.!?':
293
+ text[i] = f"{text[i]}."
294
+ return ' '.join(text)
295
+
296
+ def remove_unknown_characters(text):
297
+ text = re.sub(r"[^A-Za-z !\$%&'\*\+,-./0123456789<>\?_]", "", text)
298
+ text = re.sub(r"[<>/_+]", "", text)
299
+ return text
300
+
301
+ def collapse_whitespace(text):
302
+ text = re.sub(r'\s+', ' ', text)
303
+ text = re.sub(r' [.\?!,]', lambda m: m.group(0)[1], text)
304
+ return text
305
+
306
+ def dedup_punctuation(text):
307
+ text = re.sub(r"\.\.\.+", "[ELLIPSIS]", text)
308
+ text = re.sub(r",+", ",", text)
309
+ text = re.sub(r"[\.,]*\.[\.,]*", ".", text)
310
+ text = re.sub(r"[\.,!]*![\.,!]*", "!", text)
311
+ text = re.sub(r"[\.,!\?]*\?[\.,!\?]*", "?", text)
312
+ text = re.sub("[ELLIPSIS]", "...", text)
313
+ return text
314
+
315
+ def clean_text(text):
316
+ text = convert_to_ascii(text)
317
+ text = normalize_newlines(text)
318
+ text = normalize_numbers(text)
319
+ text = normalize_special(text)
320
+ text = expand_abbreviations(text)
321
+ text = expand_special_characters(text)
322
+ text = lowercase(text)
323
+ text = remove_unknown_characters(text)
324
+ text = collapse_whitespace(text)
325
+ text = dedup_punctuation(text)
326
+ return text
327
+
328
+
329
+ if __name__ == '__main__':
330
+ print(normalize_numbers('1,2,3,456,176'))
331
+ print(normalize_numbers('123,456,789'))
332
+ print(normalize_numbers('123,456,789th'))
333
+ print(normalize_numbers('123-456-7890'))
334
+ print(normalize_numbers('111-111-1111'))
335
+ print(normalize_numbers('(111) 111-1111'))
336
+ print(normalize_numbers('A(111) 111-1111'))
337
+ print(normalize_numbers('A (111) 111-1111'))
338
+ print(normalize_numbers('$2.47'))
339
+ print(normalize_numbers('$247'))
340
+ print(normalize_numbers('$0.27'))
341
+ print(normalize_numbers('$1.00'))
342
+ print(normalize_numbers('£20'))
343
+ for i in range(1990, 2030):
344
+ print(normalize_numbers(str(i)))
345
+ print(normalize_numbers('2656'))
346
+ print(normalize_numbers('1024'))
347
+ print(normalize_numbers('2.47023'))
348
+ print(normalize_numbers('20.47023'))
349
+ print(normalize_numbers('1.17.1.1'))
350
+ print(normalize_numbers('111.111.1111'))
351
+ print(normalize_numbers('1/1/2025'))
352
+ print(normalize_numbers('1-1-2025'))
353
+ print(normalize_numbers('1-1-25'))
354
+ print(normalize_numbers('A 1/1/11 A'))
355
+ print(normalize_numbers('A 1/1 A'))
356
+ print(normalize_numbers('1/1'))
357
+ print(normalize_numbers('1/10'))
358
+ print(normalize_numbers('1/1/10'))
359
+ print(normalize_numbers('11/1/1/10'))
360
+
361
+ print(normalize_numbers('0:00'))
362
+ print(normalize_numbers('12:00'))
363
+ print(normalize_numbers('13:00'))
364
+ print(normalize_numbers('8:00'))
365
+ print(normalize_numbers('8:05'))
366
+ print(normalize_numbers('8:15'))
367
+ print(normalize_numbers('0:00:00'))
368
+ print(normalize_numbers('00:01:10'))
369
+ print(normalize_numbers('00:10:01'))
370
+ print(normalize_numbers('01:01:01'))
371
+ print(normalize_numbers('00:01:00'))
372
+ print(normalize_numbers('01:00:00'))
373
+
374
+ print(normalize_numbers('-1 + 2 * 3 - 4 / 5'))
375
+ print(normalize_numbers('-1+2*3-5/4/25'))
376
+
377
+ print(normalize_numbers('100x1'))
378
+ print(normalize_numbers('100k'))
379
+ print(normalize_numbers('100m'))
380
+ print(normalize_numbers('100b'))
381
+ print(normalize_numbers('100t'))
382
+
383
+ print(normalize_numbers('#1'))
384
+
385
+ print(normalize_numbers('12:00'))
386
+ print(normalize_numbers('11:59'))
387
+ print(normalize_numbers('01:00'))
388
+ print(normalize_numbers('0100'))