Zihan428 commited on
Commit
49abc70
·
1 Parent(s): 0aafff4

Update analyzer modules and tokenizer

Browse files
app.py CHANGED
@@ -13,7 +13,7 @@ MODEL = None
13
 
14
  LANGUAGE_CONFIG = {
15
  "ar": {
16
- "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_m1.flac",
17
  "text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."
18
  },
19
  "da": {
@@ -57,7 +57,7 @@ LANGUAGE_CONFIG = {
57
  "text": "Il mese scorso abbiamo raggiunto un nuovo traguardo: due miliardi di visualizzazioni sul nostro canale YouTube."
58
  },
59
  "ja": {
60
- "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ja_f.flac",
61
  "text": "先月、私たちのYouTubeチャンネルで二十億回の再生回数という新たなマイルストーンに到達しました。"
62
  },
63
  "ko": {
@@ -101,8 +101,8 @@ LANGUAGE_CONFIG = {
101
  "text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."
102
  },
103
  "zh": {
104
- "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f.flac",
105
- "text": "上个月,我们达到了一个新的里程碑,我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"
106
  },
107
  }
108
 
 
13
 
14
  LANGUAGE_CONFIG = {
15
  "ar": {
16
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
17
  "text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."
18
  },
19
  "da": {
 
57
  "text": "Il mese scorso abbiamo raggiunto un nuovo traguardo: due miliardi di visualizzazioni sul nostro canale YouTube."
58
  },
59
  "ja": {
60
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ja/ja_prompts1.flac",
61
  "text": "先月、私たちのYouTubeチャンネルで二十億回の再生回数という新たなマイルストーンに到達しました。"
62
  },
63
  "ko": {
 
101
  "text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."
102
  },
103
  "zh": {
104
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f2.flac",
105
+ "text": "上个月,我们达到了一个新的里程碑. 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"
106
  },
107
  }
108
 
src/chatterbox/models/t3/inference/alignment_stream_analyzer.py CHANGED
@@ -155,12 +155,12 @@ class AlignmentStreamAnalyzer:
155
  token_repetition = (
156
  # self.complete and
157
  len(self.generated_tokens) >= 3 and
158
- len(set(self.generated_tokens[-3:])) == 1
159
  )
160
 
161
  if token_repetition:
162
  repeated_token = self.generated_tokens[-1]
163
- logger.warning(f"🚨 Detected 3x repetition of token {repeated_token}")
164
 
165
  # Suppress EoS to prevent early termination
166
  if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
 
155
  token_repetition = (
156
  # self.complete and
157
  len(self.generated_tokens) >= 3 and
158
+ len(set(self.generated_tokens[-2:])) == 1
159
  )
160
 
161
  if token_repetition:
162
  repeated_token = self.generated_tokens[-1]
163
+ logger.warning(f"🚨 Detected 2x repetition of token {repeated_token}")
164
 
165
  # Suppress EoS to prevent early termination
166
  if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
src/chatterbox/models/t3/modules/t3_config.py CHANGED
@@ -25,6 +25,10 @@ class T3Config:
25
  @property
26
  def n_channels(self):
27
  return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
 
 
 
 
28
 
29
  @classmethod
30
  def english_only(cls):
 
25
  @property
26
  def n_channels(self):
27
  return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
28
+
29
+ @property
30
+ def is_multilingual(self):
31
+ return self.text_tokens_dict_size == 2352
32
 
33
  @classmethod
34
  def english_only(cls):
src/chatterbox/models/t3/t3.py CHANGED
@@ -257,14 +257,17 @@ class T3(nn.Module):
257
  # TODO? synchronize the expensive compile function
258
  # with self.compile_lock:
259
  if not self.compiled:
260
- alignment_stream_analyzer = AlignmentStreamAnalyzer(
261
- self.tfmr,
262
- None,
263
- text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
264
- alignment_layer_idx=9, # TODO: hparam or something?
265
- eos_idx=self.hp.stop_speech_token,
266
- )
267
- assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
 
 
 
268
 
269
  patched_model = T3HuggingfaceBackend(
270
  config=self.cfg,
 
257
  # TODO? synchronize the expensive compile function
258
  # with self.compile_lock:
259
  if not self.compiled:
260
+ # Default to None for English models, only create for multilingual
261
+ alignment_stream_analyzer = None
262
+ if self.hp.is_multilingual:
263
+ alignment_stream_analyzer = AlignmentStreamAnalyzer(
264
+ self.tfmr,
265
+ None,
266
+ text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
267
+ alignment_layer_idx=9, # TODO: hparam or something?
268
+ eos_idx=self.hp.stop_speech_token,
269
+ )
270
+ assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
271
 
272
  patched_model = T3HuggingfaceBackend(
273
  config=self.cfg,
src/chatterbox/models/tokenizers/tokenizer.py CHANGED
@@ -151,9 +151,7 @@ def korean_normalize(text: str) -> str:
151
  return initial + medial + final
152
 
153
  # Decompose syllables and normalize punctuation
154
- result = ''.join(decompose_hangul(char) for char in text)
155
- result = re.sub(r'[…~?!,:;()「」『』]', '.', result) # Korean punctuation
156
-
157
  return result.strip()
158
 
159
 
@@ -201,81 +199,39 @@ class ChineseCangjieConverter:
201
 
202
  def _cangjie_encode(self, glyph: str):
203
  """Encode a single Chinese glyph to Cangjie code."""
204
- code = self.word2cj.get(glyph, None)
205
- if code is None:
 
206
  return None
207
-
208
- index = self.cj2word[code].index(glyph)
209
- index_suffix = str(index) if index > 0 else ""
210
- return code + index_suffix
211
 
212
- def _normalize_numbers(self, text: str) -> str:
213
- """Convert Arabic numerals (1-99) to Chinese characters."""
214
- digit_map = {'0': '零', '1': '一', '2': '二', '3': '三', '4': '四',
215
- '5': '五', '6': '六', '7': '七', '8': '八', '9': '九'}
216
-
217
- pattern = re.compile(r'(?<!\d)(\d{1,2})(?!\d)')
218
-
219
- def convert_number(match):
220
- num = int(match.group(1))
221
-
222
- if num == 0:
223
- return '零'
224
- elif 1 <= num <= 9:
225
- return digit_map[str(num)]
226
- elif num == 10:
227
- return '十'
228
- elif 11 <= num <= 19:
229
- return '十' + digit_map[str(num % 10)]
230
- elif 20 <= num <= 99:
231
- tens, ones = divmod(num, 10)
232
- if ones == 0:
233
- return digit_map[str(tens)] + '十'
234
- else:
235
- return digit_map[str(tens)] + '十' + digit_map[str(ones)]
236
- else:
237
- return match.group(1)
238
-
239
- return pattern.sub(convert_number, text)
240
 
241
- def convert_chinese_text(self, text: str) -> str:
242
  """Convert Chinese characters in text to Cangjie tokens."""
243
- text = re.sub('[、,:;〜-()⦅⦆]', ',', text)
244
- text = re.sub('(。|…)', '.', text)
245
- text = self._normalize_numbers(text)
246
-
247
- # Skip segmentation for simple sequences (numbers, punctuation, short phrases)
248
  if self.segmenter is not None:
249
- # This avoids over-segmenting number sequences like "一, 二, 三"
250
- is_simple_sequence = (
251
- len([c for c in text if category(c) == "Lo"]) <= 15 and # Max 15 Chinese chars
252
- text.count(',') >= 2 # Contains multiple commas (likely enumeration)
253
- )
254
-
255
- # Only segment complex Chinese text (longer sentences without enumeration patterns)
256
- if not is_simple_sequence and len(text) > 10:
257
- chinese_chars = sum(1 for c in text if category(c) == "Lo")
258
- total_chars = len([c for c in text if c.strip()])
259
-
260
- if chinese_chars > 5 and chinese_chars / total_chars > 0.7:
261
- segmented_words = self.segmenter.cut(text)
262
- text = " ".join(segmented_words)
263
 
264
- output = []
265
- for char in text:
266
- if category(char) == "Lo": # Chinese character
267
- cangjie = self._cangjie_encode(char)
268
  if cangjie is None:
269
- output.append(char)
270
  continue
271
-
272
- code_tokens = [f"[cj_{c}]" for c in cangjie]
273
- code_tokens.append("[cj_.]")
274
-
275
- output.append("".join(code_tokens))
 
276
  else:
277
- output.append(char)
278
-
279
  return "".join(output)
280
 
281
 
@@ -299,7 +255,7 @@ class MTLTokenizer:
299
  def encode(self, txt: str, language_id: str = None):
300
  # Language-specific text processing
301
  if language_id == 'zh':
302
- txt = self.cangjie_converter.convert_chinese_text(txt)
303
  elif language_id == 'ja':
304
  txt = hiragana_normalize(txt)
305
  elif language_id == 'he':
 
151
  return initial + medial + final
152
 
153
  # Decompose syllables and normalize punctuation
154
+ result = ''.join(decompose_hangul(char) for char in text)
 
 
155
  return result.strip()
156
 
157
 
 
199
 
200
  def _cangjie_encode(self, glyph: str):
201
  """Encode a single Chinese glyph to Cangjie code."""
202
+ normed_glyph = glyph
203
+ code = self.word2cj.get(normed_glyph, None)
204
+ if code is None: # e.g. Japanese hiragana
205
  return None
206
+ index = self.cj2word[code].index(normed_glyph)
207
+ index = str(index) if index > 0 else ""
208
+ return code + str(index)
 
209
 
210
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ def __call__(self, text):
213
  """Convert Chinese characters in text to Cangjie tokens."""
214
+ output = []
 
 
 
 
215
  if self.segmenter is not None:
216
+ segmented_words = self.segmenter.cut(text)
217
+ full_text = " ".join(segmented_words)
218
+ else:
219
+ full_text = text
 
 
 
 
 
 
 
 
 
 
220
 
221
+ for t in full_text:
222
+ if category(t) == "Lo":
223
+ cangjie = self._cangjie_encode(t)
 
224
  if cangjie is None:
225
+ output.append(t)
226
  continue
227
+ code = []
228
+ for c in cangjie:
229
+ code.append(f"[cj_{c}]")
230
+ code.append("[cj_.]")
231
+ code = "".join(code)
232
+ output.append(code)
233
  else:
234
+ output.append(t)
 
235
  return "".join(output)
236
 
237
 
 
255
  def encode(self, txt: str, language_id: str = None):
256
  # Language-specific text processing
257
  if language_id == 'zh':
258
+ txt = self.cangjie_converter(txt)
259
  elif language_id == 'ja':
260
  txt = hiragana_normalize(txt)
261
  elif language_id == 'he':
src/chatterbox/mtl_tts.py CHANGED
@@ -83,7 +83,7 @@ def punc_norm(text: str) -> str:
83
 
84
  # Add full stop if no ending punc
85
  text = text.rstrip(" ")
86
- sentence_enders = {".", "!", "?", "-", ","}
87
  if not any(text.endswith(p) for p in sentence_enders):
88
  text += "."
89
 
 
83
 
84
  # Add full stop if no ending punc
85
  text = text.rstrip(" ")
86
+ sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"}
87
  if not any(text.endswith(p) for p in sentence_enders):
88
  text += "."
89