WJ88 commited on
Commit
8f87442
·
verified ·
1 Parent(s): 7a412d6

test, previous was working ok

Browse files
Files changed (1) hide show
  1. app.py +147 -96
app.py CHANGED
@@ -1,10 +1,10 @@
 
1
 
2
- """Gradio Blocks app for streaming ASR with NVIDIA NeMo Parakeet-TDT-0.6B-v3.
3
-
4
- Fixes for HF Spaces + Gradio SSR:
5
- - Uses Blocks + .stream() API for input streaming.
6
- - Forces client-side rendering by setting ssr_mode=False on launch.
7
- - Accepts both (sr, np.ndarray) and {"sampling_rate","data"} chunk formats.
8
  """
9
  from __future__ import annotations
10
 
@@ -14,12 +14,10 @@ from typing import Optional, Tuple, Union, Dict
14
  import copy
15
  import numpy as np
16
  import torch
17
- # torch.set_num_threads(2)
18
  import torchaudio
19
  import gradio as gr
20
 
21
  import nemo.collections.asr as nemo_asr
22
- # from omegaconf import OmegaConf
23
  from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
24
  from nemo.collections.asr.parts.utils.rnnt_utils import batched_hyps_to_hypotheses
25
  from nemo.collections.asr.parts.utils.streaming_utils import ContextSize, StreamingBatchedAudioBuffer
@@ -33,13 +31,13 @@ Chunk = Union[Tuple[int, np.ndarray], Dict[str, ArrayLike]]
33
  # ----------------------------
34
  @dataclass
35
  class AppConfig:
36
- model_name: str = "nvidia/parakeet-tdt-0.6b-v3"
37
  left_s: float = 10.0
38
  chunk_s: float = 2.0
39
  right_s: float = 2.0
40
  max_buffer_s: float = 40.0
41
  batch_size: int = 1
42
- device: str = "cpu" # set "cuda" to force GPU if available
43
 
44
 
45
  # ----------------------------
@@ -50,47 +48,64 @@ def _floor_multiple(a: int, b: int) -> int:
50
 
51
 
52
  # ----------------------------
53
- # ASR Engine
54
  # ----------------------------
55
  class ParakeetStreamer:
56
  def __init__(self, cfg: AppConfig) -> None:
57
  self.cfg = cfg
 
 
 
 
 
58
 
59
- # Load model
60
  self.model = (
61
  nemo_asr.models.EncDecRNNTModel.from_pretrained(cfg.model_name)
62
- .to(cfg.device)
63
  .eval()
64
  )
65
  for p in self.model.parameters():
66
  p.requires_grad_(False)
67
 
 
 
 
 
 
 
 
 
68
  # Decoding strategy: greedy-batch with label-looping
69
  dec_cfg = RNNTDecodingConfig(
70
- strategy="greedy_batch", fused_batch_size=-1, compute_timestamps=False
 
 
71
  )
72
  dec_cfg.greedy.loop_labels = True
 
73
  self.model.change_decoding_strategy(dec_cfg)
74
  self._decoding_computer = self.model.decoding.decoding.decoding_computer
75
 
76
- # Clone + tweak preprocessor for inference
77
- mcfg = copy.deepcopy(self.model.cfg)
78
- # OmegaConf.set_struct(mcfg.preprocessor, False)
79
- # mcfg.preprocessor.dither = 0.0
80
- # mcfg.preprocessor.pad_to = 0
81
- # OmegaConf.set_struct(mcfg.preprocessor, True)
82
 
83
- # Derived constants
84
  self.sample_rate: int = int(mcfg.preprocessor.sample_rate)
85
  window_stride: float = float(mcfg.preprocessor.window_stride)
86
  self.frames_per_second: float = 1.0 / window_stride
 
 
 
 
87
  self.subsampling: int = int(self.model.encoder.subsampling_factor)
88
 
89
- # Encoder-step to audio alignment
90
  feat_f2a = _floor_multiple(int(self.sample_rate * window_stride), self.subsampling)
91
  self.enc_f2a = feat_f2a * self.subsampling
