Spaces:
Running
on
Zero
Running
on
Zero
Use grammar to avoid generation error
Browse files
app.py
CHANGED
|
@@ -6,13 +6,18 @@ import json
|
|
| 6 |
import tempfile
|
| 7 |
import hashlib
|
| 8 |
import os
|
|
|
|
| 9 |
from typing import Optional
|
|
|
|
|
|
|
| 10 |
from outetts.models.info import MODEL_INFO
|
| 11 |
from outetts.utils import helpers
|
| 12 |
from huggingface_hub import hf_hub_download
|
| 13 |
import torch
|
| 14 |
from transformers import BitsAndBytesConfig
|
| 15 |
import spaces
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# Available OuteTTS models based on the documentation
|
| 18 |
MODELS = {v.value: v for _, v in outetts.Models.__members__.items()}
|
|
@@ -26,6 +31,77 @@ MODEL_QUANTIZATION = {
|
|
| 26 |
# Cache for speaker profiles to avoid re-transcribing the same audio
|
| 27 |
speaker_cache = {}
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def get_file_hash(file_path):
|
| 30 |
"""Calculate MD5 hash of a file for caching purposes."""
|
| 31 |
hash_md5 = hashlib.md5()
|
|
@@ -34,7 +110,7 @@ def get_file_hash(file_path):
|
|
| 34 |
hash_md5.update(chunk)
|
| 35 |
return hash_md5.hexdigest()
|
| 36 |
|
| 37 |
-
def try_ggml_model(model: outetts.Models,
|
| 38 |
model_config = MODEL_INFO[model]
|
| 39 |
repo = f"OuteAI/{model.value}-GGUF"
|
| 40 |
filename = f"{model.value}-{quantization.value}.gguf"
|
|
@@ -45,12 +121,12 @@ def try_ggml_model(model: outetts.Models, backend: outetts.Backend, quantization
|
|
| 45 |
local_files_only=False
|
| 46 |
)
|
| 47 |
generation_type = outetts.GenerationType.CHUNKED
|
| 48 |
-
if model_config['interface_version'] == outetts.InterfaceVersion.V3:
|
| 49 |
-
|
| 50 |
return outetts.ModelConfig(
|
| 51 |
model_path=model_path,
|
| 52 |
tokenizer_path=f"OuteAI/{model.value}",
|
| 53 |
-
backend=
|
| 54 |
n_gpu_layers=99,
|
| 55 |
verbose=False,
|
| 56 |
device=None,
|
|
@@ -67,7 +143,7 @@ def get_interface(model_name: str):
|
|
| 67 |
|
| 68 |
try:
|
| 69 |
quantization = MODEL_QUANTIZATION.get(model, outetts.LlamaCppQuantization.Q8_0)
|
| 70 |
-
config = try_ggml_model(model,
|
| 71 |
except:
|
| 72 |
has_cuda = torch.cuda.is_available()
|
| 73 |
model_config = MODEL_INFO[model]
|
|
@@ -98,7 +174,7 @@ def get_or_create_speaker(interface, audio_file):
|
|
| 98 |
# Check if speaker profile is already cached
|
| 99 |
if cache_key in speaker_cache:
|
| 100 |
print(f"✅ Using cached speaker profile for {os.path.basename(audio_file)}")
|
| 101 |
-
return speaker_cache[cache_key]
|
| 102 |
|
| 103 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 104 |
|
|
@@ -108,7 +184,7 @@ def get_or_create_speaker(interface, audio_file):
|
|
| 108 |
speaker = interface.create_speaker(audio_file, whisper_model="large-v3-turbo", whisper_device=device)
|
| 109 |
|
| 110 |
# Cache the speaker profile
|
| 111 |
-
speaker_cache[cache_key] = speaker
|
| 112 |
print(f"💾 Cached speaker profile ({len(speaker_cache)} total cached)")
|
| 113 |
|
| 114 |
return speaker
|
|
|
|
| 6 |
import tempfile
|
| 7 |
import hashlib
|
| 8 |
import os
|
| 9 |
+
import re
|
| 10 |
from typing import Optional
|
| 11 |
+
from llama_cpp.llama import LlamaGrammar
|
| 12 |
+
from outetts.version.interface import InterfaceLLAMACPP
|
| 13 |
from outetts.models.info import MODEL_INFO
|
| 14 |
from outetts.utils import helpers
|
| 15 |
from huggingface_hub import hf_hub_download
|
| 16 |
import torch
|
| 17 |
from transformers import BitsAndBytesConfig
|
| 18 |
import spaces
|
| 19 |
+
import numpy as np
|
| 20 |
+
from collections import OrderedDict
|
| 21 |
|
| 22 |
# Available OuteTTS models based on the documentation
|
| 23 |
MODELS = {v.value: v for _, v in outetts.Models.__members__.items()}
|
|
|
|
| 31 |
# Cache for speaker profiles to avoid re-transcribing the same audio
|
| 32 |
speaker_cache = {}
|
| 33 |
|
| 34 |
+
SPLIT_SYMBOL = {
|
| 35 |
+
outetts.InterfaceVersion.V1: '<|space|>',
|
| 36 |
+
outetts.InterfaceVersion.V2: '<|space|>',
|
| 37 |
+
outetts.InterfaceVersion.V3: ' ',
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def word_to_grammar(word):
|
| 41 |
+
if all(ord(c) < 128 for c in word):
|
| 42 |
+
return f'"{word}"'
|
| 43 |
+
return f'[{"".join(OrderedDict.fromkeys(word))}]+'
|
| 44 |
+
|
| 45 |
+
# patch InterfaceLLAMACPP, inject new _generate method
|
| 46 |
+
InterfaceLLAMACPP._orig_generate = InterfaceLLAMACPP._generate
|
| 47 |
+
def ggml_generate(self, input_ids, config):
|
| 48 |
+
tokenizer = self.prompt_processor.tokenizer
|
| 49 |
+
split = SPLIT_SYMBOL.get(self.config.interface_version, ' ')
|
| 50 |
+
prompt = tokenizer.decode(input_ids, skip_special_tokens=False)
|
| 51 |
+
prompt_no_special = tokenizer.decode(input_ids, skip_special_tokens=True).strip()
|
| 52 |
+
if '<|text_start|>' not in prompt:
|
| 53 |
+
return self._orig_generate(input_ids, config)
|
| 54 |
+
speaker_text_last = prompt_no_special.split('\n').pop()
|
| 55 |
+
text = prompt[prompt.index('<|text_start|>')+14:prompt.index('<|text_end|>')]
|
| 56 |
+
gen_text = text[text.index(speaker_text_last)+len(speaker_text_last):].strip(split) if speaker_text_last in text else text
|
| 57 |
+
words = [word_to_grammar(word) for word in gen_text.split(split)]
|
| 58 |
+
if self.config.interface_version == outetts.InterfaceVersion.V2:
|
| 59 |
+
config.additional_gen_config["grammar"] = LlamaGrammar.from_string(f"""\
|
| 60 |
+
root ::= NL? {' audioBlock '.join(words)} audioEnd NL EOS?
|
| 61 |
+
audioBlock ::= TIME CODE* space NL?
|
| 62 |
+
TEXT ::= [A-Za-z0-9 .,?!]+
|
| 63 |
+
EOS ::= "<|im_end|>"
|
| 64 |
+
emotionStart ::= "<|emotion_start|>"
|
| 65 |
+
emotionEnd ::= "<|emotion_end|>"
|
| 66 |
+
audioEnd ::= "<|audio_end|>"
|
| 67 |
+
space ::= "<|space|>"
|
| 68 |
+
WORD ::= {' | '.join(words)}
|
| 69 |
+
NL ::= [\\n]
|
| 70 |
+
TIME ::= "<|t_" DECIMAL "|>"
|
| 71 |
+
CODE ::= "<|" DIGITS "|>"
|
| 72 |
+
DIGITS ::= [0-9]+
|
| 73 |
+
DECIMAL ::= [0-9]+ "." [0-9]+
|
| 74 |
+
punch ::= "<|" [a-z_]+ "|>"
|
| 75 |
+
""")
|
| 76 |
+
elif self.config.interface_version == outetts.InterfaceVersion.V3:
|
| 77 |
+
config.additional_gen_config["grammar"] = LlamaGrammar.from_string(f"""\
|
| 78 |
+
root ::= leadWord wordBlock* audioEnd NL EOS?
|
| 79 |
+
leadWord ::= WORD audioBlock
|
| 80 |
+
wordBlock ::= wordStart WORD audioBlock
|
| 81 |
+
audioBlock ::= codeBlock wordEnd NL?
|
| 82 |
+
codeBlock ::= features TIME energy spectralCentroid pitch CODE CODES*
|
| 83 |
+
TEXT ::= [A-Za-z0-9.,!?]+
|
| 84 |
+
EOS ::= "<|im_end|>"
|
| 85 |
+
audioEnd ::= "<|audio_end|>"
|
| 86 |
+
wordStart ::= "<|word_start|>"
|
| 87 |
+
wordEnd ::= "<|word_end|>"
|
| 88 |
+
features ::= "<|features|>"
|
| 89 |
+
energy ::= "<|energy_" DIGITS "|>"
|
| 90 |
+
spectralCentroid ::= "<|spectral_centroid_" DIGITS "|>"
|
| 91 |
+
pitch ::= "<|pitch_" DIGITS "|>"
|
| 92 |
+
WORD ::= {' | '.join(words)}
|
| 93 |
+
NL ::= [\\n]
|
| 94 |
+
TIME ::= "<|t_" DECIMAL "|>"
|
| 95 |
+
CODE ::= "<|code|>"
|
| 96 |
+
CODES ::= CODE1 CODE2
|
| 97 |
+
CODE1 ::= "<|c1_" DIGITS "|>"
|
| 98 |
+
CODE2 ::= "<|c2_" DIGITS "|>"
|
| 99 |
+
DIGITS ::= [0-9]+
|
| 100 |
+
DECIMAL ::= [0-9]+ "." [0-9]+
|
| 101 |
+
""")
|
| 102 |
+
return self._orig_generate(input_ids, config)
|
| 103 |
+
InterfaceLLAMACPP._generate = ggml_generate
|
| 104 |
+
|
| 105 |
def get_file_hash(file_path):
|
| 106 |
"""Calculate MD5 hash of a file for caching purposes."""
|
| 107 |
hash_md5 = hashlib.md5()
|
|
|
|
| 110 |
hash_md5.update(chunk)
|
| 111 |
return hash_md5.hexdigest()
|
| 112 |
|
| 113 |
+
def try_ggml_model(model: outetts.Models, quantization: outetts.LlamaCppQuantization):
|
| 114 |
model_config = MODEL_INFO[model]
|
| 115 |
repo = f"OuteAI/{model.value}-GGUF"
|
| 116 |
filename = f"{model.value}-{quantization.value}.gguf"
|
|
|
|
| 121 |
local_files_only=False
|
| 122 |
)
|
| 123 |
generation_type = outetts.GenerationType.CHUNKED
|
| 124 |
+
# if model_config['interface_version'] == outetts.InterfaceVersion.V3:
|
| 125 |
+
# generation_type = outetts.GenerationType.GUIDED_WORDS
|
| 126 |
return outetts.ModelConfig(
|
| 127 |
model_path=model_path,
|
| 128 |
tokenizer_path=f"OuteAI/{model.value}",
|
| 129 |
+
backend=outetts.Backend.LLAMACPP,
|
| 130 |
n_gpu_layers=99,
|
| 131 |
verbose=False,
|
| 132 |
device=None,
|
|
|
|
| 143 |
|
| 144 |
try:
|
| 145 |
quantization = MODEL_QUANTIZATION.get(model, outetts.LlamaCppQuantization.Q8_0)
|
| 146 |
+
config = try_ggml_model(model, quantization)
|
| 147 |
except:
|
| 148 |
has_cuda = torch.cuda.is_available()
|
| 149 |
model_config = MODEL_INFO[model]
|
|
|
|
| 174 |
# Check if speaker profile is already cached
|
| 175 |
if cache_key in speaker_cache:
|
| 176 |
print(f"✅ Using cached speaker profile for {os.path.basename(audio_file)}")
|
| 177 |
+
return json.loads(speaker_cache[cache_key])
|
| 178 |
|
| 179 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 180 |
|
|
|
|
| 184 |
speaker = interface.create_speaker(audio_file, whisper_model="large-v3-turbo", whisper_device=device)
|
| 185 |
|
| 186 |
# Cache the speaker profile
|
| 187 |
+
speaker_cache[cache_key] = json.dumps(speaker)
|
| 188 |
print(f"💾 Cached speaker profile ({len(speaker_cache)} total cached)")
|
| 189 |
|
| 190 |
return speaker
|