Files changed (1) hide show
  1. app.py +57 -302
app.py CHANGED
@@ -3,300 +3,24 @@
3
  import os
4
  import time
5
  import gradio as gr
 
6
  import numpy as np
7
- import spaces
8
- import torch
9
 
10
- import os
11
- import lightning as L
12
- import torch
13
- import time
14
- import spaces
15
- from snac import SNAC
16
- from litgpt import Tokenizer
17
- from litgpt.utils import (
18
- num_parameters,
19
- )
20
- from litgpt.generate.base import (
21
- generate_AA,
22
- generate_ASR,
23
- generate_TA,
24
- generate_TT,
25
- generate_AT,
26
- generate_TA_BATCH,
27
- )
28
- from typing import Any, Literal, Optional
29
- import soundfile as sf
30
- from litgpt.model import GPT, Config
31
- from lightning.fabric.utilities.load import _lazy_load as lazy_load
32
- from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
33
- from utils.snac_utils import get_snac
34
- import whisper
35
- from tqdm import tqdm
36
- from huggingface_hub import snapshot_download
37
- from litgpt.generate.base import sample
38
 
 
 
39
 
40
- device = "cuda" if torch.cuda.is_available() else "cpu"
41
- ckpt_dir = "./checkpoint"
 
 
42
 
43
 
44
  OUT_CHUNK = 4096
45
  OUT_RATE = 24000
46
  OUT_CHANNELS = 1
47
 