92
 
93
- # Context sizes (encoder and samples)
94
  self.ctx_enc = ContextSize(
95
  left=int(cfg.left_s * self.frames_per_second / self.subsampling),
96
  chunk=int(cfg.chunk_s * self.frames_per_second / self.subsampling),
@@ -104,119 +119,155 @@ class ParakeetStreamer:
104
 
105
  self.max_samples = int(cfg.max_buffer_s * self.sample_rate)
106
 
107
- # -------- audio helpers --------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  @staticmethod
109
  def _to_mono(x: np.ndarray) -> np.ndarray:
110
  x = np.asarray(x)
111
  if x.ndim == 2:
112
- # handle (samples, channels) or (channels, samples)
113
- if x.shape[0] == 2 and x.shape[1] != 2:
114
- # ambiguous case; fallback to last axis
115
- x = x.mean(axis=-1)
116
- else:
117
- x = x.mean(axis=-1 if x.shape[-1] in (1, 2) else 1)
118
  return x.astype(np.float32, copy=False)
119
 
120
  def _resample_if_needed(self, x: np.ndarray, in_sr: int) -> np.ndarray:
121
- if int(in_sr) == self.sample_rate:
 
122
  return x
123
- y = torchaudio.functional.resample(torch.from_numpy(x), in_sr, self.sample_rate)
 
 
 
124
  return y.numpy().astype(np.float32, copy=False)
125
 
126
  @staticmethod
127
  def _parse_chunk(new_chunk: Chunk) -> Tuple[int, np.ndarray]:
128
- # Accept tuple (sr, np.ndarray) or dict {"sampling_rate": int, "data": array-like}
129
  if isinstance(new_chunk, dict):
130
  sr = int(new_chunk.get("sampling_rate") or new_chunk.get("sample_rate"))
131
  data = new_chunk["data"]
132
  if isinstance(data, torch.Tensor):
133
  data = data.detach().cpu().numpy()
134
  return sr, np.asarray(data)
135
- # assume (sr, np.ndarray)
136
  sr, data = new_chunk
137
  return int(sr), np.asarray(data)
138
 
139
- # -------- core decoding --------
140
  @torch.inference_mode()
141
- def _decode_buffer(self, audio_np: np.ndarray) -> str:
142
- if audio_np.size == 0:
143
  return ""
144
 
145
- a = torch.from_numpy(audio_np).unsqueeze(0).to(torch.float32).to(self.cfg.device)
146
- total_len = torch.tensor([a.shape[1]], dtype=torch.long, device=self.cfg.device)
147
-
148
- cur_hyps = None
149
- prev_state = None
150
-
151
- l = 0
152
- r = min(self.ctx_samp.chunk + self.ctx_samp.right, a.shape[1])
153
-
154
- buf = StreamingBatchedAudioBuffer(
155
- batch_size=self.cfg.batch_size,
156
- context_samples=self.ctx_samp,
157
- dtype=a.dtype,
158
- device=self.cfg.device,
159
- )
160
-
161
- remaining = total_len.clone()
162
-
163
- while l < a.shape[1]:
164
- clen = int(min(r, a.shape[1]) - l)
165
- is_last = r >= a.shape[1]
166
-
167
- is_last_b = torch.tensor([clen >= remaining[0]], dtype=torch.bool, device=self.cfg.device)
168
- clen_b = torch.where(is_last_b, remaining, torch.full_like(remaining, fill_value=clen))
169
-
170
- buf.add_audio_batch_(
171
- a[:, l:r], audio_lengths=clen_b, is_last_chunk=is_last, is_last_chunk_batch=is_last_b
 
 
 
 
172
  )
173
 
174
- enc, _ = self.model(input_signal=buf.samples, input_signal_length=buf.context_size_batch.total())
 
 
 
175
  enc = enc.transpose(1, 2) # [B, T, C]
176
 
177
- enc_ctx = buf.context_size.subsample(factor=self.enc_f2a)
178
- enc_ctx_b = buf.context_size_batch.subsample(factor=self.enc_f2a)
179
 
180
- enc = enc[:, enc_ctx.left:] # drop left context before decoding
 
181
 
182
- hyps, _, prev_state = self._decoding_computer(
183
- x=enc, out_len=enc_ctx_b.chunk, prev_batched_state=prev_state
184
  )
185
 
186
- if cur_hyps is None:
187
- cur_hyps = hyps
188
  else:
189
- cur_hyps.merge_(hyps)
190
 
191
- remaining -= clen_b
192
- l = r
193
- r = min(r + self.ctx_samp.chunk, a.shape[1])
194
 
195
- outs = batched_hyps_to_hypotheses(cur_hyps, None, batch_size=self.cfg.batch_size) if cur_hyps is not None else []
 
 
 
 
196
  for h in outs:
197
  h.text = self.model.tokenizer.ids_to_text(h.y_sequence.tolist())
198
-
199
  return outs[0].text if outs else ""
200
 
201
- # -------- public streaming API (stateless) --------
202
- def transcribe(self, stream: Optional[np.ndarray], new_chunk: Optional[Chunk]):
 
 
 
203
  if new_chunk is None:
204
- return stream, ""
205
 
206
  in_sr, data = self._parse_chunk(new_chunk)
207
  y = self._to_mono(data)
208
- y = self._resample_if_needed(y, int(in_sr))
209
 
210
- if stream is None or len(stream) == 0:
211
- a = y
212
  else:
213
- a = np.concatenate([stream, y])
214
-
215
- if a.size > self.max_samples:
216
- a = a[-self.max_samples:]
 
 
 
217
 
218
- text = self._decode_buffer(a) if a.size else ""
219
- return a, text
220
 
221
 
222
  # ----------------------------
@@ -226,26 +277,26 @@ def build_demo(cfg: Optional[AppConfig] = None):
226
  cfg = cfg or AppConfig()
227
  engine = ParakeetStreamer(cfg)
228
 
229
- with gr.Blocks(title="Parakeet-TDT-0.6B-v3 — CPU streaming") as demo:
230
- gr.Markdown("**Multilingual buffered streaming (10-2-2) in memory**")
231
  with gr.Row():
232
  mic = gr.Audio(
233
  sources=["microphone"],
234
  type="numpy",
235
  streaming=True,
236
- label="Mic",
237
  recording=False,
238
  )
239
  out = gr.Textbox(label="Transcript", lines=3)
240
  state = gr.State(value=None)
241
 
242
- # Stream mic to backend every ~1s for more context and lower CPU churn
243
  mic.stream(
244
  fn=engine.transcribe,
245
  inputs=[state, mic],
246
  outputs=[state, out],
247
  stream_every=1.0,
248
- time_limit=120,
249
  concurrency_limit=1,
250
  )
251
 
@@ -254,6 +305,6 @@ def build_demo(cfg: Optional[AppConfig] = None):
254
 
255
  if __name__ == "__main__":
256
  demo = build_demo()
257
- # Disable SSR explicitly to avoid Audio preprocessing via file paths on HF Spaces.
258
  demo.queue()
 
259
  demo.launch(ssr_mode=False, show_api=False)
 
1
+ """Gradio Blocks app for CPU streaming ASR with NVIDIA NeMo Parakeet RNNT (label-looping greedy).
2
 
3
+ Key points (CPU-tailored):
4
+ - Incremental decoding with StreamingBatchedAudioBuffer and persistent RNNT state.
5
+ - Avoids re-decoding the entire accumulated buffer each chunk.
6
+ - Disables dither/padding for streaming and reuses a torchaudio resampler.
7
+ - Works on Hugging Face Spaces (CPU) with SSR disabled.
 
8
  """
9
  from __future__ import annotations
10
 
 
14
  import copy
15
  import numpy as np
16
  import torch
 
17
  import torchaudio
18
  import gradio as gr
19
 
20
  import nemo.collections.asr as nemo_asr
 
21
  from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
22
  from nemo.collections.asr.parts.utils.rnnt_utils import batched_hyps_to_hypotheses
23
  from nemo.collections.asr.parts.utils.streaming_utils import ContextSize, StreamingBatchedAudioBuffer
 
31
  # ----------------------------
32
  @dataclass
33
  class AppConfig:
34
+ model_name: str = "nvidia/parakeet-tdt-0.6b-v3" # or any RNNT model compatible with label-looping greedy
35
  left_s: float = 10.0
36
  chunk_s: float = 2.0
37
  right_s: float = 2.0
38
  max_buffer_s: float = 40.0
39
  batch_size: int = 1
40
+ device: str = "cpu" # CPU-only for HF Spaces
41
 
42
 
43
  # ----------------------------
 
48
 
49
 
50
  # ----------------------------
51
+ # ASR Engine (stateful incremental streaming)
52
  # ----------------------------
53
  class ParakeetStreamer:
54
  def __init__(self, cfg: AppConfig) -> None:
55
  self.cfg = cfg
56
+ self.device = torch.device("cpu")
57
+ torch.set_grad_enabled(False)
58
+ torch.set_float32_matmul_precision("high")
59
+ # Optionally tune CPU threads (uncomment to adjust for your HF Space)
60
+ # torch.set_num_threads(max(1, (torch.get_num_threads() or 1)))
61
 
62
+ # Load model (RNNT)
63
  self.model = (
64
  nemo_asr.models.EncDecRNNTModel.from_pretrained(cfg.model_name)
65
+ .to(self.device)
66
  .eval()
67
  )
68
  for p in self.model.parameters():
69
  p.requires_grad_(False)
70
 
71
+ # Set streaming-friendly preprocessor params
72
+ try:
73
+ if hasattr(self.model, "preprocessor") and hasattr(self.model.preprocessor, "featurizer"):
74
+ self.model.preprocessor.featurizer.dither = 0.0
75
+ self.model.preprocessor.featurizer.pad_to = 0
76
+ except Exception:
77
+ pass
78
+
79
  # Decoding strategy: greedy-batch with label-looping
80
  dec_cfg = RNNTDecodingConfig(
81
+ strategy="greedy_batch",
82
+ fused_batch_size=-1,
83
+ compute_timestamps=False,
84
  )
85
  dec_cfg.greedy.loop_labels = True
86
+ dec_cfg.greedy.preserve_alignments = False
87
  self.model.change_decoding_strategy(dec_cfg)
88
  self._decoding_computer = self.model.decoding.decoding.decoding_computer
89
 
90
+ # Clone + read model cfg for derived params
91
+ mcfg = copy.deepcopy(getattr(self.model, "_cfg", getattr(self.model, "cfg", None)))
92
+ if mcfg is None:
93
+ raise RuntimeError("Unable to access model config. Update NeMo or provide a compatible RNNT model.")
 
 
94
 
 
95
  self.sample_rate: int = int(mcfg.preprocessor.sample_rate)
96
  window_stride: float = float(mcfg.preprocessor.window_stride)
97
  self.frames_per_second: float = 1.0 / window_stride
98
+
99
+ # Encoder subsampling factor
100
+ if not hasattr(self.model, "encoder") or not hasattr(self.model.encoder, "subsampling_factor"):
101
+ raise RuntimeError("Model encoder must expose subsampling_factor for streaming alignment.")
102
  self.subsampling: int = int(self.model.encoder.subsampling_factor)
103
 
104
+ # Map encoder frames to audio samples
105
  feat_f2a = _floor_multiple(int(self.sample_rate * window_stride), self.subsampling)
106
  self.enc_f2a = feat_f2a * self.subsampling
107
 
108
+ # Context sizes (encoder frames and audio samples)
109
  self.ctx_enc = ContextSize(
110
  left=int(cfg.left_s * self.frames_per_second / self.subsampling),
111
  chunk=int(cfg.chunk_s * self.frames_per_second / self.subsampling),
 
119
 
120
  self.max_samples = int(cfg.max_buffer_s * self.sample_rate)
121
 
122
+ # Persistent streaming state
123
+ self._stream_np: Optional[np.ndarray] = None
124
+ self._buf: Optional[StreamingBatchedAudioBuffer] = None
125
+ self._prev_state = None
126
+ self._cur_hyps = None
127
+ self._l = 0 # left cursor (samples)
128
+ self._r = 0 # right cursor (samples)
129
+
130
+ # Cached resampler
131
+ self._resampler: Optional[torchaudio.transforms.Resample] = None
132
+ self._resampler_in_sr: Optional[int] = None
133
+
134
+ def reset(self):
135
+ self._stream_np = None
136
+ self._buf = None
137
+ self._prev_state = None
138
+ self._cur_hyps = None
139
+ self._l = 0
140
+ self._r = 0
141
+ self._resampler = None
142
+ self._resampler_in_sr = None
143
+
144
  @staticmethod
145
  def _to_mono(x: np.ndarray) -> np.ndarray:
146
  x = np.asarray(x)
147
  if x.ndim == 2:
148
+ # average over last axis
149
+ x = x.mean(axis=-1 if x.shape[-1] in (1, 2) else 1)
 
 
 
 
150
  return x.astype(np.float32, copy=False)
151
 
152
  def _resample_if_needed(self, x: np.ndarray, in_sr: int) -> np.ndarray:
153
+ in_sr = int(in_sr)
154
+ if in_sr == self.sample_rate:
155
  return x
156
+ if self._resampler is None or self._resampler_in_sr != in_sr:
157
+ self._resampler = torchaudio.transforms.Resample(orig_freq=in_sr, new_freq=self.sample_rate)
158
+ self._resampler_in_sr = in_sr
159
+ y = self._resampler(torch.from_numpy(x))
160
  return y.numpy().astype(np.float32, copy=False)
161
 
162
  @staticmethod
163
  def _parse_chunk(new_chunk: Chunk) -> Tuple[int, np.ndarray]:
164
+ # Accept dict {"sampling_rate"|"sample_rate", "data"} or tuple (sr, np.ndarray)
165
  if isinstance(new_chunk, dict):
166
  sr = int(new_chunk.get("sampling_rate") or new_chunk.get("sample_rate"))
167
  data = new_chunk["data"]
168
  if isinstance(data, torch.Tensor):
169
  data = data.detach().cpu().numpy()
170
  return sr, np.asarray(data)
 
171
  sr, data = new_chunk
172
  return int(sr), np.asarray(data)
173
 
 
174
  @torch.inference_mode()
175
+ def _decode_increment(self) -> str:
176
+ if self._stream_np is None or self._stream_np.size == 0:
177
  return ""
178
 
179
+ # Lazily initialize buffer and cursors
180
+ if self._buf is None:
181
+ self._buf = StreamingBatchedAudioBuffer(
182
+ batch_size=self.cfg.batch_size,
183
+ context_samples=self.ctx_samp,
184
+ dtype=torch.float32,
185
+ device=self.device,
186
+ )
187
+ self._l = 0
188
+ # First decode when we have chunk+right samples available
189
+ self._r = self.ctx_samp.chunk + self.ctx_samp.right
190
+
191
+ a = torch.from_numpy(self._stream_np).unsqueeze(0).to(torch.float32).to(self.device)
192
+
193
+ # Decode as long as we have enough samples for the next window [left: right]
194
+ while self._l < a.shape[1]:
195
+ if a.shape[1] < self._r:
196
+ break # wait for more right-context samples
197
+ clen = int(self._r - self._l)
198
+ if clen <= 0:
199
+ break
200
+
201
+ is_last_chunk = False # not final; mic keeps streaming
202
+ is_last_b = torch.tensor([False], dtype=torch.bool, device=self.device)
203
+ clen_b = torch.tensor([clen], dtype=torch.long, device=self.device)
204
+
205
+ self._buf.add_audio_batch_(
206
+ a[:, self._l:self._r],
207
+ audio_lengths=clen_b,
208
+ is_last_chunk=is_last_chunk,
209
+ is_last_chunk_batch=is_last_b,
210
  )
211
 
212
+ enc, _ = self.model(
213
+ input_signal=self._buf.samples,
214
+ input_signal_length=self._buf.context_size_batch.total(),
215
+ )
216
  enc = enc.transpose(1, 2) # [B, T, C]
217
 
218
+ enc_ctx = self._buf.context_size.subsample(factor=self.enc_f2a)
219
+ enc_ctx_b = self._buf.context_size_batch.subsample(factor=self.enc_f2a)
220
 
221
+ # Drop left context before decoding; decode only the chunk frames
222
+ enc = enc[:, enc_ctx.left:]
223
 
224
+ hyps, _, self._prev_state = self._decoding_computer(
225
+ x=enc, out_len=enc_ctx_b.chunk, prev_batched_state=self._prev_state
226
  )
227
 
228
+ if self._cur_hyps is None:
229
+ self._cur_hyps = hyps
230
  else:
231
+ self._cur_hyps.merge_(hyps)
232
 
233
+ # Advance to next chunk window
234
+ self._l = self._r
235
+ self._r = self._r + self.ctx_samp.chunk
236
 
237
+ outs = (
238
+ batched_hyps_to_hypotheses(self._cur_hyps, None, batch_size=self.cfg.batch_size)
239
+ if self._cur_hyps is not None
240
+ else []
241
+ )
242
  for h in outs:
243
  h.text = self.model.tokenizer.ids_to_text(h.y_sequence.tolist())
 
244
  return outs[0].text if outs else ""
245
 
246
+ # Public API for Gradio streaming callback (stateful)
247
+ def transcribe(self, state: Optional[np.ndarray], new_chunk: Optional[Chunk]):
248
+ # Reset when a new session starts
249
+ if state is None and self._cur_hyps is not None:
250
+ self.reset()
251
  if new_chunk is None:
252
+ return state, ""
253
 
254
  in_sr, data = self._parse_chunk(new_chunk)
255
  y = self._to_mono(data)
256
+ y = self._resample_if_needed(y, in_sr)
257
 
258
+ if self._stream_np is None or self._stream_np.size == 0:
259
+ self._stream_np = y
260
  else:
261
+ self._stream_np = np.concatenate([self._stream_np, y])
262
+ if self._stream_np.size > self.max_samples:
263
+ # Trim buffer and shift cursors accordingly
264
+ drop = self._stream_np.size - self.max_samples
265
+ self._stream_np = self._stream_np[-self.max_samples:]
266
+ self._l = max(0, self._l - drop)
267
+ self._r = max(self.ctx_samp.chunk + self.ctx_samp.right, self._r - drop)
268
 
269
+ text = self._decode_increment() if self._stream_np.size else ""
270
+ return self._stream_np, text
271
 
272
 
273
  # ----------------------------
 
277
  cfg = cfg or AppConfig()
278
  engine = ParakeetStreamer(cfg)
279
 
280
+ with gr.Blocks(title="Parakeet RNNT — CPU Streaming") as demo:
281
+ gr.Markdown("**Buffered streaming (10-2-2) on CPU with incremental decoding**")
282
  with gr.Row():
283
  mic = gr.Audio(
284
  sources=["microphone"],
285
  type="numpy",
286
  streaming=True,
287
+ label="Microphone",
288
  recording=False,
289
  )
290
  out = gr.Textbox(label="Transcript", lines=3)
291
  state = gr.State(value=None)
292
 
293
+ # Stream mic to backend periodically. Increase to 1.0 for lower CPU, decrease for lower latency.
294
  mic.stream(
295
  fn=engine.transcribe,
296
  inputs=[state, mic],
297
  outputs=[state, out],
298
  stream_every=1.0,
299
+ time_limit=180,
300
  concurrency_limit=1,
301
  )
302
 
 
305
 
306
  if __name__ == "__main__":
307
  demo = build_demo()
 
308
  demo.queue()
309
+ # Disable SSR to avoid file-path based Audio preprocessing on HF Spaces
310
  demo.launch(ssr_mode=False, show_api=False)