Spaces:
Running
on
Zero
Running
on
Zero
import re | |
import time | |
import inflect | |
import torch | |
import torch.nn.functional as F | |
from torchaudio.transforms import Resample | |
from torch import Tensor | |
from torch.nn.attention import SDPBackend, sdpa_kernel | |
from vui.model import Vui | |
from vui.sampling import multinomial, sample_top_k, sample_top_p, sample_top_p_top_k | |
resample = Resample(22050, 16000).cuda() | |
def ensure_spaces_around_tags(text: str): | |
# Add space before '[' if not preceded by space, '<', or '[' | |
text = re.sub( | |
r"(?<![<\[\s])(\[)", | |
lambda m: ( | |
f"\n{m.group(1)}" | |
if m.start() > 0 and text[m.start() - 1] == "\n" | |
else f" {m.group(1)}" | |
), | |
text, | |
) | |
# Add space after ']' if not preceded by digit+']' and not followed by space, '>', or ']' | |
text = re.sub( | |
r"(?<!\d\])(\])(?![>\]\s])", | |
lambda m: ( | |
f"{m.group(1)}\n" | |
if m.end() < len(text) and text[m.end()] == "\n" | |
else f"{m.group(1)} " | |
), | |
text, | |
) | |
text = text.strip() | |
return text | |
REPLACE = [ | |
("β", ","), | |
("'", "'"), | |
(":", ","), | |
(";", ","), | |
] | |
engine = None | |
wm = None | |
def asr(chunk, model=None, prefix=None): | |
import whisper | |
global wm | |
if model is not None: | |
wm = model | |
elif wm is None: | |
wm = whisper.load_model("turbo", "cuda") | |
"""Process audio with VAD and transcribe""" | |
chunk = whisper.pad_or_trim(chunk) | |
mel = whisper.log_mel_spectrogram(chunk, n_mels=wm.dims.n_mels).to(wm.device) | |
options = whisper.DecodingOptions( | |
language="en", without_timestamps=True, prefix=prefix | |
) | |
result = whisper.decode(wm, mel[None], options) | |
return result[0].text | |
def replace_numbers_with_words(text): | |
global engine | |
if engine is None: | |
engine = inflect.engine() | |
# Function to convert a number match to words | |
def number_to_words(match): | |
number = match.group() | |
return engine.number_to_words(number) + " " | |
# Replace digits with their word equivalents | |
return re.sub(r"\d+", number_to_words, text) | |
valid_non_speech = ["breath", "sigh", "laugh", "tut", "hesitate", "clearthroat"] | |
valid_non_speech = [f"[{v}]" for v in valid_non_speech] | |
def remove_all_invalid_non_speech(txt): | |
""" | |
Remove all non-speech markers that are not in the valid_non_speech list. | |
Only keeps valid non-speech markers like [breath], [sigh], etc. | |
""" | |
# Find all text within square brackets | |
bracket_pattern = r"\[([^\]]+)\]" | |
brackets = re.findall(bracket_pattern, txt) | |
# For each bracketed text, check if it's in our valid list | |
for bracket in brackets: | |
bracket_with_brackets = f"[{bracket}]" | |
if bracket_with_brackets not in valid_non_speech and bracket != "pause": | |
# If not valid, remove it from the text | |
txt = txt.replace(bracket_with_brackets, "") | |
return txt | |
def simple_clean(text): | |
text = re.sub(r"(\d+)am", r"\1 AM", text) | |
text = re.sub(r"(\d+)pm", r"\1 PM", text) | |
text = replace_numbers_with_words(text) | |
text = ensure_spaces_around_tags(text) | |
text = remove_all_invalid_non_speech(text) | |
text = text.replace('"', "") | |
text = text.replace("β", "") | |
text = text.replace("β", "") | |
text = text.replace("β", "'") | |
text = text.replace("%", " percent") | |
text = text.replace("*", "") | |
text = text.replace("(", "") | |
text = text.replace(")", "") | |
text = text.replace(";", "") | |
text = text.replace("β", " ") | |
text = text.replace("β", "") | |
text = text.replace(":", "") | |
text = text.replace("β¦", "...") | |
text = text.replace("s...", "s") | |
# replace repeating \n with just one \n | |
text = re.sub(r"\n+", "\n", text) | |
ntxt = re.sub(r" +", " ", text) | |
# Ensure that ntxt ends with . or ? | |
ntxt = ntxt.strip() | |
if not ntxt.endswith(".") or ntxt.endswith("?"): | |
ntxt += "." | |
ntxt += " [pause]" | |
return ntxt | |
def generate( | |
self: Vui, | |
text: str, | |
prompt_codes: Tensor | None = None, | |
temperature: float = 0.5, | |
top_k: int | None = 150, | |
top_p: float | None = None, | |
max_gen_len: int = int(120 * 21.53), | |
): | |
text = simple_clean(text) | |
with ( | |
torch.autocast("cuda", torch.bfloat16, True), | |
sdpa_kernel([SDPBackend.MATH]), | |
): | |
t1 = time.perf_counter() | |
batch_size = 1 | |
device = self.device | |
self.dtype | |
self.decoder.allocate_inference_cache(batch_size, device, torch.bfloat16) | |
texts = [text] | |
encoded = self.tokenizer( | |
texts, | |
padding="longest", | |
return_tensors="pt", | |
) | |
input_ids = encoded.input_ids.to(device) | |
text_embeddings = self.token_emb(input_ids) | |
B = batch_size | |
Q = self.config.model.n_quantizers | |
if prompt_codes is None: | |
prompt_codes = torch.zeros( | |
(batch_size, Q, 0), dtype=torch.int64, device=device | |
) | |
else: | |
prompt_codes = prompt_codes[:, :Q].repeat(batch_size, 1, 1) | |
start_offset = prompt_codes.size(-1) | |
pattern = self.pattern_provider.get_pattern(max_gen_len) | |
# this token is used as default value for codes that are not generated yet | |
unknown_token = -1 | |
special_token_id = self.config.model.special_token_id | |
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence | |
codes = torch.full( | |
(B, Q, max_gen_len), unknown_token, dtype=torch.int64, device=device | |
) | |
codes[:, :, :start_offset] = prompt_codes | |
sequence, indexes, mask = pattern.build_pattern_sequence( | |
codes, special_token_id | |
) | |
# retrieve the start_offset in the sequence: | |
# it is the first sequence step that contains the `start_offset` timestep | |
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) | |
assert start_offset_sequence is not None | |
prev_offset = 0 | |
S = sequence.size(-1) | |
do_prefill = True | |
eos = self.config.model.audio_eos_id | |
for offset in range(start_offset_sequence, S): | |
# print(f"{prev_offset}:{offset}") | |
curr_sequence = sequence[..., prev_offset:offset] | |
audio_embeddings = ( | |
sum([self.audio_embeddings[q](curr_sequence[:, q]) for q in range(Q)]) | |
/ Q | |
) | |
if do_prefill: | |
embeddings = torch.cat((text_embeddings, audio_embeddings), dim=1) | |
T = embeddings.size(1) | |
input_pos = torch.arange(0, T, device=device) | |
do_prefill = False | |
else: | |
embeddings = audio_embeddings | |
input_pos = torch.tensor([T], device=device) | |
T += 1 | |
out = self.decoder(embeddings, input_pos) | |
if offset == 15: | |
print("TTFB", time.perf_counter() - t1) | |
logits = torch.stack( | |
[self.audio_heads[q](out[:, -1]) for q in range(Q)], dim=1 | |
) | |
repetition_penalty = 1.4 | |
history_window = 12 | |
# Get the history of generated tokens for each quantizer | |
for q in range(Q): | |
# Extract the history window for this quantizer | |
history_start = max(0, offset - history_window) | |
token_history = sequence[0, q, history_start:offset] | |
# Only apply penalty to tokens that appear in the history | |
unique_tokens = torch.unique(token_history) | |
unique_tokens = unique_tokens[unique_tokens != special_token_id] | |
unique_tokens = unique_tokens[unique_tokens != eos] | |
unique_tokens = unique_tokens[unique_tokens != unknown_token] | |
if len(unique_tokens) > 0: | |
# Apply penalty by dividing the logits for tokens that have appeared recently | |
logits[0, q, unique_tokens] = ( | |
logits[0, q, unique_tokens] / repetition_penalty | |
) | |
if offset < 24.53 * 4: | |
logits[..., eos] = -float("inf") | |
probs = F.softmax(logits / temperature, dim=-1) | |
# print(probs.shape) | |
if top_p is not None and top_k is not None: | |
next_codes = sample_top_p_top_k(probs, top_p, top_k) | |
elif top_p is not None and top_p > 0: | |
next_codes = sample_top_p(probs, top_p) | |
elif top_k is not None and top_k > 0: | |
next_codes = sample_top_k(probs, top_k) | |
else: | |
next_codes = multinomial(probs, num_samples=1) | |
next_codes = next_codes.repeat(batch_size, 1, 1) | |
if (probs[..., eos] > 0.95).any(): | |
print("breaking at", offset) | |
break | |
valid_mask = mask[..., offset : offset + 1].expand(B, -1, -1) | |
next_codes[~valid_mask] = special_token_id | |
sequence[..., offset : offset + 1] = torch.where( | |
sequence[..., offset : offset + 1] == unknown_token, | |
next_codes, | |
sequence[..., offset : offset + 1], | |
) | |
prev_offset = offset | |
# print(sequence.shape) | |
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence( | |
sequence, special_token=unknown_token | |
) | |
# sanity checks over the returned codes and corresponding masks | |
# assert (out_codes[..., :max_gen_len] != unknown_token).all() | |
# assert (out_mask[..., :max_gen_len] == 1).all() | |
out_codes = out_codes[..., prompt_codes.shape[-1] : offset] | |
return out_codes[[0]] | |
def render( | |
self: Vui, | |
text: str, | |
prompt_codes: Tensor | None = None, | |
temperature: float = 0.5, | |
top_k: int | None = 100, | |
top_p: float | None = None, | |
max_secs: int = 100, | |
): | |
""" | |
Render audio from text. Uses generate for text < 1000 characters, | |
otherwise breaks text into sections and uses chunking with context. | |
""" | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
text = remove_all_invalid_non_speech(text) | |
text = simple_clean(text) | |
SR = self.codec.config.sample_rate | |
HZ = self.codec.hz | |
max_gen_len = int(HZ * max_secs) | |
t1 = time.perf_counter() | |
if len(text) < 1400: | |
codes = generate( | |
self, text, prompt_codes, temperature, top_k, top_p, max_gen_len | |
) | |
codes = codes[..., :-10] | |
audio = self.codec.from_indices(codes) | |
print("RTF", (audio.numel()/SR)/(time.perf_counter() - t1)) | |
return audio.cpu() | |
# Otherwise we have to do some clever chaining! | |
orig_codes = prompt_codes | |
lines = text.split("\n") | |
audios = [] | |
prev_codes = prompt_codes | |
prev_text = "" | |
for i, line in enumerate(lines): | |
run = True | |
while run: | |
current_text = prev_text + "\n" + line if prev_text else line | |
current_text = current_text.strip() | |
current_text = current_text.replace("...", "") | |
current_text = current_text + " [pause]" | |
# Calculate max length based on text length | |
maxlen = int(HZ * int(60 * len(current_text) / 500)) | |
try: | |
print("rendering", current_text) | |
codes = generate( | |
self, | |
current_text, | |
prompt_codes=prev_codes, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
max_gen_len=maxlen, | |
) | |
codes = codes[..., :-10] | |
paudio = self.codec.from_indices(codes) | |
prev_text = line | |
prev_codes = codes | |
audios.append(paudio) | |
except KeyboardInterrupt: | |
break | |
except RuntimeError as e: | |
prev_codes = orig_codes | |
prev_text = "" | |
print(e) | |
return torch.cat(audios, dim=-1) | |