48
- # TODO
49
- text_vocabsize = 151936
50
- text_specialtokens = 64
51
- audio_vocabsize = 4096
52
- audio_specialtokens = 64
53
-
54
- padded_text_vocabsize = text_vocabsize + text_specialtokens
55
- padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
56
-
57
- _eot = text_vocabsize
58
- _pad_t = text_vocabsize + 1
59
- _input_t = text_vocabsize + 2
60
- _answer_t = text_vocabsize + 3
61
- _asr = text_vocabsize + 4
62
-
63
- _eoa = audio_vocabsize
64
- _pad_a = audio_vocabsize + 1
65
- _input_a = audio_vocabsize + 2
66
- _answer_a = audio_vocabsize + 3
67
- _split = audio_vocabsize + 4
68
-
69
-
70
- def download_model(ckpt_dir):
71
- repo_id = "gpt-omni/mini-omni"
72
- snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
73
-
74
-
75
- if not os.path.exists(ckpt_dir):
76
- print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
77
- download_model(ckpt_dir)
78
-
79
-
80
- snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
81
- whispermodel = whisper.load_model("small").to(device)
82
- whispermodel.eval()
83
- text_tokenizer = Tokenizer(ckpt_dir)
84
- # fabric = L.Fabric(devices=1, strategy="auto")
85
- config = Config.from_file(ckpt_dir + "/model_config.yaml")
86
- config.post_adapter = False
87
-
88
- model = GPT(config, device=device)
89
-
90
- state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
91
- model.load_state_dict(state_dict, strict=True)
92
- model = model.to(device)
93
- model.eval()
94
-
95
-
96
- def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
97
- # with torch.no_grad():
98
- mel = mel.unsqueeze(0).to(device)
99
- # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
100
- audio_feature = whispermodel.embed_audio(mel)[0][:leng]
101
- T = audio_feature.size(0)
102
- input_ids_AA = []
103
- for i in range(7):
104
- input_ids_item = []
105
- input_ids_item.append(layershift(_input_a, i))
106
- input_ids_item += [layershift(_pad_a, i)] * T
107
- input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
108
- input_ids_AA.append(torch.tensor(input_ids_item))
109
- input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
110
- input_ids_AA.append(input_id_T)
111
-
112
- input_ids_AT = []
113
- for i in range(7):
114
- input_ids_item = []
115
- input_ids_item.append(layershift(_input_a, i))
116
- input_ids_item += [layershift(_pad_a, i)] * T
117
- input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
118
- input_ids_AT.append(torch.tensor(input_ids_item))
119
- input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
120
- input_ids_AT.append(input_id_T)
121
-
122
- input_ids = [input_ids_AA, input_ids_AT]
123
- stacked_inputids = [[] for _ in range(8)]
124
- for i in range(2):
125
- for j in range(8):
126
- stacked_inputids[j].append(input_ids[i][j])
127
- stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
128
- return torch.stack([audio_feature, audio_feature]), stacked_inputids
129
-
130
-
131
- def next_token_batch(
132
- model: GPT,
133
- audio_features: torch.tensor,
134
- input_ids: list,
135
- whisper_lens: int,
136
- task: list,
137
- input_pos: torch.Tensor,
138
- **kwargs: Any,
139
- ) -> torch.Tensor:
140
- input_pos = input_pos.to(model.device)
141
- input_ids = [input_id.to(model.device) for input_id in input_ids]
142
- logits_a, logit_t = model(
143
- audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task
144
- )
145
-
146
- for i in range(7):
147
- logits_a[i] = logits_a[i][0].unsqueeze(0)
148
- logit_t = logit_t[1].unsqueeze(0)
149
-
150
- next_audio_tokens = []
151
- for logit_a in logits_a:
152
- next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
153
- next_audio_tokens.append(next_a)
154
- next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
155
- return next_audio_tokens, next_t
156
-
157
-
158
- def load_audio(path):
159
- audio = whisper.load_audio(path)
160
- duration_ms = (len(audio) / 16000) * 1000
161
- audio = whisper.pad_or_trim(audio)
162
- mel = whisper.log_mel_spectrogram(audio)
163
- return mel, int(duration_ms / 20) + 1
164
-
165
-
166
- def generate_audio_data(snac_tokens, snacmodel, device=None):
167
- audio = reconstruct_tensors(snac_tokens, device)
168
- with torch.inference_mode():
169
- audio_hat = snacmodel.decode(audio)
170
- audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
171
- audio_data = audio_data.astype(np.int16)
172
- audio_data = audio_data.tobytes()
173
- return audio_data
174
-
175
-
176
- @spaces.GPU
177
- @torch.inference_mode()
178
- def run_AT_batch_stream(
179
- audio_path,
180
- stream_stride=4,
181
- max_returned_tokens=2048,
182
- temperature=0.9,
183
- top_k=1,
184
- top_p=1.0,
185
- eos_id_a=_eoa,
186
- eos_id_t=_eot,
187
- ):
188
-
189
- assert os.path.exists(audio_path), f"audio file {audio_path} not found"
190
-
191
- model.set_kv_cache(batch_size=2, device=device)
192
-
193
- mel, leng = load_audio(audio_path)
194
- audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
195
- T = input_ids[0].size(1)
196
- # device = input_ids[0].device
197
-
198
- assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
199
-
200
- if model.max_seq_length < max_returned_tokens - 1:
201
- raise NotImplementedError(
202
- f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
203
- )
204
-
205
- input_pos = torch.tensor([T], device=device)
206
- list_output = [[] for i in range(8)]
207
- tokens_A, token_T = next_token_batch(
208
- model,
209
- audio_feature.to(torch.float32).to(model.device),
210
- input_ids,
211
- [T - 3, T - 3],
212
- ["A1T2", "A1T2"],
213
- input_pos=torch.arange(0, T, device=device),
214
- temperature=temperature,
215
- top_k=top_k,
216
- top_p=top_p,
217
- )
218
-
219
- for i in range(7):
220
- list_output[i].append(tokens_A[i].tolist()[0])
221
- list_output[7].append(token_T.tolist()[0])
222
-
223
- model_input_ids = [[] for i in range(8)]
224
- for i in range(7):
225
- tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
226
- model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
227
- model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
228
- model_input_ids[i] = torch.stack(model_input_ids[i])
229
-
230
- model_input_ids[-1].append(token_T.clone().to(torch.int32))
231
- model_input_ids[-1].append(token_T.clone().to(torch.int32))
232
- model_input_ids[-1] = torch.stack(model_input_ids[-1])
233
-
234
- text_end = False
235
- index = 1
236
- nums_generate = stream_stride
237
- begin_generate = False
238
- current_index = 0
239
- for _ in tqdm(range(2, max_returned_tokens - T + 1)):
240
- tokens_A, token_T = next_token_batch(
241
- model,
242
- None,
243
- model_input_ids,
244
- None,
245
- None,
246
- input_pos=input_pos,
247
- temperature=temperature,
248
- top_k=top_k,
249
- top_p=top_p,
250
- )
251
-
252
- if text_end:
253
- token_T = torch.tensor([_pad_t], device=device)
254
-
255
- if tokens_A[-1] == eos_id_a:
256
- break
257
-
258
- if token_T == eos_id_t:
259
- text_end = True
260
-
261
- for i in range(7):
262
- list_output[i].append(tokens_A[i].tolist()[0])
263
- list_output[7].append(token_T.tolist()[0])
264
-
265
- model_input_ids = [[] for i in range(8)]
266
- for i in range(7):
267
- tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
268
- model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
269
- model_input_ids[i].append(
270
- torch.tensor([layershift(4097, i)], device=device)
271
- )
272
- model_input_ids[i] = torch.stack(model_input_ids[i])
273
-
274
- model_input_ids[-1].append(token_T.clone().to(torch.int32))
275
- model_input_ids[-1].append(token_T.clone().to(torch.int32))
276
- model_input_ids[-1] = torch.stack(model_input_ids[-1])
277
-
278
- if index == 7:
279
- begin_generate = True
280
-
281
- if begin_generate:
282
- current_index += 1
283
- if current_index == nums_generate:
284
- current_index = 0
285
- snac = get_snac(list_output, index, nums_generate)
286
- audio_stream = generate_audio_data(snac, snacmodel, device)
287
- yield audio_stream
288
-
289
- input_pos = input_pos.add_(1)
290
- index += 1
291
- text = text_tokenizer.decode(torch.tensor(list_output[-1]))
292
- print(f"text output: {text}")
293
- model.clear_kv_cache()
294
- return list_output
295
-
296
-
297
- for chunk in run_AT_batch_stream('./data/samples/output1.wav'):
298
- pass
299
-
300
 
