hans00 commited on
Commit
86b23e4
·
unverified ·
1 Parent(s): 216ef9a

Use grammar to avoid generation error

Browse files
Files changed (1) hide show
  1. app.py +83 -7
app.py CHANGED
@@ -6,13 +6,18 @@ import json
6
  import tempfile
7
  import hashlib
8
  import os
 
9
  from typing import Optional
 
 
10
  from outetts.models.info import MODEL_INFO
11
  from outetts.utils import helpers
12
  from huggingface_hub import hf_hub_download
13
  import torch
14
  from transformers import BitsAndBytesConfig
15
  import spaces
 
 
16
 
17
  # Available OuteTTS models based on the documentation
18
  MODELS = {v.value: v for _, v in outetts.Models.__members__.items()}
@@ -26,6 +31,77 @@ MODEL_QUANTIZATION = {
26
  # Cache for speaker profiles to avoid re-transcribing the same audio
27
  speaker_cache = {}
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_file_hash(file_path):
30
  """Calculate MD5 hash of a file for caching purposes."""
31
  hash_md5 = hashlib.md5()
@@ -34,7 +110,7 @@ def get_file_hash(file_path):
34
  hash_md5.update(chunk)
35
  return hash_md5.hexdigest()
36
 
37
- def try_ggml_model(model: outetts.Models, backend: outetts.Backend, quantization: outetts.LlamaCppQuantization):
38
  model_config = MODEL_INFO[model]
39
  repo = f"OuteAI/{model.value}-GGUF"
40
  filename = f"{model.value}-{quantization.value}.gguf"
@@ -45,12 +121,12 @@ def try_ggml_model(model: outetts.Models, backend: outetts.Backend, quantization
45
  local_files_only=False
46
  )
47
  generation_type = outetts.GenerationType.CHUNKED
48
- if model_config['interface_version'] == outetts.InterfaceVersion.V3:
49
- generation_type = outetts.GenerationType.GUIDED_WORDS
50
  return outetts.ModelConfig(
51
  model_path=model_path,
52
  tokenizer_path=f"OuteAI/{model.value}",
53
- backend=backend,
54
  n_gpu_layers=99,
55
  verbose=False,
56
  device=None,
@@ -67,7 +143,7 @@ def get_interface(model_name: str):
67
 
68
  try:
69
  quantization = MODEL_QUANTIZATION.get(model, outetts.LlamaCppQuantization.Q8_0)
70
- config = try_ggml_model(model, outetts.Backend.LLAMACPP, quantization)
71
  except:
72
  has_cuda = torch.cuda.is_available()
73
  model_config = MODEL_INFO[model]
@@ -98,7 +174,7 @@ def get_or_create_speaker(interface, audio_file):
98
  # Check if speaker profile is already cached
99
  if cache_key in speaker_cache:
100
  print(f"✅ Using cached speaker profile for {os.path.basename(audio_file)}")
101
- return speaker_cache[cache_key]
102
 
103
  device = "cuda" if torch.cuda.is_available() else "cpu"
104
 
@@ -108,7 +184,7 @@ def get_or_create_speaker(interface, audio_file):
108
  speaker = interface.create_speaker(audio_file, whisper_model="large-v3-turbo", whisper_device=device)
109
 
110
  # Cache the speaker profile
111
- speaker_cache[cache_key] = speaker
112
  print(f"💾 Cached speaker profile ({len(speaker_cache)} total cached)")
113
 
114
  return speaker
 
6
  import tempfile
7
  import hashlib
8
  import os
9
+ import re
10
  from typing import Optional
11
+ from llama_cpp.llama import LlamaGrammar
12
+ from outetts.version.interface import InterfaceLLAMACPP
13
  from outetts.models.info import MODEL_INFO
14
  from outetts.utils import helpers
15
  from huggingface_hub import hf_hub_download
16
  import torch
17
  from transformers import BitsAndBytesConfig
18
  import spaces
19
+ import numpy as np
20
+ from collections import OrderedDict
21
 
22
  # Available OuteTTS models based on the documentation
23
  MODELS = {v.value: v for _, v in outetts.Models.__members__.items()}
 
31
  # Cache for speaker profiles to avoid re-transcribing the same audio
32
  speaker_cache = {}
33
 
34
+ SPLIT_SYMBOL = {
35
+ outetts.InterfaceVersion.V1: '<|space|>',
36
+ outetts.InterfaceVersion.V2: '<|space|>',
37
+ outetts.InterfaceVersion.V3: ' ',
38
+ }
39
+
40
+ def word_to_grammar(word):
41
+ if all(ord(c) < 128 for c in word):
42
+ return f'"{word}"'
43
+ return f'[{"".join(OrderedDict.fromkeys(word))}]+'
44
+
45
+ # patch InterfaceLLAMACPP, inject new _generate method
46
+ InterfaceLLAMACPP._orig_generate = InterfaceLLAMACPP._generate
47
+ def ggml_generate(self, input_ids, config):
48
+ tokenizer = self.prompt_processor.tokenizer
49
+ split = SPLIT_SYMBOL.get(self.config.interface_version, ' ')
50
+ prompt = tokenizer.decode(input_ids, skip_special_tokens=False)
51
+ prompt_no_special = tokenizer.decode(input_ids, skip_special_tokens=True).strip()
52
+ if '<|text_start|>' not in prompt:
53
+ return self._orig_generate(input_ids, config)
54
+ speaker_text_last = prompt_no_special.split('\n').pop()
55
+ text = prompt[prompt.index('<|text_start|>')+14:prompt.index('<|text_end|>')]
56
+ gen_text = text[text.index(speaker_text_last)+len(speaker_text_last):].strip(split) if speaker_text_last in text else text
57
+ words = [word_to_grammar(word) for word in gen_text.split(split)]
58
+ if self.config.interface_version == outetts.InterfaceVersion.V2:
59
+ config.additional_gen_config["grammar"] = LlamaGrammar.from_string(f"""\
60
+ root ::= NL? {' audioBlock '.join(words)} audioEnd NL EOS?
61
+ audioBlock ::= TIME CODE* space NL?
62
+ TEXT ::= [A-Za-z0-9 .,?!]+
63
+ EOS ::= "<|im_end|>"
64
+ emotionStart ::= "<|emotion_start|>"
65
+ emotionEnd ::= "<|emotion_end|>"
66
+ audioEnd ::= "<|audio_end|>"
67
+ space ::= "<|space|>"
68
+ WORD ::= {' | '.join(words)}
69
+ NL ::= [\\n]
70
+ TIME ::= "<|t_" DECIMAL "|>"
71
+ CODE ::= "<|" DIGITS "|>"
72
+ DIGITS ::= [0-9]+
73
+ DECIMAL ::= [0-9]+ "." [0-9]+
74
+ punch ::= "<|" [a-z_]+ "|>"
75
+ """)
76
+ elif self.config.interface_version == outetts.InterfaceVersion.V3:
77
+ config.additional_gen_config["grammar"] = LlamaGrammar.from_string(f"""\
78
+ root ::= leadWord wordBlock* audioEnd NL EOS?
79
+ leadWord ::= WORD audioBlock
80
+ wordBlock ::= wordStart WORD audioBlock
81
+ audioBlock ::= codeBlock wordEnd NL?
82
+ codeBlock ::= features TIME energy spectralCentroid pitch CODE CODES*
83
+ TEXT ::= [A-Za-z0-9.,!?]+
84
+ EOS ::= "<|im_end|>"
85
+ audioEnd ::= "<|audio_end|>"
86
+ wordStart ::= "<|word_start|>"
87
+ wordEnd ::= "<|word_end|>"
88
+ features ::= "<|features|>"
89
+ energy ::= "<|energy_" DIGITS "|>"
90
+ spectralCentroid ::= "<|spectral_centroid_" DIGITS "|>"
91
+ pitch ::= "<|pitch_" DIGITS "|>"
92
+ WORD ::= {' | '.join(words)}
93
+ NL ::= [\\n]
94
+ TIME ::= "<|t_" DECIMAL "|>"
95
+ CODE ::= "<|code|>"
96
+ CODES ::= CODE1 CODE2
97
+ CODE1 ::= "<|c1_" DIGITS "|>"
98
+ CODE2 ::= "<|c2_" DIGITS "|>"
99
+ DIGITS ::= [0-9]+
100
+ DECIMAL ::= [0-9]+ "." [0-9]+
101
+ """)
102
+ return self._orig_generate(input_ids, config)
103
+ InterfaceLLAMACPP._generate = ggml_generate
104
+
105
  def get_file_hash(file_path):
106
  """Calculate MD5 hash of a file for caching purposes."""
107
  hash_md5 = hashlib.md5()
 
110
  hash_md5.update(chunk)
111
  return hash_md5.hexdigest()
112
 
113
+ def try_ggml_model(model: outetts.Models, quantization: outetts.LlamaCppQuantization):
114
  model_config = MODEL_INFO[model]
115
  repo = f"OuteAI/{model.value}-GGUF"
116
  filename = f"{model.value}-{quantization.value}.gguf"
 
121
  local_files_only=False
122
  )
123
  generation_type = outetts.GenerationType.CHUNKED
124
+ # if model_config['interface_version'] == outetts.InterfaceVersion.V3:
125
+ # generation_type = outetts.GenerationType.GUIDED_WORDS
126
  return outetts.ModelConfig(
127
  model_path=model_path,
128
  tokenizer_path=f"OuteAI/{model.value}",
129
+ backend=outetts.Backend.LLAMACPP,
130
  n_gpu_layers=99,
131
  verbose=False,
132
  device=None,
 
143
 
144
  try:
145
  quantization = MODEL_QUANTIZATION.get(model, outetts.LlamaCppQuantization.Q8_0)
146
+ config = try_ggml_model(model, quantization)
147
  except:
148
  has_cuda = torch.cuda.is_available()
149
  model_config = MODEL_INFO[model]
 
174
  # Check if speaker profile is already cached
175
  if cache_key in speaker_cache:
176
  print(f"✅ Using cached speaker profile for {os.path.basename(audio_file)}")
177
+ return json.loads(speaker_cache[cache_key])
178
 
179
  device = "cuda" if torch.cuda.is_available() else "cpu"
180
 
 
184
  speaker = interface.create_speaker(audio_file, whisper_model="large-v3-turbo", whisper_device=device)
185
 
186
  # Cache the speaker profile
187
+ speaker_cache[cache_key] = json.dumps(speaker)
188
  print(f"💾 Cached speaker profile ({len(speaker_cache)} total cached)")
189
 
190
  return speaker