File size: 4,827 Bytes
ee7a752 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torchaudio
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)
from speechtokenizer import SpeechTokenizer
from audiotools import AudioSignal
def decode_tts(tokens, quantizer, n_codebooks, n_original_tokens, start_audio_token_id, end_audio_token_id):
# find start and end indices of audio tokens
start = torch.nonzero(tokens == start_audio_token_id)
end = torch.nonzero(tokens == end_audio_token_id)
start = start[0, -1] + 1 if len(start) else 0
end = end[0, -1] if len(end) else tokens.shape[-1]
# subtract length of original vocabulary -> tokens in range [0, 1024)
audio_tokens = tokens[start:end] % n_original_tokens
reminder = audio_tokens.shape[-1] % n_codebooks
if reminder:
# pad if last frame is incomplete
pad_tokens = torch.zeros(n_codebooks - reminder, device="cuda")
audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0)
transposed = audio_tokens.view(-1, n_codebooks).t()
codes = transposed.view(n_codebooks, 1, -1).to(device)
audio = quantizer.decode(codes).squeeze(0)
del tokens
del audio_tokens
torch.cuda.empty_cache()
return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate)
def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
text_tokenized = tokenizer(text, return_tensors="pt")
text_input_tokens = text_tokenized["input_ids"].to(device)
soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
text_tokens = torch.cat([text_input_tokens, soa], dim=1)
attention_mask = torch.ones(text_tokens.size(), device=device)
output_audio_tokens = model.generate(
text_tokens,
attention_mask=attention_mask,
max_new_tokens=max_seq_length,
top_k=top_k,
do_sample=True,
temperature=0.8,
no_repeat_ngram_size=3,
)
audio_signal = decode_tts(output_audio_tokens[0], quantizer, 3, len(tokenizer) - codebook_size, soa, eoa)
return audio_signal
def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
audio_data, sample_rate = torchaudio.load(audio_path)
audio = audio_data.view(1, 1, -1).float().to(device)
# bandwidth_id = torch.tensor([0])
codes = quantizer.encode(audio)
raw_audio_tokens = codes[:, :n_codebooks_asr] + len(tokenizer) - codebook_size
soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1)
tokens = torch.cat([audio_tokens], dim=1)
attention_mask = torch.ones(tokens.size(), device=device)
output_text_tokens = model.generate(
tokens,
attention_mask=attention_mask,
max_new_tokens=max_seq_length,
temperature=0.6,
top_p=0.9,
top_k=top_k,
no_repeat_ngram_size=4,
length_penalty=2.0,
repetition_penalty=1.5
)
output_text_tokens = output_text_tokens.cpu()[0]
output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]]
decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)
return decoded_text
device = "cuda"
n_special_tokens = 3
n_codebooks_tts = 3
n_codebooks_asr = 1
start_audio_token = "<soa>"
end_audio_token = "<eoa>"
end_sequence_token = "<eos>"
base_model = "Vikhrmodels/salt-116k"
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=".")
model = AutoModelForCausalLM.from_pretrained(
base_model,
cache_dir=".",
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map={"": 0}
)
quantizer_speech = SpeechTokenizer.load_from_checkpoint("speechtokenizer/config.json",
"speechtokenizer/SpeechTokenizer.pt")
quantizer_speech = quantizer_speech.eval().to(device)
codebook_size = quantizer_speech.quantizer.bins
text = ("Say 'COUNT NUMBERS FROM ONE TO TEN' with a male speaker delivers a very monotone and "
"low-pitched speech with a moderate speed in a setting with almost no noise, "
"creating a clear and quiet recording.")
audio_signal = infer_text_to_audio(text, model, tokenizer, quantizer_speech, top_k=50)
audio_signal.write("output.wav")
audio_path = "./input.wav"
generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer_speech)
print(generated_text)
|