zhzluke96 commited on
Commit
f367757
·
1 Parent(s): bb4ceb3
modules/SentenceSplitter.py CHANGED
@@ -2,87 +2,95 @@ import re
2
 
3
  import zhon
4
 
 
5
  from modules.utils.detect_lang import guess_lang
6
 
7
 
8
- def split_zhon_sentence(text):
9
- result = []
10
- pattern = re.compile(zhon.hanzi.sentence)
11
- start = 0
12
- for match in pattern.finditer(text):
13
- # 获取匹配的中文句子
14
- end = match.end()
15
- result.append(text[start:end])
16
- start = end
17
-
18
- # 最后一个中文句子后面的内容(如果有)也需要添加到结果中
19
- if start < len(text):
20
- result.append(text[start:])
21
-
22
- result = [t for t in result if t.strip()]
23
- return result
24
-
25
-
26
- def split_en_sentence(text):
27
- """
28
- Split English text into sentences.
29
- """
30
- # Define a regex pattern for English sentence splitting
31
- pattern = re.compile(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s")
32
- result = pattern.split(text)
33
-
34
- # Filter out any empty strings or strings that are just whitespace
35
- result = [sentence.strip() for sentence in result if sentence.strip()]
36
-
37
- return result
38
 
 
 
 
 
39
 
40
- def is_eng_sentence(text):
41
- return guess_lang(text) == "en"
42
 
 
 
43
 
44
- def split_zhon_paragraph(text):
45
- lines = text.split("\n")
46
- result = []
47
- for line in lines:
48
- if is_eng_sentence(line):
49
- result.extend(split_en_sentence(line))
50
- else:
51
- result.extend(split_zhon_sentence(line))
52
- return result
53
 
 
54
 
55
- # 解析文本 并根据停止符号分割成句子
56
- # 可以设置最大阈值,即如果分割片段小于这个阈值会与下一段合并
57
- class SentenceSplitter:
58
- def __init__(self, threshold=100):
59
- self.sentence_threshold = threshold
60
 
61
- def parse(self, text):
62
- sentences = split_zhon_paragraph(text)
63
-
64
- # 合并小于最大阈值的片段
65
- merged_sentences = []
66
- temp_sentence = []
67
- for sentence in sentences:
68
- if len(sentence) < self.sentence_threshold:
69
- temp_sentence.extend(sentence)
70
- if len(temp_sentence) >= self.sentence_threshold:
71
- merged_sentences.append(temp_sentence)
72
- temp_sentence = []
73
  else:
74
- if temp_sentence:
75
- merged_sentences.append(temp_sentence)
76
- temp_sentence = []
77
- merged_sentences.append(sentence)
78
 
79
  if temp_sentence:
80
  merged_sentences.append(temp_sentence)
81
-
82
- joind_sentences = [
83
- "".join(sentence) for sentence in merged_sentences if sentence
84
- ]
85
- return joind_sentences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  if __name__ == "__main__":
 
2
 
3
  import zhon
4
 
5
+ from modules.models import get_tokenizer
6
  from modules.utils.detect_lang import guess_lang
7
 
8
 
9
+ # 解析文本 并根据停止符号分割成句子
10
+ # 可以设置最大阈值,即如果分割片段小于这个阈值会与下一段合并
11
+ class SentenceSplitter:
12
+ SEP_TOKEN = " "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def __init__(self, threshold=100):
15
+ assert (
16
+ isinstance(threshold, int) and threshold > 0
17
+ ), "Threshold must be greater than 0."
18
 
19
+ self.sentence_threshold = threshold
20
+ self.tokenizer = get_tokenizer()
21
 
22
+ def count_tokens(self, text: str):
23
+ return len(self.tokenizer.tokenize(text))
24
 
25
+ def parse(self, text: str):
26
+ sentences = self.split_paragraph(text)
27
+ sentences = self.merge_text_by_threshold(sentences)
 
 
 
 
 
 
28
 
29
+ return sentences
30
 
31
+ def merge_text_by_threshold(self, setences: list[str]):
32
+ """
33
+ Merge text by threshold.
 
 
34
 
35
+ If the length of the text is less than the threshold, merge it with the next text.
36
+ """
37
+ merged_sentences: list[str] = []
38
+ temp_sentence = ""
39
+ for sentence in setences:
40
+ if len(temp_sentence) + len(sentence) < self.sentence_threshold:
41
+ temp_sentence += SentenceSplitter.SEP_TOKEN + sentence
 
 
 
 
 
42
  else:
43
+ merged_sentences.append(temp_sentence)
44
+ temp_sentence = sentence
 
 
45
 
46
  if temp_sentence:
47
  merged_sentences.append(temp_sentence)
48
+ return merged_sentences
49
+
50
+ def split_paragraph(self, text: str):
51
+ """
52
+ Split text into sentences.
53
+ """
54
+ lines = text.split("\n")
55
+ sentences: list[str] = []
56
+ for line in lines:
57
+ if self.is_eng_sentence(line):
58
+ sentences.extend(self.split_en_sentence(line))
59
+ else:
60
+ sentences.extend(self.split_zhon_sentence(line))
61
+ return sentences
62
+
63
+ def is_eng_sentence(self, text: str):
64
+ return guess_lang(text) == "en"
65
+
66
+ def split_en_sentence(self, text: str):
67
+ """
68
+ Split English text into sentences.
69
+ """
70
+ pattern = re.compile(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s")
71
+ sentences = pattern.split(text)
72
+
73
+ sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
74
+
75
+ return sentences
76
+
77
+ def split_zhon_sentence(self, text: str):
78
+ """
79
+ Split Chinese text into sentences.
80
+ """
81
+ sentences: list[str] = []
82
+ pattern = re.compile(zhon.hanzi.sentence)
83
+ start = 0
84
+ for match in pattern.finditer(text):
85
+ end = match.end()
86
+ sentences.append(text[start:end])
87
+ start = end
88
+
89
+ if start < len(text):
90
+ sentences.append(text[start:])
91
+
92
+ sentences = [t for t in sentences if t.strip()]
93
+ return sentences
94
 
95
 
96
  if __name__ == "__main__":
modules/models.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
  import threading
4
 
5
  import torch
 
6
 
7
  from modules import config
8
  from modules.ChatTTS import ChatTTS
@@ -76,3 +77,9 @@ def reload_chat_tts():
76
  instance = load_chat_tts()
77
  logger.info("ChatTTS models reloaded")
78
  return instance
 
 
 
 
 
 
 
3
  import threading
4
 
5
  import torch
6
+ from transformers import LlamaTokenizer
7
 
8
  from modules import config
9
  from modules.ChatTTS import ChatTTS
 
77
  instance = load_chat_tts()
78
  logger.info("ChatTTS models reloaded")
79
  return instance
80
+
81
+
82
+ def get_tokenizer() -> LlamaTokenizer:
83
+ chat_tts = load_chat_tts()
84
+ tokenizer = chat_tts.pretrain_models["tokenizer"]
85
+ return tokenizer
modules/utils/audio.py CHANGED
@@ -2,9 +2,9 @@ import sys
2
  from io import BytesIO
3
 
4
  import numpy as np
 
5
  import soundfile as sf
6
  from pydub import AudioSegment, effects
7
- import pyrubberband as pyrb
8
 
9
  INT16_MAX = np.iinfo(np.int16).max
10
 
 
2
  from io import BytesIO
3
 
4
  import numpy as np
5
+ import pyrubberband as pyrb
6
  import soundfile as sf
7
  from pydub import AudioSegment, effects
 
8
 
9
  INT16_MAX = np.iinfo(np.int16).max
10
 
modules/utils/html.py CHANGED
@@ -1,6 +1,10 @@
 
 
1
  from html.parser import HTMLParser
2
 
3
 
 
 
4
  class HTMLTagRemover(HTMLParser):
5
  def __init__(self):
6
  super().__init__()
@@ -20,7 +24,19 @@ def remove_html_tags(text):
20
  return parser.get_data()
21
 
22
 
 
 
 
 
 
 
23
  if __name__ == "__main__":
24
- input_text = "<h1>一个标题</h1> 这是一段包含<code>标签</code>的文本。"
25
- output_text = remove_html_tags(input_text)
 
 
 
 
 
 
26
  print(output_text) # 输出: 一个标题 这是一段包含标签的文本。
 
1
+ import html
2
+ import re
3
  from html.parser import HTMLParser
4
 
5
 
6
+ # NOTE: 现在没用这个,因为不好解决转义字符的问题
7
+ # 除非分段处理,但是太麻烦了...
8
  class HTMLTagRemover(HTMLParser):
9
  def __init__(self):
10
  super().__init__()
 
24
  return parser.get_data()
25
 
26
 
27
+ def remove_html_tags_re(text):
28
+ text = html.unescape(text)
29
+ html_tags_pattern = re.compile(r"</?([a-zA-Z1-9]+)[^>]*>")
30
+ return re.sub(html_tags_pattern, " ", text)
31
+
32
+
33
  if __name__ == "__main__":
34
+ input_text = """
35
+ <h1>一个标题</h1> 这是一段包含<code>标签</code>的文本。 <code>&amp;</code>
36
+ <设定>
37
+ 一些文本
38
+ </设定>
39
+ """
40
+ # input_text = "我&你"
41
+ output_text = remove_html_tags_re(input_text)
42
  print(output_text) # 输出: 一个标题 这是一段包含标签的文本。
modules/webui/ssml/podcast_tab.py CHANGED
@@ -19,13 +19,18 @@ def merge_dataframe_to_ssml(msg, spk, style, df: pd.DataFrame):
19
  spk = row.get("speaker")
20
  style = row.get("style")
21
 
 
 
 
 
 
22
  ssml += f"{indent}<voice"
23
  if spk:
24
  ssml += f' spk="{spk}"'
25
  if style:
26
  ssml += f' style="{style}"'
27
  ssml += ">\n"
28
- ssml += f"{indent}{indent}{text_normalize(text)}\n"
29
  ssml += f"{indent}</voice>\n"
30
  # 原封不动输出回去是为了触发 loadding 效果
31
  return msg, spk, style, f"<speak version='0.1'>\n{ssml}</speak>"
@@ -42,6 +47,7 @@ def create_ssml_podcast_tab(ssml_input: gr.Textbox, tabs1: gr.Tabs, tabs2: gr.Ta
42
  with gr.Row():
43
  with gr.Column(scale=1):
44
  with gr.Group():
 
45
  spk_input_dropdown = gr.Dropdown(
46
  choices=get_spk_choices(),
47
  interactive=True,
@@ -55,13 +61,19 @@ def create_ssml_podcast_tab(ssml_input: gr.Textbox, tabs1: gr.Tabs, tabs2: gr.Ta
55
  show_label=False,
56
  value="*auto",
57
  )
 
58
  with gr.Group():
 
59
  msg = gr.Textbox(
60
- lines=5, label="Message", placeholder="Type speaker message here"
 
 
 
61
  )
62
  add = gr.Button("Add")
63
  undo = gr.Button("Undo")
64
  clear = gr.Button("Clear")
 
65
  with gr.Column(scale=5):
66
  with gr.Group():
67
  gr.Markdown("📔Script")
@@ -75,7 +87,7 @@ def create_ssml_podcast_tab(ssml_input: gr.Textbox, tabs1: gr.Tabs, tabs2: gr.Ta
75
  col_count=(4, "fixed"),
76
  )
77
 
78
- send_to_ssml_btn = gr.Button("📩Send to SSML", variant="primary")
79
 
80
  def add_message(msg, spk, style, sheet: pd.DataFrame):
81
  if not msg:
 
19
  spk = row.get("speaker")
20
  style = row.get("style")
21
 
22
+ text = text_normalize(text)
23
+
24
+ if text.strip() == "":
25
+ continue
26
+
27
  ssml += f"{indent}<voice"
28
  if spk:
29
  ssml += f' spk="{spk}"'
30
  if style:
31
  ssml += f' style="{style}"'
32
  ssml += ">\n"
33
+ ssml += f"{indent}{indent}{text}\n"
34
  ssml += f"{indent}</voice>\n"
35
  # 原封不动输出回去是为了触发 loadding 效果
36
  return msg, spk, style, f"<speak version='0.1'>\n{ssml}</speak>"
 
47
  with gr.Row():
48
  with gr.Column(scale=1):
49
  with gr.Group():
50
+ gr.Markdown("🗣️Speaker")
51
  spk_input_dropdown = gr.Dropdown(
52
  choices=get_spk_choices(),
53
  interactive=True,
 
61
  show_label=False,
62
  value="*auto",
63
  )
64
+
65
  with gr.Group():
66
+ gr.Markdown("📝Text Input")
67
  msg = gr.Textbox(
68
+ lines=5,
69
+ label="Message",
70
+ show_label=False,
71
+ placeholder="Type speaker message here",
72
  )
73
  add = gr.Button("Add")
74
  undo = gr.Button("Undo")
75
  clear = gr.Button("Clear")
76
+
77
  with gr.Column(scale=5):
78
  with gr.Group():
79
  gr.Markdown("📔Script")
 
87
  col_count=(4, "fixed"),
88
  )
89
 
90
+ send_to_ssml_btn = gr.Button("📩Send to SSML", variant="primary")
91
 
92
  def add_message(msg, spk, style, sheet: pd.DataFrame):
93
  if not msg:
modules/webui/ssml/spliter_tab.py CHANGED
@@ -22,6 +22,12 @@ def merge_dataframe_to_ssml(dataframe, spk, style, seed):
22
  indent = " " * 2
23
 
24
  for i, row in dataframe.iterrows():
 
 
 
 
 
 
25
  ssml += f"{indent}<voice"
26
  if spk:
27
  ssml += f' spk="{spk}"'
@@ -30,7 +36,7 @@ def merge_dataframe_to_ssml(dataframe, spk, style, seed):
30
  if seed:
31
  ssml += f' seed="{seed}"'
32
  ssml += ">\n"
33
- ssml += f"{indent}{indent}{text_normalize(row.iloc[1])}\n"
34
  ssml += f"{indent}</voice>\n"
35
  # 原封不动输出回去是为了触发 loadding 效果
36
  return dataframe, spk, style, seed, f"<speak version='0.1'>\n{ssml}</speak>"
@@ -73,8 +79,9 @@ def create_spliter_tab(ssml_input, tabs1, tabs2):
73
  show_label=False,
74
  value="*auto",
75
  )
 
76
  with gr.Group():
77
- gr.Markdown("🗣️Seed")
78
  infer_seed_input = gr.Number(
79
  value=42,
80
  label="Inference Seed",
@@ -84,10 +91,23 @@ def create_spliter_tab(ssml_input, tabs1, tabs2):
84
  )
85
  infer_seed_rand_button = gr.Button(
86
  value="🎲",
 
87
  variant="secondary",
88
  )
89
 
90
- send_btn = gr.Button("📩Send to SSML", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  with gr.Column(scale=3):
93
  with gr.Group():
@@ -102,19 +122,21 @@ def create_spliter_tab(ssml_input, tabs1, tabs2):
102
  )
103
  long_text_split_button = gr.Button("🔪Split Text")
104
 
105
- with gr.Row():
106
- with gr.Column(scale=3):
107
  with gr.Group():
108
  gr.Markdown("🎨Output")
109
  long_text_output = gr.DataFrame(
110
  headers=["index", "text", "length"],
111
  datatype=["number", "str", "number"],
112
  elem_id="long-text-output",
113
- interactive=False,
114
  wrap=True,
115
  value=[],
 
 
116
  )
117
 
 
 
118
  spk_input_dropdown.change(
119
  fn=lambda x: x.startswith("*") and "-1" or x.split(":")[-1].strip(),
120
  inputs=[spk_input_dropdown],
@@ -132,8 +154,14 @@ def create_spliter_tab(ssml_input, tabs1, tabs2):
132
  )
133
  long_text_split_button.click(
134
  split_long_text,
135
- inputs=[long_text_input],
136
- outputs=[long_text_output],
 
 
 
 
 
 
137
  )
138
 
139
  infer_seed_rand_button.click(
 
22
  indent = " " * 2
23
 
24
  for i, row in dataframe.iterrows():
25
+ text = row.iloc[1]
26
+ text = text_normalize(text)
27
+
28
+ if text.strip() == "":
29
+ continue
30
+
31
  ssml += f"{indent}<voice"
32
  if spk:
33
  ssml += f' spk="{spk}"'
 
36
  if seed:
37
  ssml += f' seed="{seed}"'
38
  ssml += ">\n"
39
+ ssml += f"{indent}{indent}{text}\n"
40
  ssml += f"{indent}</voice>\n"
41
  # 原封不动输出回去是为了触发 loadding 效果
42
  return dataframe, spk, style, seed, f"<speak version='0.1'>\n{ssml}</speak>"
 
79
  show_label=False,
80
  value="*auto",
81
  )
82
+
83
  with gr.Group():
84
+ gr.Markdown("💃Inference Seed")
85
  infer_seed_input = gr.Number(
86
  value=42,
87
  label="Inference Seed",
 
91
  )
92
  infer_seed_rand_button = gr.Button(
93
  value="🎲",
94
+ # tooltip="Random Seed",
95
  variant="secondary",
96
  )
97
 
98
+ with gr.Group():
99
+ gr.Markdown("🎛️Spliter")
100
+ eos_input = gr.Textbox(
101
+ label="eos",
102
+ value="[uv_break]",
103
+ )
104
+ spliter_thr_input = gr.Slider(
105
+ label="Spliter Threshold",
106
+ value=100,
107
+ minimum=50,
108
+ maximum=1000,
109
+ step=1,
110
+ )
111
 
112
  with gr.Column(scale=3):
113
  with gr.Group():
 
122
  )
123
  long_text_split_button = gr.Button("🔪Split Text")
124
 
 
 
125
  with gr.Group():
126
  gr.Markdown("🎨Output")
127
  long_text_output = gr.DataFrame(
128
  headers=["index", "text", "length"],
129
  datatype=["number", "str", "number"],
130
  elem_id="long-text-output",
131
+ interactive=True,
132
  wrap=True,
133
  value=[],
134
+ row_count=(0, "dynamic"),
135
+ col_count=(3, "fixed"),
136
  )
137
 
138
+ send_btn = gr.Button("📩Send to SSML", variant="primary")
139
+
140
  spk_input_dropdown.change(
141
  fn=lambda x: x.startswith("*") and "-1" or x.split(":")[-1].strip(),
142
  inputs=[spk_input_dropdown],
 
154
  )
155
  long_text_split_button.click(
156
  split_long_text,
157
+ inputs=[
158
+ long_text_input,
159
+ spliter_thr_input,
160
+ eos_input,
161
+ ],
162
+ outputs=[
163
+ long_text_output,
164
+ ],
165
  )
166
 
167
  infer_seed_rand_button.click(
modules/webui/webui_utils.py CHANGED
@@ -276,11 +276,12 @@ def refine_text(
276
 
277
  @torch.inference_mode()
278
  @spaces.GPU(duration=120)
279
- def split_long_text(long_text_input):
280
- spliter = SentenceSplitter(webui_config.spliter_threshold)
281
  sentences = spliter.parse(long_text_input)
282
- sentences = [text_normalize(s) for s in sentences]
283
  data = []
284
  for i, text in enumerate(sentences):
285
- data.append([i, text, len(text)])
 
286
  return data
 
276
 
277
  @torch.inference_mode()
278
  @spaces.GPU(duration=120)
279
+ def split_long_text(long_text_input, spliter_threshold=100, eos=""):
280
+ spliter = SentenceSplitter(threshold=spliter_threshold)
281
  sentences = spliter.parse(long_text_input)
282
+ sentences = [text_normalize(s) + eos for s in sentences]
283
  data = []
284
  for i, text in enumerate(sentences):
285
+ token_length = spliter.count_tokens(text)
286
+ data.append([i, text, token_length])
287
  return data
requirements.txt CHANGED
@@ -26,4 +26,5 @@ cn2an
26
  python-box
27
  ftfy
28
  librosa
29
- pyrubberband
 
 
26
  python-box
27
  ftfy
28
  librosa
29
+ pyrubberband
30
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch1.12cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
webui.py CHANGED
@@ -30,14 +30,6 @@ from modules.webui.app import create_interface, webui_init
30
  dcls_patch()
31
  ignore_useless_warnings()
32
 
33
- import subprocess
34
-
35
- subprocess.run(
36
- "pip install flash-attn --no-build-isolation",
37
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
38
- shell=True,
39
- )
40
-
41
 
42
  def setup_webui_args(parser: argparse.ArgumentParser):
43
  parser.add_argument("--server_name", type=str, help="server name")
 
30
  dcls_patch()
31
  ignore_useless_warnings()
32
 
 
 
 
 
 
 
 
 
33
 
34
  def setup_webui_args(parser: argparse.ArgumentParser):
35
  parser.add_argument("--server_name", type=str, help="server name")