gpt-omni commited on
Commit
5e4b316
1 Parent(s): 8fc1cf4
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A simple web interactive chat demo based on gradio."""
2
+
3
+ import os
4
+ import time
5
+ import gradio as gr
6
+ import numpy as np
7
+ import spaces
8
+ import torch
9
+
10
+
11
+ from inference import OmniInference
12
+
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ omni_client = OmniInference('./checkpoint', device)
16
+ omni_client.warm_up()
17
+
18
+
19
+ OUT_CHUNK = 4096
20
+ OUT_RATE = 24000
21
+ OUT_CHANNELS = 1
22
+
23
+
24
+ @spaces.GPU
25
+ def process_audio(audio):
26
+ filepath = audio
27
+ print(f"filepath: {filepath}")
28
+ if filepath is None:
29
+ return
30
+
31
+ cnt = 0
32
+ tik = time.time()
33
+ for chunk in omni_client.run_AT_batch_stream(filepath):
34
+ # Convert chunk to numpy array
35
+ if cnt == 0:
36
+ print(f"first chunk time cost: {time.time() - tik:.3f}")
37
+ cnt += 1
38
+ audio_data = np.frombuffer(chunk, dtype=np.int16)
39
+ audio_data = audio_data.reshape(-1, OUT_CHANNELS)
40
+ yield OUT_RATE, audio_data.astype(np.int16)
41
+
42
+
43
+ demo = gr.Interface(
44
+ process_audio,
45
+ inputs=gr.Audio(type="filepath", label="Microphone"),
46
+ outputs=[gr.Audio(label="Response", streaming=True, autoplay=True)],
47
+ title="Chat Mini-Omni Demo",
48
+ live=True,
49
+ )
50
+ demo.queue().launch()
inference.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lightning as L
3
+ import torch
4
+ import time
5
+ from snac import SNAC
6
+ from litgpt import Tokenizer
7
+ from litgpt.utils import (
8
+ num_parameters,
9
+ )
10
+ from litgpt.generate.base import (
11
+ generate_AA,
12
+ generate_ASR,
13
+ generate_TA,
14
+ generate_TT,
15
+ generate_AT,
16
+ generate_TA_BATCH,
17
+ next_token_batch
18
+ )
19
+ import soundfile as sf
20
+ from litgpt.model import GPT, Config
21
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
22
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
23
+ from utils.snac_utils import get_snac, generate_audio_data
24
+ import whisper
25
+ from tqdm import tqdm
26
+ from huggingface_hub import snapshot_download
27
+
28
+
29
+ torch.set_printoptions(sci_mode=False)
30
+
31
+
32
+ # TODO
33
+ text_vocabsize = 151936
34
+ text_specialtokens = 64
35
+ audio_vocabsize = 4096
36
+ audio_specialtokens = 64
37
+
38
+ padded_text_vocabsize = text_vocabsize + text_specialtokens
39
+ padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
40
+
41
+ _eot = text_vocabsize
42
+ _pad_t = text_vocabsize + 1
43
+ _input_t = text_vocabsize + 2
44
+ _answer_t = text_vocabsize + 3
45
+ _asr = text_vocabsize + 4
46
+
47
+ _eoa = audio_vocabsize
48
+ _pad_a = audio_vocabsize + 1
49
+ _input_a = audio_vocabsize + 2
50
+ _answer_a = audio_vocabsize + 3
51
+ _split = audio_vocabsize + 4
52
+
53
+
54
+ def get_input_ids_TA(text, text_tokenizer):
55
+ input_ids_item = [[] for _ in range(8)]
56
+ text_tokens = text_tokenizer.encode(text)
57
+ for i in range(7):
58
+ input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [
59
+ layershift(_answer_a, i)
60
+ ]
61
+ input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)
62
+ input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]
63
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
64
+ return input_ids_item
65
+
66
+
67
+ def get_input_ids_TT(text, text_tokenizer):
68
+ input_ids_item = [[] for i in range(8)]
69
+ text_tokens = text_tokenizer.encode(text).tolist()
70
+
71
+ for i in range(7):
72
+ input_ids_item[i] = torch.tensor(
73
+ [layershift(_pad_a, i)] * (len(text_tokens) + 3)
74
+ ).unsqueeze(0)
75
+ input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]
76
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
77
+
78
+ return input_ids_item
79
+
80
+
81
+ def get_input_ids_whisper(
82
+ mel, leng, whispermodel, device,
83
+ special_token_a=_answer_a, special_token_t=_answer_t,
84
+ ):
85
+
86
+ with torch.no_grad():
87
+ mel = mel.unsqueeze(0).to(device)
88
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
89
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
90
+
91
+ T = audio_feature.size(0)
92
+ input_ids = []
93
+ for i in range(7):
94
+ input_ids_item = []
95
+ input_ids_item.append(layershift(_input_a, i))
96
+ input_ids_item += [layershift(_pad_a, i)] * T
97
+ input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
98
+ input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
99
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])
100
+ input_ids.append(input_id_T.unsqueeze(0))
101
+ return audio_feature.unsqueeze(0), input_ids
102
+
103
+
104
+ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
105
+ with torch.no_grad():
106
+ mel = mel.unsqueeze(0).to(device)
107
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
108
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
109
+ T = audio_feature.size(0)
110
+ input_ids_AA = []
111
+ for i in range(7):
112
+ input_ids_item = []
113
+ input_ids_item.append(layershift(_input_a, i))
114
+ input_ids_item += [layershift(_pad_a, i)] * T
115
+ input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
116
+ input_ids_AA.append(torch.tensor(input_ids_item))
117
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
118
+ input_ids_AA.append(input_id_T)
119
+
120
+ input_ids_AT = []
121
+ for i in range(7):
122
+ input_ids_item = []
123
+ input_ids_item.append(layershift(_input_a, i))
124
+ input_ids_item += [layershift(_pad_a, i)] * T
125
+ input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
126
+ input_ids_AT.append(torch.tensor(input_ids_item))
127
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
128
+ input_ids_AT.append(input_id_T)
129
+
130
+ input_ids = [input_ids_AA, input_ids_AT]
131
+ stacked_inputids = [[] for _ in range(8)]
132
+ for i in range(2):
133
+ for j in range(8):
134
+ stacked_inputids[j].append(input_ids[i][j])
135
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
136
+ return torch.stack([audio_feature, audio_feature]), stacked_inputids
137
+
138
+
139
+ def load_audio(path):
140
+ audio = whisper.load_audio(path)
141
+ duration_ms = (len(audio) / 16000) * 1000
142
+ audio = whisper.pad_or_trim(audio)
143
+ mel = whisper.log_mel_spectrogram(audio)
144
+ return mel, int(duration_ms / 20) + 1
145
+
146
+
147
+ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
148
+ snacmodel, out_dir=None):
149
+ with fabric.init_tensor():
150
+ model.set_kv_cache(batch_size=2)
151
+ tokenlist = generate_TA_BATCH(
152
+ model,
153
+ audio_feature,
154
+ input_ids,
155
+ [leng, leng],
156
+ ["A1A2", "A1T2"],
157
+ max_returned_tokens=2048,
158
+ temperature=0.9,
159
+ top_k=1,
160
+ eos_id_a=_eoa,
161
+ eos_id_t=_eot,
162
+ pad_id_t=_pad_t,
163
+ shift=padded_text_vocabsize,
164
+ include_prompt=True,
165
+ generate_text=True,
166
+ )
167
+ text_tokenlist = tokenlist[-1]
168
+ if text_vocabsize in text_tokenlist:
169
+ text_tokenlist = text_tokenlist[: text_tokenlist.index(text_vocabsize)]
170
+ text = text_tokenizer.decode(torch.tensor(text_tokenlist)).strip()
171
+
172
+ audio_tokenlist = tokenlist[:-1]
173
+ audiolist = reconscruct_snac(audio_tokenlist)
174
+ audio = reconstruct_tensors(audiolist)
175
+ if out_dir is None:
176
+ out_dir = "./output/default/A1-A2-batch"
177
+ else:
178
+ out_dir = out_dir + "/A1-A2-batch"
179
+ if not os.path.exists(out_dir):
180
+ os.makedirs(out_dir)
181
+ with torch.inference_mode():
182
+ audio_hat = snacmodel.decode(audio)
183
+ sf.write(
184
+ f"{out_dir}/{step:02d}.wav",
185
+ audio_hat.squeeze().cpu().numpy(),
186
+ 24000,
187
+ )
188
+ model.clear_kv_cache()
189
+ return text
190
+
191
+
192
+ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
193
+ with fabric.init_tensor():
194
+ model.set_kv_cache(batch_size=1)
195
+ tokenlist = generate_AT(
196
+ model,
197
+ audio_feature,
198
+ input_ids,
199
+ [leng],
200
+ ["AT"],
201
+ max_returned_tokens=2048,
202
+ temperature=0.9,
203
+ top_k=1,
204
+ eos_id_a=_eoa,
205
+ eos_id_t=_eot,
206
+ pad_id_t=_pad_t,
207
+ shift=padded_text_vocabsize,
208
+ include_prompt=True,
209
+ generate_text=True,
210
+ )
211
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
212
+
213
+
214
+ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
215
+ snacmodel, out_dir=None):
216
+ with fabric.init_tensor():
217
+ model.set_kv_cache(batch_size=1)
218
+ tokenlist = generate_AA(
219
+ model,
220
+ audio_feature,
221
+ input_ids,
222
+ [leng],
223
+ ["A1T2"],
224
+ max_returned_tokens=2048,
225
+ temperature=0.9,
226
+ top_k=1,
227
+ eos_id_a=_eoa,
228
+ eos_id_t=_eot,
229
+ pad_id_t=_pad_t,
230
+ shift=padded_text_vocabsize,
231
+ include_prompt=True,
232
+ generate_text=True,
233
+ )
234
+ audiolist = reconscruct_snac(tokenlist)
235
+ tokenlist = tokenlist[-1]
236
+ if text_vocabsize in tokenlist:
237
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
238
+ if out_dir is None:
239
+ out_dir = "./output/default/A1-A2"
240
+ else:
241
+ out_dir = out_dir + "/A1-A2"
242
+ if not os.path.exists(out_dir):
243
+ os.makedirs(out_dir)
244
+
245
+ audio = reconstruct_tensors(audiolist)
246
+ with torch.inference_mode():
247
+ audio_hat = snacmodel.decode(audio)
248
+ sf.write(
249
+ f"{out_dir}/{step:02d}.wav",
250
+ audio_hat.squeeze().cpu().numpy(),
251
+ 24000,
252
+ )
253
+ model.clear_kv_cache()
254
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
255
+
256
+
257
+ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
258
+ with fabric.init_tensor():
259
+ model.set_kv_cache(batch_size=1)
260
+ tokenlist = generate_ASR(
261
+ model,
262
+ audio_feature,
263
+ input_ids,
264
+ [leng],
265
+ ["A1T1"],
266
+ max_returned_tokens=2048,
267
+ temperature=0.9,
268
+ top_k=1,
269
+ eos_id_a=_eoa,
270
+ eos_id_t=_eot,
271
+ pad_id_t=_pad_t,
272
+ shift=padded_text_vocabsize,
273
+ include_prompt=True,
274
+ generate_text=True,
275
+ )
276
+ model.clear_kv_cache()
277
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
278
+
279
+
280
+ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
281
+ snacmodel, out_dir=None):
282
+ with fabric.init_tensor():
283
+ model.set_kv_cache(batch_size=1)
284
+ tokenlist = generate_TA(
285
+ model,
286
+ None,
287
+ input_ids,
288
+ None,
289
+ ["T1A2"],
290
+ max_returned_tokens=2048,
291
+ temperature=0.9,
292
+ top_k=1,
293
+ eos_id_a=_eoa,
294
+ eos_id_t=_eot,
295
+ pad_id_t=_pad_t,
296
+ shift=padded_text_vocabsize,
297
+ include_prompt=True,
298
+ generate_text=True,
299
+ )
300
+
301
+ audiolist = reconscruct_snac(tokenlist)
302
+ tokenlist = tokenlist[-1]
303
+
304
+ if text_vocabsize in tokenlist:
305
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
306
+ audio = reconstruct_tensors(audiolist)
307
+ if out_dir is None:
308
+ out_dir = "./output/default/T1-A2"
309
+ else:
310
+ out_dir = out_dir + "/T1-A2"
311
+ if not os.path.exists(out_dir):
312
+ os.makedirs(out_dir)
313
+
314
+ with torch.inference_mode():
315
+ audio_hat = snacmodel.decode(audio)
316
+ sf.write(
317
+ f"{out_dir}/{step:02d}.wav",
318
+ audio_hat.squeeze().cpu().numpy(),
319
+ 24000,
320
+ )
321
+ model.clear_kv_cache()
322
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
323
+
324
+
325
+ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
326
+
327
+ with fabric.init_tensor():
328
+ model.set_kv_cache(batch_size=1)
329
+ tokenlist = generate_TT(
330
+ model,
331
+ None,
332
+ input_ids,
333
+ None,
334
+ ["T1T2"],
335
+ max_returned_tokens=2048,
336
+ temperature=0.9,
337
+ top_k=1,
338
+ eos_id_a=_eoa,
339
+ eos_id_t=_eot,
340
+ pad_id_t=_pad_t,
341
+ shift=padded_text_vocabsize,
342
+ include_prompt=True,
343
+ generate_text=True,
344
+ )
345
+ model.clear_kv_cache()
346
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
347
+
348
+
349
+ def load_model(ckpt_dir, device):
350
+ snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
351
+ whispermodel = whisper.load_model("small").to(device)
352
+ text_tokenizer = Tokenizer(ckpt_dir)
353
+ fabric = L.Fabric(devices=1, strategy="auto")
354
+ config = Config.from_file(ckpt_dir + "/model_config.yaml")
355
+ config.post_adapter = False
356
+
357
+ with fabric.init_module(empty_init=False):
358
+ model = GPT(config)
359
+
360
+ model = fabric.setup(model)
361
+ state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
362
+ model.load_state_dict(state_dict, strict=True)
363
+ model.to(device).eval()
364
+
365
+ return fabric, model, text_tokenizer, snacmodel, whispermodel
366
+
367
+
368
+ def download_model(ckpt_dir):
369
+ repo_id = "gpt-omni/mini-omni"
370
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
371
+
372
+
373
+ class OmniInference:
374
+
375
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
376
+ self.device = device
377
+ if not os.path.exists(ckpt_dir):
378
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
379
+ download_model(ckpt_dir)
380
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
381
+
382
+ def warm_up(self, sample='./data/samples/output1.wav'):
383
+ for _ in self.run_AT_batch_stream(sample):
384
+ pass
385
+
386
+ @torch.inference_mode()
387
+ def run_AT_batch_stream(self,
388
+ audio_path,
389
+ stream_stride=4,
390
+ max_returned_tokens=2048,
391
+ temperature=0.9,
392
+ top_k=1,
393
+ top_p=1.0,
394
+ eos_id_a=_eoa,
395
+ eos_id_t=_eot,
396
+ ):
397
+
398
+ assert os.path.exists(audio_path), f"audio file {audio_path} not found"
399
+ model = self.model
400
+
401
+ with self.fabric.init_tensor():
402
+ model.set_kv_cache(batch_size=2)
403
+
404
+ mel, leng = load_audio(audio_path)
405
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
406
+ T = input_ids[0].size(1)
407
+ device = input_ids[0].device
408
+
409
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
410
+
411
+ if model.max_seq_length < max_returned_tokens - 1:
412
+ raise NotImplementedError(
413
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
414
+ )
415
+
416
+ input_pos = torch.tensor([T], device=device)
417
+ list_output = [[] for i in range(8)]
418
+ tokens_A, token_T = next_token_batch(
419
+ model,
420
+ audio_feature.to(torch.float32).to(model.device),
421
+ input_ids,
422
+ [T - 3, T - 3],
423
+ ["A1T2", "A1T2"],
424
+ input_pos=torch.arange(0, T, device=device),
425
+ temperature=temperature,
426
+ top_k=top_k,
427
+ top_p=top_p,
428
+ )
429
+
430
+ for i in range(7):
431
+ list_output[i].append(tokens_A[i].tolist()[0])
432
+ list_output[7].append(token_T.tolist()[0])
433
+
434
+ model_input_ids = [[] for i in range(8)]
435
+ for i in range(7):
436
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
437
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
438
+ model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
439
+ model_input_ids[i] = torch.stack(model_input_ids[i])
440
+
441
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
442
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
443
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
444
+
445
+ text_end = False
446
+ index = 1
447
+ nums_generate = stream_stride
448
+ begin_generate = False
449
+ current_index = 0
450
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
451
+ tokens_A, token_T = next_token_batch(
452
+ model,
453
+ None,
454
+ model_input_ids,
455
+ None,
456
+ None,
457
+ input_pos=input_pos,
458
+ temperature=temperature,
459
+ top_k=top_k,
460
+ top_p=top_p,
461
+ )
462
+
463
+ if text_end:
464
+ token_T = torch.tensor([_pad_t], device=device)
465
+
466
+ if tokens_A[-1] == eos_id_a:
467
+ break
468
+
469
+ if token_T == eos_id_t:
470
+ text_end = True
471
+
472
+ for i in range(7):
473
+ list_output[i].append(tokens_A[i].tolist()[0])
474
+ list_output[7].append(token_T.tolist()[0])
475
+
476
+ model_input_ids = [[] for i in range(8)]
477
+ for i in range(7):
478
+ tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
479
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
480
+ model_input_ids[i].append(
481
+ torch.tensor([layershift(4097, i)], device=device)
482
+ )
483
+ model_input_ids[i] = torch.stack(model_input_ids[i])
484
+
485
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
486
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
487
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
488
+
489
+ if index == 7:
490
+ begin_generate = True
491
+
492
+ if begin_generate:
493
+ current_index += 1
494
+ if current_index == nums_generate:
495
+ current_index = 0
496
+ snac = get_snac(list_output, index, nums_generate)
497
+ audio_stream = generate_audio_data(snac, self.snacmodel)
498
+ yield audio_stream
499
+
500
+ input_pos = input_pos.add_(1)
501
+ index += 1
502
+ text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
503
+ print(f"text output: {text}")
504
+ model.clear_kv_cache()
505
+ return list_output
506
+
507
+
508
+ def test_infer():
509
+ device = "cuda:0"
510
+ out_dir = f"./output/{get_time_str()}"
511
+ ckpt_dir = f"./checkpoint"
512
+ if not os.path.exists(ckpt_dir):
513
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
514
+ download_model(ckpt_dir)
515
+
516
+ fabric, model, text_tokenizer, snacmodel, whispermodel = load_model(ckpt_dir, device)
517
+
518
+ task = ['A1A2', 'asr', "T1A2", "AA-BATCH", 'T1T2', 'AT']
519
+
520
+ # prepare test data
521
+ # TODO
522
+ test_audio_list = sorted(os.listdir('./data/samples'))
523
+ test_audio_list = [os.path.join('./data/samples', path) for path in test_audio_list]
524
+ test_audio_transcripts = [
525
+ "What is your name?",
526
+ "what are your hobbies?",
527
+ "Do you like beijing",
528
+ "How are you feeling today?",
529
+ "what is the weather like today?",
530
+ ]
531
+ test_text_list = [
532
+ "What is your name?",
533
+ "How are you feeling today?",
534
+ "Can you describe your surroundings?",
535
+ "What did you do yesterday?",
536
+ "What is your favorite book and why?",
537
+ "How do you make a cup of tea?",
538
+ "What is the weather like today?",
539
+ "Can you explain the concept of time?",
540
+ "Can you tell me a joke?",
541
+ ]
542
+
543
+ # LOAD MODEL
544
+ with torch.no_grad():
545
+ if "A1A2" in task:
546
+ print("===============================================================")
547
+ print(" testing A1A2")
548
+ print("===============================================================")
549
+ step = 0
550
+ for path in test_audio_list:
551
+ try:
552
+ mel, leng = load_audio(path)
553
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device)
554
+ text = A1_A2(
555
+ fabric,
556
+ audio_feature,
557
+ input_ids,
558
+ leng,
559
+ model,
560
+ text_tokenizer,
561
+ step,
562
+ snacmodel,
563
+ out_dir=out_dir,
564
+ )
565
+ print(f"input: {test_audio_transcripts[step]}")
566
+ print(f"output: {text}")
567
+ step += 1
568
+ print(
569
+ "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
570
+ )
571
+ except:
572
+ print(f"[error] failed to process {path}")
573
+ print("===============================================================")
574
+
575
+ if 'asr' in task:
576
+ print("===============================================================")
577
+ print(" testing asr")
578
+ print("===============================================================")
579
+
580
+ index = 0
581
+ step = 0
582
+ for path in test_audio_list:
583
+ mel, leng = load_audio(path)
584
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_asr)
585
+ output = A1_T1(fabric, audio_feature, input_ids ,leng, model, text_tokenizer, index).lower().replace(',','').replace('.','').replace('?','')
586
+ print(f"audio_path: {path}")
587
+ print(f"audio transcript: {test_audio_transcripts[index]}")
588
+ print(f"asr output: {output}")
589
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
590
+ index += 1
591
+
592
+ if "T1A2" in task:
593
+ step = 0
594
+ print("\n")
595
+ print("===============================================================")
596
+ print(" testing T1A2")
597
+ print("===============================================================")
598
+ for text in test_text_list:
599
+ input_ids = get_input_ids_TA(text, text_tokenizer)
600
+ text_output = T1_A2(fabric, input_ids, model, text_tokenizer, step,
601
+ snacmodel, out_dir=out_dir)
602
+ print(f"input: {text}")
603
+ print(f"output: {text_output}")
604
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
605
+ step += 1
606
+ print("===============================================================")
607
+
608
+ if "T1T2" in task:
609
+ step = 0
610
+ print("\n")
611
+ print("===============================================================")
612
+ print(" testing T1T2")
613
+ print("===============================================================")
614
+
615
+ for text in test_text_list:
616
+ input_ids = get_input_ids_TT(text, text_tokenizer)
617
+ text_output = T1_T2(fabric, input_ids, model, text_tokenizer, step)
618
+ print(f" Input: {text}")
619
+ print(f"Output: {text_output}")
620
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
621
+ print("===============================================================")
622
+
623
+ if "AT" in task:
624
+ print("===============================================================")
625
+ print(" testing A1T2")
626
+ print("===============================================================")
627
+ step = 0
628
+ for path in test_audio_list:
629
+ mel, leng = load_audio(path)
630
+ audio_feature, input_ids = get_input_ids_whisper(
631
+ mel, leng, whispermodel, device,
632
+ special_token_a=_pad_a, special_token_t=_answer_t
633
+ )
634
+ text = A1_T2(
635
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step
636
+ )
637
+ print(f"input: {test_audio_transcripts[step]}")
638
+ print(f"output: {text}")
639
+ step += 1
640
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
641
+ print("===============================================================")
642
+
643
+ if "AA-BATCH" in task:
644
+ print("===============================================================")
645
+ print(" testing A1A2-BATCH")
646
+ print("===============================================================")
647
+ step = 0
648
+ for path in test_audio_list:
649
+ mel, leng = load_audio(path)
650
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
651
+ text = A1_A2_batch(
652
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
653
+ snacmodel, out_dir=out_dir
654
+ )
655
+ print(f"input: {test_audio_transcripts[step]}")
656
+ print(f"output: {text}")
657
+ step += 1
658
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
659
+ print("===============================================================")
660
+
661
+ print("*********************** test end *****************************")
662
+
663
+
664
+
665
+ if __name__ == "__main__":
666
+ test_infer()
litgpt/.DS_Store ADDED
Binary file (6.15 kB). View file
 
litgpt/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import logging
4
+ import re
5
+ from litgpt.model import GPT # needs to be imported before config
6
+ from litgpt.config import Config
7
+ from litgpt.tokenizer import Tokenizer
8
+
9
+ # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
10
+ pattern = re.compile(".*Profiler function .* will be ignored")
11
+ logging.getLogger("torch._dynamo.variables.torch").addFilter(
12
+ lambda record: not pattern.search(record.getMessage())
13
+ )
14
+
15
+ # Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint
16
+ logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True
17
+ logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True
18
+
19
+ __all__ = ["GPT", "Config", "Tokenizer"]
litgpt/config.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any, Literal, Optional, Type, Union
7
+
8
+ import torch
9
+ import yaml
10
+ from typing_extensions import Self
11
+
12
+ import litgpt.model
13
+ from litgpt.utils import find_multiple
14
+
15
+
16
+ @dataclass
17
+ class Config:
18
+ name: str = ""
19
+ hf_config: dict = field(default_factory=dict)
20
+ scale_embeddings: bool = False
21
+ block_size: int = 4096
22
+ vocab_size: int = 50254
23
+ padding_multiple: int = 512
24
+ padded_vocab_size: Optional[int] = None
25
+ n_layer: int = 16
26
+ n_head: int = 32
27
+ head_size: Optional[int] = None
28
+ n_embd: int = 4096
29
+ rotary_percentage: float = 0.25
30
+ parallel_residual: bool = True
31
+ bias: bool = True
32
+ lm_head_bias: bool = False
33
+ # to use multi-head attention (MHA), set this to `n_head` (default)
34
+ # to use multi-query attention (MQA), set this to 1
35
+ # to use grouped-query attention (GQA), set this to a value in between
36
+ # Example with `n_head=4`
37
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
38
+ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
39
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
40
+ # │ │ │ │ │ │ │
41
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
42
+ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
43
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
44
+ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
45
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
46
+ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
47
+ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
48
+ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
49
+ # MHA GQA MQA
50
+ # n_query_groups=4 n_query_groups=2 n_query_groups=1
51
+ #
52
+ # credit https://arxiv.org/pdf/2305.13245.pdf
53
+ n_query_groups: Optional[int] = None
54
+ shared_attention_norm: bool = False
55
+ norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
56
+ norm_eps: float = 1e-5
57
+ mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
58
+ "GptNeoxMLP"
59
+ )
60
+ gelu_approximate: str = "none"
61
+ intermediate_size: Optional[int] = None
62
+ rope_condense_ratio: int = 1
63
+ rope_base: int = 10000
64
+ n_expert: int = 0
65
+ n_expert_per_token: int = 0
66
+
67
+ add_qkv_bias: Optional[bool] = None
68
+ prompt_vocab_size: Optional[int] = None
69
+ attn_dropout: float = 0.0
70
+ pos_type: str = "rope"
71
+ force_align: bool = False
72
+ use_pretrain_phoneme_emb: bool = False
73
+ tie_word_embeddings: bool = False
74
+
75
+ # setting for mini-omni
76
+ text_vocab_size:int = 152000
77
+ cat_audio_vocab_size: int = 29120
78
+ audio_vocab_size: int = 4160
79
+ whisper_adapter_dim: int = 768
80
+
81
+ post_adapter: bool = False
82
+ post_adapter_layers: int = 6
83
+ asr_adapter: str = "llamamlp"
84
+
85
+ def __post_init__(self):
86
+ if not self.name:
87
+ self.name = self.hf_config.get("name", self.name)
88
+
89
+ if self.head_size is None:
90
+ assert self.n_embd % self.n_head == 0
91
+ self.head_size = self.n_embd // self.n_head
92
+
93
+ # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
94
+ if self.padded_vocab_size is None:
95
+ self.padded_vocab_size = find_multiple(
96
+ self.vocab_size, self.padding_multiple
97
+ )
98
+ else:
99
+ # vocab size shouldn't be larger than padded vocab size
100
+ self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
101
+
102
+ # compute the number of query groups
103
+ if self.n_query_groups is not None:
104
+ assert self.n_head % self.n_query_groups == 0
105
+ else:
106
+ self.n_query_groups = self.n_head
107
+
108
+ # compute the intermediate size for MLP if not set
109
+ if self.intermediate_size is None:
110
+ if self.mlp_class_name == "LLaMAMLP":
111
+ raise ValueError(
112
+ f"The config {self.name!r}, needs to set the `intermediate_size`"
113
+ )
114
+ self.intermediate_size = 4 * self.n_embd
115
+
116
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
117
+
118
+ if self.add_qkv_bias is None:
119
+ self.add_qkv_bias = self.bias
120
+
121
+ @classmethod
122
+ def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
123
+ if name not in name_to_config:
124
+ # search through all `config['hf_config']['name']`
125
+ try:
126
+ conf_dict = next(
127
+ config
128
+ for config in configs
129
+ if name == config["hf_config"]["name"]
130
+ or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
131
+ == name
132
+ )
133
+ except StopIteration:
134
+ raise ValueError(f"{name!r} is not a supported config name")
135
+ else:
136
+ conf_dict = name_to_config[name]
137
+
138
+ conf_dict = conf_dict.copy()
139
+ conf_dict.update(kwargs)
140
+ return cls(**conf_dict)
141
+
142
+ @classmethod
143
+ def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
144
+ with open(path, encoding="utf-8") as fp:
145
+ file_kwargs = yaml.safe_load(fp)
146
+ if file_kwargs is None:
147
+ raise ValueError(f"{path} is empty which is likely unexpected.")
148
+ file_kwargs.update(kwargs)
149
+ return cls(**file_kwargs)
150
+
151
+ @classmethod
152
+ def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
153
+ """Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
154
+ if (config_path := path / "model_config.yaml").is_file():
155
+ return cls.from_file(config_path, **kwargs)
156
+ if (model_name := path.name) in name_to_config:
157
+ return cls.from_name(model_name, **kwargs)
158
+ raise FileNotFoundError(
159
+ f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
160
+ )
161
+
162
+ @property
163
+ def mlp_class(self) -> Type:
164
+ # `self.mlp_class_name` cannot be the type to keep the config serializable
165
+ return getattr(litgpt.model, self.mlp_class_name)
166
+
167
+ @property
168
+ def norm_class(self) -> Type:
169
+ # `self.norm_class_name` cannot be the type to keep the config serializable
170
+ if self.norm_class_name == "RMSNorm":
171
+ from functools import partial
172
+
173
+ from litgpt.model import RMSNorm
174
+
175
+ return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
176
+ return getattr(torch.nn, self.norm_class_name)
177
+
178
+
179
+ configs = []
180
+ name_to_config = {config["name"]: config for config in configs}
litgpt/generate/__init__.py ADDED
File without changes
litgpt/generate/base.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from typing import Any, Literal, Optional
4
+
5
+ import torch
6
+ # import torch._dynamo.config
7
+ # import torch._inductor.config
8
+
9
+ from litgpt.model import GPT
10
+ from utils.snac_utils import layershift, snac_config
11
+ from tqdm import tqdm
12
+
13
+
14
+ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
15
+ if torch._dynamo.is_compiling():
16
+ # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
17
+ distribution = torch.empty_like(probs).exponential_(1)
18
+ return torch.argmax(probs / distribution, dim=-1, keepdim=True)
19
+ return torch.multinomial(probs, num_samples=1)
20
+
21
+
22
+ def sample_top_p(logits_A: torch.Tensor, top_p: float) -> torch.Tensor:
23
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
24
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
25
+ # Example:
26
+ # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
27
+ # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
28
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
29
+ # Keep at least 1 token always to prevent the case where no token is selected
30
+ # In this case the most probable one is always kept
31
+ sorted_indices_to_remove[-1:] = 0
32
+ indices_to_remove = sorted_indices_to_remove.scatter(
33
+ 0, sorted_indices, sorted_indices_to_remove
34
+ )
35
+ logits = logits.masked_fill(indices_to_remove, float("-inf"))
36
+ return logits
37
+
38
+
39
+ def sample(
40
+ logits: torch.Tensor,
41
+ temperature: float = 1.0,
42
+ top_k: Optional[int] = None,
43
+ top_p: float = 1.0,
44
+ ) -> torch.Tensor:
45
+ if top_p < 0.0 or top_p > 1.0:
46
+ raise ValueError(f"top_p must be in [0, 1], got {top_p}")
47
+ logits = logits[0, -1]
48
+ # optionally crop the logits to only the top k options
49
+ if top_k is not None:
50
+ v, i = torch.topk(logits, min(top_k, logits.size(-1)))
51
+ # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
52
+ logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
53
+ # optionally scale the logits and sample from a probability distribution
54
+ if temperature > 0.0 or top_p > 0.0:
55
+ if temperature > 0.0:
56
+ logits = logits / temperature
57
+ # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
58
+ if top_p < 1.0:
59
+ logits = sample_top_p(logits, top_p)
60
+ probs = torch.nn.functional.softmax(logits, dim=-1)
61
+ return multinomial_num_samples_1(probs)
62
+ return torch.argmax(logits, dim=-1, keepdim=True)
63
+
64
+
65
+ def next_token(
66
+ model: GPT, input_pos: torch.Tensor, x: list, **kwargs: Any
67
+ ) -> torch.Tensor:
68
+ input_pos = input_pos.to(model.device)
69
+ logits_a, logit_t = model(x, input_pos)
70
+
71
+ next_audio_tokens = []
72
+ for logit_a in logits_a:
73
+ next_a = sample(logit_a, **kwargs).to(dtype=x[0].dtype)
74
+ next_audio_tokens.append(next_a)
75
+ next_t = sample(logit_t, **kwargs).to(dtype=x[0].dtype)
76
+ return next_audio_tokens, next_t
77
+
78
+
79
+ def next_token_asr(
80
+ model: GPT,
81
+ input_pos: torch.Tensor,
82
+ audio_features: torch.tensor,
83
+ lens: int,
84
+ input_ids: list,
85
+ **kwargs: Any,
86
+ ) -> torch.Tensor:
87
+ input_pos = input_pos.to(model.device)
88
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
89
+ logits_a, logit_t = model(audio_features, input_ids, input_pos, whisper_lens=lens)
90
+
91
+ next_audio_tokens = []
92
+ for logit_a in logits_a:
93
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
94
+