301
  def process_audio(audio):
302
  filepath = audio
@@ -305,23 +29,54 @@ def process_audio(audio):
305
  return
306
 
307
  cnt = 0
308
- tik = time.time()
309
- for chunk in run_AT_batch_stream(filepath):
310
- # Convert chunk to numpy array
311
- if cnt == 0:
312
- print(f"first chunk time cost: {time.time() - tik:.3f}")
313
- cnt += 1
314
- audio_data = np.frombuffer(chunk, dtype=np.int16)
315
- audio_data = audio_data.reshape(-1, OUT_CHANNELS)
316
- yield OUT_RATE, audio_data.astype(np.int16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
 
 
318
 
319
- demo = gr.Interface(
320
- process_audio,
321
- inputs=gr.Audio(type="filepath", label="Microphone"),
322
- outputs=[gr.Audio(label="Response", streaming=True, autoplay=True)],
323
- title="Chat Mini-Omni Demo",
324
- # live=True,
325
- )
326
- demo.queue()
327
- demo.launch()
 
3
  import os
4
  import time
5
  import gradio as gr
6
+ import base64
7
  import numpy as np
8
+ import requests
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ API_URL = os.getenv("API_URL", None)
12
+ client = None
13
 
14
+ if API_URL is None:
15
+ from inference import OmniInference
16
+ omni_client = OmniInference('./checkpoint', 'cuda:0')
17
+ omni_client.warm_up()
18
 
19
 
20
  OUT_CHUNK = 4096
21
  OUT_RATE = 24000
22
  OUT_CHANNELS = 1
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def process_audio(audio):
26
  filepath = audio
 
29
  return
30
 
31
  cnt = 0
32
+ if API_URL is not None:
33
+ with open(filepath, "rb") as f:
34
+ data = f.read()
35
+ base64_encoded = str(base64.b64encode(data), encoding="utf-8")
36
+ files = {"audio": base64_encoded}
37
+ tik = time.time()
38
+ with requests.post(API_URL, json=files, stream=True) as response:
39
+ try:
40
+ for chunk in response.iter_content(chunk_size=OUT_CHUNK):
41
+ if chunk:
42
+ # Convert chunk to numpy array
43
+ if cnt == 0:
44
+ print(f"first chunk time cost: {time.time() - tik:.3f}")
45
+ cnt += 1
46
+ audio_data = np.frombuffer(chunk, dtype=np.int16)
47
+ audio_data = audio_data.reshape(-1, OUT_CHANNELS)
48
+ yield OUT_RATE, audio_data.astype(np.int16)
49
+
50
+ except Exception as e:
51
+ print(f"error: {e}")
52
+ else:
53
+ tik = time.time()
54
+ for chunk in omni_client.run_AT_batch_stream(filepath):
55
+ # Convert chunk to numpy array
56
+ if cnt == 0:
57
+ print(f"first chunk time cost: {time.time() - tik:.3f}")
58
+ cnt += 1
59
+ audio_data = np.frombuffer(chunk, dtype=np.int16)
60
+ audio_data = audio_data.reshape(-1, OUT_CHANNELS)
61
+ yield OUT_RATE, audio_data.astype(np.int16)
62
+
63
+
64
+ def main(port=None):
65
+
66
+ demo = gr.Interface(
67
+ process_audio,
68
+ inputs=gr.Audio(type="filepath", label="Microphone"),
69
+ outputs=[gr.Audio(label="Response", streaming=True, autoplay=True)],
70
+ title="Chat Mini-Omni Demo",
71
+ live=True,
72
+ )
73
+ if port is not None:
74
+ demo.queue().launch(share=False, server_name="0.0.0.0", server_port=port)
75
+ else:
76
+ demo.queue().launch()
77
+
78
 
79
+ if __name__ == "__main__":
80
+ import fire
81
 
82
+ fire.Fire(main)