ekwek commited on
Commit
a459f6e
·
verified ·
1 Parent(s): bf7fc52

Upload 10 files

Browse files
Files changed (1) hide show
  1. soprano/tts.py +26 -4
soprano/tts.py CHANGED
@@ -45,15 +45,16 @@ class SopranoTTS:
45
 
46
  self.infer("Hello world!") # warmup
47
 
48
- def _preprocess_text(self, texts):
49
  '''
50
  adds prompt format and sentence/part index
 
51
  '''
52
  res = []
53
  for text_idx, text in enumerate(texts):
54
  text = text.strip()
55
  sentences = re.split(r"(?<=[.!?])\s+", text)
56
- processed_sentences = []
57
  for sentence_idx, sentence in enumerate(sentences):
58
  old_len = len(sentence)
59
  new_sentence = re.sub(r"[^A-Za-z !\$%&'*+,-./0123456789<>?_]", "", sentence)
@@ -64,8 +65,29 @@ class SopranoTTS:
64
  if old_len != new_len:
65
  print(f"Warning: unsupported characters found in sentence: {sentence}\n\tThese characters have been removed.")
66
  new_sentence = unidecode(new_sentence.strip())
67
- processed_sentences.append((f'[STOP][TEXT]{new_sentence}[START]', text_idx, sentence_idx))
68
- res.extend(processed_sentences)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return res
70
 
71
  def infer(self,
 
45
 
46
  self.infer("Hello world!") # warmup
47
 
48
+ def _preprocess_text(self, texts, min_length=30):
49
  '''
50
  adds prompt format and sentence/part index
51
+ Enforces a minimum sentence length by merging short sentences.
52
  '''
53
  res = []
54
  for text_idx, text in enumerate(texts):
55
  text = text.strip()
56
  sentences = re.split(r"(?<=[.!?])\s+", text)
57
+ processed = []
58
  for sentence_idx, sentence in enumerate(sentences):
59
  old_len = len(sentence)
60
  new_sentence = re.sub(r"[^A-Za-z !\$%&'*+,-./0123456789<>?_]", "", sentence)
 
65
  if old_len != new_len:
66
  print(f"Warning: unsupported characters found in sentence: {sentence}\n\tThese characters have been removed.")
67
  new_sentence = unidecode(new_sentence.strip())
68
+ processed.append({
69
+ "text": new_sentence,
70
+ "text_idx": text_idx,
71
+ })
72
+
73
+ if min_length > 0 and len(processed) > 1:
74
+ merged = []
75
+ i = 0
76
+ while i < len(processed):
77
+ cur = processed[i]
78
+ if len(cur["text"]) < min_length:
79
+ if merged: merged[-1]["text"] = (merged[-1]["text"] + " " + cur["text"]).strip()
80
+ else:
81
+ if i + 1 < len(processed): processed[i + 1]["text"] = (cur["text"] + " " + processed[i + 1]["text"]).strip()
82
+ else: merged.append(cur)
83
+ else: merged.append(cur)
84
+ i += 1
85
+ processed = merged
86
+ sentence_idxes = {}
87
+ for item in processed:
88
+ if item['text_idx'] not in sentence_idxes: sentence_idxes[item['text_idx']] = 0
89
+ res.append((f'[STOP][TEXT]{item["text"]}[START]', item["text_idx"], sentence_idxes[item['text_idx']]))
90
+ sentence_idxes[item['text_idx']] += 1
91
  return res
92
 
93
  def infer(self,