Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from speechtokenizer import SpeechTokenizer | |
| from audiotools import AudioSignal | |
| import bitsandbytes as bnb # Import bitsandbytes for INT8 quantization | |
| import numpy as np | |
| from uuid import uuid4 | |
| # Load the necessary models and tokenizers | |
| model_path = "Vikhrmodels/llama_asr_tts_24000" | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".") | |
| # Специальные токены | |
| start_audio_token = "<soa>" | |
| end_audio_token = "<eoa>" | |
| end_sequence_token = "<eos>" | |
| # Константы | |
| n_codebooks = 3 | |
| max_seq_length = 1024 | |
| top_k = 20 | |
| from safetensors.torch import load_file | |
| def convert_to_16_bit_wav(data): | |
| # Based on: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.write.html | |
| # breakpoint() | |
| if data.dtype == np.float32: | |
| # warnings.warn( | |
| # "Audio data is not in 16-bit integer format." | |
| # "Trying to convert to 16-bit int format." | |
| # ) | |
| data = data / np.abs(data).max() | |
| data = data * 32767 | |
| data = data.astype(np.int16) | |
| elif data.dtype == np.int32: | |
| # warnings.warn( | |
| # "Audio data is not in 16-bit integer format." | |
| # "Trying to convert to 16-bit int format." | |
| # ) | |
| data = data / 65538 | |
| data = data.astype(np.int16) | |
| elif data.dtype == np.int16: | |
| pass | |
| elif data.dtype == np.uint8: | |
| # warnings.warn( | |
| # "Audio data is not in 16-bit integer format." | |
| # "Trying to convert to 16-bit int format." | |
| # ) | |
| data = data * 257 - 32768 | |
| data = data.astype(np.int16) | |
| else: | |
| raise ValueError("Audio data cannot be converted to " "16-bit int format.") | |
| return data | |
| # Load the model with INT8 quantization | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| cache_dir=".", | |
| load_in_8bit=True, # Enable loading in INT8 | |
| device_map="auto" # Automatically map model to available devices | |
| ) | |
| # Configurations for Speech Tokenizer | |
| config_path = "audiotokenizer/speechtokenizer_hubert_avg_config.json" | |
| ckpt_path = "audiotokenizer/SpeechTokenizer.pt" | |
| quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path) | |
| quantizer.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Перемещение всех слоев квантизатора на устройство и их заморозка | |
| def freeze_entire_model(model): | |
| for n, p in model.named_parameters(): | |
| p.requires_grad = False | |
| return model | |
| for n, child in quantizer.named_children(): | |
| child.to(device) | |
| child = freeze_entire_model(child) | |
| # Функция для создания токенов заполнения для аудио | |
| def get_audio_padding_tokens(quantizer): | |
| audio = torch.zeros((1, 1, 1)).to(device) | |
| codes = quantizer.encode(audio) | |
| del audio | |
| torch.cuda.empty_cache() | |
| return {"audio_tokens": codes.squeeze(1)} | |
| # Функция для декодирования аудио из токенов | |
| def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens): | |
| start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1]) | |
| end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1]) | |
| start = start[0, -1] + 1 if len(start) else 0 | |
| end = end[0, -1] if len(end) else tokens.shape[-1] | |
| audio_tokens = tokens[start:end] % n_original_tokens | |
| reminder = audio_tokens.shape[-1] % n_codebooks | |
| if reminder: | |
| audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0) | |
| transposed = audio_tokens.view(-1, n_codebooks).t() | |
| codes = transposed.view(n_codebooks, 1, -1).to(device) | |
| audio = quantizer.decode(codes).squeeze(0) | |
| torch.cuda.empty_cache() | |
| xp = str(uuid4())+'.wav' | |
| AudioSignal(audio.detach().cpu().numpy(),quantizer.sample_rate).write(xp) | |
| return xp | |
| # Пример использования | |
| # Функция инференса для текста на входе и аудио на выходе | |
| def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20): | |
| text_tokenized = tokenizer(text, return_tensors="pt") | |
| text_input_tokens = text_tokenized["input_ids"].to(device) | |
| soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) | |
| eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) | |
| text_tokens = torch.cat([text_input_tokens, soa], dim=1) | |
| attention_mask = torch.ones(text_tokens.size(), device=device) | |
| output_audio_tokens = model.generate(text_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True) | |
| padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"].to(device) | |
| audio_signal = decode_audio(output_audio_tokens[0], quantizer, padding_tokens.t()[0], len(tokenizer) - 1024) | |
| return audio_signal | |
| # Функция инференса для аудио на входе и текста на выходе | |
| def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20): | |
| audio_data, sample_rate = torchaudio.load(audio_path) | |
| audio = audio_data.view(1, 1, -1).float().to(device) | |
| codes = quantizer.encode(audio) | |
| n_codebooks_a = 1 | |
| raw_audio_tokens = codes[:, :n_codebooks_a] + len(tokenizer) - 1024 | |
| soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) | |
| eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device) | |
| audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1) | |
| attention_mask = torch.ones(audio_tokens.size(), device=device) | |
| output_text_tokens = model.generate(audio_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True) | |
| output_text_tokens = output_text_tokens.cpu()[0] | |
| output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]] | |
| decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True) | |
| return decoded_text | |
| # Functions for inference | |
| def infer_text_to_audio_gr(text): | |
| audio_signal = infer_text_to_audio(text.strip().upper(), model, tokenizer, quantizer) | |
| return audio_signal | |
| def infer_audio_to_text_gr(audio_path): | |
| generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer) | |
| return generated_text | |
| # Gradio Interface | |
| text_to_audio_interface = gr.Interface( | |
| fn=infer_text_to_audio_gr, | |
| inputs=gr.Textbox(label="Input Text"), | |
| outputs=gr.Audio(label="Аудио Ответ"), | |
| title="T2S", | |
| description="Модель в режиме ответа в аудио", | |
| allow_flagging='never', | |
| ) | |
| audio_to_text_interface = gr.Interface( | |
| fn=infer_audio_to_text_gr, | |
| inputs=gr.Audio(type="filepath", label="Input Audio"), | |
| outputs=gr.Textbox(label="Текстовый ответ"), | |
| title="S2T", | |
| description="Модель в режиме ответа в тексте", | |
| allow_flagging='never' | |
| ) | |
| # Launch Gradio App | |
| demo = gr.TabbedInterface([text_to_audio_interface, audio_to_text_interface], ["Текст - Аудио", "Аудио - Текст"]) | |
| demo.launch(share=True) |