salt-116k / inference.py
Ksenia Sycheva
Add inference code
ee7a752
import torchaudio
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)
from speechtokenizer import SpeechTokenizer
from audiotools import AudioSignal
def decode_tts(tokens, quantizer, n_codebooks, n_original_tokens, start_audio_token_id, end_audio_token_id):
# find start and end indices of audio tokens
start = torch.nonzero(tokens == start_audio_token_id)
end = torch.nonzero(tokens == end_audio_token_id)
start = start[0, -1] + 1 if len(start) else 0
end = end[0, -1] if len(end) else tokens.shape[-1]
# subtract length of original vocabulary -> tokens in range [0, 1024)
audio_tokens = tokens[start:end] % n_original_tokens
reminder = audio_tokens.shape[-1] % n_codebooks
if reminder:
# pad if last frame is incomplete
pad_tokens = torch.zeros(n_codebooks - reminder, device="cuda")
audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0)
transposed = audio_tokens.view(-1, n_codebooks).t()
codes = transposed.view(n_codebooks, 1, -1).to(device)
audio = quantizer.decode(codes).squeeze(0)
del tokens
del audio_tokens
torch.cuda.empty_cache()
return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate)
def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
text_tokenized = tokenizer(text, return_tensors="pt")
text_input_tokens = text_tokenized["input_ids"].to(device)
soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
text_tokens = torch.cat([text_input_tokens, soa], dim=1)
attention_mask = torch.ones(text_tokens.size(), device=device)
output_audio_tokens = model.generate(
text_tokens,
attention_mask=attention_mask,
max_new_tokens=max_seq_length,
top_k=top_k,
do_sample=True,
temperature=0.8,
no_repeat_ngram_size=3,
)
audio_signal = decode_tts(output_audio_tokens[0], quantizer, 3, len(tokenizer) - codebook_size, soa, eoa)
return audio_signal
def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
audio_data, sample_rate = torchaudio.load(audio_path)
audio = audio_data.view(1, 1, -1).float().to(device)
# bandwidth_id = torch.tensor([0])
codes = quantizer.encode(audio)
raw_audio_tokens = codes[:, :n_codebooks_asr] + len(tokenizer) - codebook_size
soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1)
tokens = torch.cat([audio_tokens], dim=1)
attention_mask = torch.ones(tokens.size(), device=device)
output_text_tokens = model.generate(
tokens,
attention_mask=attention_mask,
max_new_tokens=max_seq_length,
temperature=0.6,
top_p=0.9,
top_k=top_k,
no_repeat_ngram_size=4,
length_penalty=2.0,
repetition_penalty=1.5
)
output_text_tokens = output_text_tokens.cpu()[0]
output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]]
decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)
return decoded_text
device = "cuda"
n_special_tokens = 3
n_codebooks_tts = 3
n_codebooks_asr = 1
start_audio_token = "<soa>"
end_audio_token = "<eoa>"
end_sequence_token = "<eos>"
base_model = "Vikhrmodels/salt-116k"
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=".")
model = AutoModelForCausalLM.from_pretrained(
base_model,
cache_dir=".",
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map={"": 0}
)
quantizer_speech = SpeechTokenizer.load_from_checkpoint("speechtokenizer/config.json",
"speechtokenizer/SpeechTokenizer.pt")
quantizer_speech = quantizer_speech.eval().to(device)
codebook_size = quantizer_speech.quantizer.bins
text = ("Say 'COUNT NUMBERS FROM ONE TO TEN' with a male speaker delivers a very monotone and "
"low-pitched speech with a moderate speed in a setting with almost no noise, "
"creating a clear and quiet recording.")
audio_signal = infer_text_to_audio(text, model, tokenizer, quantizer_speech, top_k=50)
audio_signal.write("output.wav")
audio_path = "./input.wav"
generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer_speech)
print(generated_text)