ksych commited on
Commit
6fd7d49
·
verified ·
1 Parent(s): e66f4b5

Add inference code

Browse files
Files changed (1) hide show
  1. inference.py +189 -0
inference.py CHANGED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+
3
+ import torchaudio
4
+ import torch
5
+
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModelForCausalLM,
9
+ )
10
+
11
+ from speechtokenizer import SpeechTokenizer
12
+ from WavTokenizer.decoder.pretrained import WavTokenizer
13
+ from audiotools import AudioSignal
14
+
15
+
16
+ def resample(audio_data: torch.Tensor, sample_rate: int):
17
+ print("Inout sample rate:", sample_rate)
18
+ if sample_rate == 24000:
19
+ audio_data24k = audio_data
20
+ audio_data16k = torch.tensor(
21
+ librosa.resample(
22
+ audio_data.cpu().detach().numpy(), orig_sr=sample_rate, target_sr=16000
23
+ )
24
+ )
25
+ elif sample_rate == 16000:
26
+ audio_data16k = audio_data
27
+ audio_data24k = torch.tensor(
28
+ librosa.resample(
29
+ audio_data.cpu().detach().numpy(), orig_sr=sample_rate, target_sr=24000
30
+ )
31
+ )
32
+ else:
33
+ print("Resampling everything")
34
+ audio_data16k = torch.tensor(
35
+ librosa.resample(
36
+ audio_data.cpu().detach().numpy(), orig_sr=sample_rate, target_sr=16000
37
+ )
38
+ )
39
+ audio_data24k = torch.tensor(
40
+ librosa.resample(
41
+ audio_data.cpu().detach().numpy(), orig_sr=sample_rate, target_sr=24000
42
+ )
43
+ )
44
+
45
+ return (audio_data16k.view(1, -1).float().to(device),
46
+ audio_data24k.view(1, -1).float().to(device))
47
+
48
+
49
+ def decode_tts(tokens, quantizer, n_codebooks, n_original_tokens, start_audio_token_id, end_audio_token_id):
50
+ # find start and end indices of audio tokens
51
+ start = torch.nonzero(tokens == start_audio_token_id)
52
+ end = torch.nonzero(tokens == end_audio_token_id)
53
+
54
+ start = start[0, -1] + 1 if len(start) else 0
55
+ end = end[0, -1] if len(end) else tokens.shape[-1]
56
+
57
+ # subtract length of original vocabulary -> tokens in range [0, 1024)
58
+ audio_tokens = tokens[start:end] % n_original_tokens
59
+ reminder = audio_tokens.shape[-1] % n_codebooks
60
+
61
+ if reminder:
62
+ # pad if last frame is incomplete
63
+ pad_tokens = torch.zeros(n_codebooks - reminder, device="cuda")
64
+ audio_tokens = torch.cat([audio_tokens, pad_tokens], dim=0)
65
+
66
+ transposed = audio_tokens.view(-1, n_codebooks).t()
67
+ codes = transposed.view(n_codebooks, 1, -1).to(device)
68
+
69
+ audio = quantizer.decode(codes).squeeze(0)
70
+
71
+ del tokens
72
+ del audio_tokens
73
+ torch.cuda.empty_cache()
74
+
75
+ return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate)
76
+
77
+
78
+ def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
79
+ text_tokenized = tokenizer(text, return_tensors="pt")
80
+ text_input_tokens = text_tokenized["input_ids"].to(device)
81
+
82
+ soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
83
+ eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
84
+
85
+ text_tokens = torch.cat([text_input_tokens, soa], dim=1)
86
+ attention_mask = torch.ones(text_tokens.size(), device=device)
87
+
88
+ output_audio_tokens = model.generate(
89
+ text_tokens,
90
+ attention_mask=attention_mask,
91
+ max_new_tokens=max_seq_length,
92
+ top_k=top_k,
93
+ do_sample=True,
94
+ temperature=0.1,
95
+ repetition_penalty=1.1,
96
+ length_penalty=1.2,
97
+ no_repeat_ngram_size=3,
98
+ )
99
+
100
+ audio_signal = decode_tts(output_audio_tokens[0], quantizer, 3, len(tokenizer), soa, eoa)
101
+
102
+ return audio_signal
103
+
104
+
105
+ def infer_audio_to_text(audio_path, model, tokenizer, quantizer_speech, quantizer_wav, max_seq_length=1024, top_k=20):
106
+ audio_data, sample_rate = torchaudio.load(audio_path)
107
+
108
+ audio_16k, audio_24k = resample(audio_data, sample_rate)
109
+ bandwidth_id = torch.tensor([0])
110
+
111
+ codes_semantics = quantizer_speech.encode(audio_16k.reshape(1, 1, -1))
112
+ raw_semantic_tokens = codes_semantics + len(tokenizer)
113
+ raw_semantic_tokens = raw_semantic_tokens[:1].view(1, -1)
114
+
115
+ _, codes = quantizer_wav.encode_infer(audio_24k, bandwidth_id=bandwidth_id)
116
+ raw_acoustic_tokens = codes + len(tokenizer) + 1024
117
+ raw_acoustic_tokens = raw_acoustic_tokens.view(1, -1)
118
+
119
+ audio_tokens = torch.cat([raw_semantic_tokens, raw_acoustic_tokens], dim=1)
120
+
121
+ soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
122
+ eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
123
+ audio_tokens = torch.cat([soa, audio_tokens, eoa], dim=1)
124
+
125
+ # text_tokens = tokenizer("is said with", return_tensors="pt")["input_ids"].to(device)
126
+ tokens = torch.cat([audio_tokens], dim=1)
127
+
128
+ attention_mask = torch.ones(tokens.size(), device=device)
129
+
130
+ output_text_tokens = model.generate(
131
+ tokens,
132
+ attention_mask=attention_mask,
133
+ max_new_tokens=max_seq_length,
134
+ do_sample=True,
135
+ temperature=0.1,
136
+ top_p=0.9,
137
+ top_k=top_k,
138
+ )
139
+
140
+ output_text_tokens = output_text_tokens.cpu()[0]
141
+ output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]]
142
+ decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)
143
+
144
+ return decoded_text
145
+
146
+
147
+ device = "cuda"
148
+
149
+ n_codebooks_tts = 3
150
+ n_codebooks_asr = 1
151
+
152
+ start_audio_token = "<|start_of_audio|>"
153
+ end_audio_token = "<|end_of_audio|>"
154
+ end_sequence_token = "<|end_of_text|>"
155
+
156
+ base_model = "Vikhrmodels/salt-asr_speech_1_wav_1_tts_speech_3_text-10k"
157
+
158
+
159
+ if __name__ == "__main__":
160
+ tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=".")
161
+ model = AutoModelForCausalLM.from_pretrained(
162
+ base_model,
163
+ cache_dir=".",
164
+ torch_dtype=torch.bfloat16,
165
+ attn_implementation="sdpa",
166
+ device_map={"": 0}
167
+ )
168
+
169
+ quantizer_speech = SpeechTokenizer.load_from_checkpoint("speechtokenizer/config.json",
170
+ "speechtokenizer/SpeechTokenizer.pt")
171
+ quantizer_speech = quantizer_speech.eval().to(device)
172
+ codebook_size = quantizer_speech.quantizer.bins
173
+
174
+ quantizer_wav = WavTokenizer.from_pretrained0802("wavtokenizer/config.yaml",
175
+ "wavtokenizer/WavTokenizer_small_600_24k_4096.ckpt")
176
+ quantizer_wav = quantizer_wav.to(device)
177
+
178
+ text = ("Say 'COUNT NUMBERS FROM ONE TO TEN' with a male speaker delivers a very monotone and "
179
+ "low-pitched speech with a moderate speed in a setting with almost no noise, "
180
+ "creating a clear and quiet recording.")
181
+
182
+ audio_signal = infer_text_to_audio(text, model, tokenizer, quantizer_speech, top_k=60)
183
+ audio_signal.write("output.wav")
184
+
185
+ audio_path = "./input.wav"
186
+ generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer_speech, quantizer_wav, top_k=10)
187
+ print(generated_text)
188
+
189
+