WJ88 commited on
Commit
bf6fbd5
·
verified ·
1 Parent(s): 0e644a1

back to giga chad verson

Browse files
Files changed (1) hide show
  1. app.py +63 -78
app.py CHANGED
@@ -1,46 +1,54 @@
 
1
  from __future__ import annotations
2
  import os
3
  import copy
4
  import uuid
5
  import logging
6
  from typing import List, Optional, Tuple, Dict
 
7
  # Reduce progress/log spam before heavy imports
8
  os.environ.setdefault("TQDM_DISABLE", "1")
9
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
 
10
  import numpy as np
11
  import torch
12
  import torchaudio
13
  import soundfile as sf
14
  import gradio as gr
 
15
  # NeMo
16
  from nemo.collections.asr.models import ASRModel
17
- from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis # For hypothesis handling
18
  from omegaconf import OmegaConf
19
  from nemo.utils import logging as nemo_logging
 
20
  # ----------------------------
21
  # Config
22
  # ----------------------------
23
- MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3")
24
- TARGET_SR = 16_000
25
- BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "32")) # Increased for subtle quality gains
26
  OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
27
- CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "4.0"))
28
- FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0"))
 
29
  # ----------------------------
30
  # Logging (unified)
31
  # ----------------------------
32
- LOG_LEVEL = os.environ.get("LOG_LEVEL", "DEBUG").upper()
33
  logger = logging.getLogger("parakeet_app")
34
  logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
35
  _handler = logging.StreamHandler()
36
  _handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
37
  logger.handlers = [_handler]
38
  logger.propagate = False
 
39
  # Quiet NeMo logs
40
  nemo_logging.setLevel(logging.ERROR)
41
  logging.getLogger("nemo").setLevel(logging.ERROR)
42
  logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR)
 
43
  torch.set_grad_enabled(False)
 
44
  # ----------------------------
45
  # Audio utils
46
  # ----------------------------
@@ -48,6 +56,7 @@ def to_mono_np(x: np.ndarray) -> np.ndarray:
48
  if x.ndim == 2:
49
  x = x.mean(axis=1)
50
  return x.astype(np.float32, copy=False)
 
51
  class ResamplerCache:
52
  def __init__(self):
53
  self._cache: Dict[int, torchaudio.transforms.Resample] = {}
@@ -62,19 +71,22 @@ class ResamplerCache:
62
  t = t.unsqueeze(0)
63
  y = self._cache[src_sr](t)
64
  return y.squeeze(0).numpy()
 
65
  RESAMPLER = ResamplerCache()
 
66
  def load_mono16k(path: str) -> np.ndarray:
67
  """Load any audio file, convert to mono float32 at 16 kHz."""
68
  try:
69
- wav, sr = sf.read(path, dtype="float32", always_2d=True) # (T,C)
70
  wav = wav.mean(axis=1).astype(np.float32, copy=False)
71
  return RESAMPLER.resample(wav, sr)
72
  except Exception:
73
- wav_t, sr = torchaudio.load(path) # (C,T)
74
  if wav_t.dtype != torch.float32:
75
  wav_t = wav_t.float()
76
  wav = wav_t.mean(dim=0).numpy()
77
  return RESAMPLER.resample(wav, int(sr))
 
78
  # ----------------------------
79
  # Model manager (MALSD batched beam everywhere, loop_labels=True)
80
  # ----------------------------
@@ -87,17 +99,22 @@ class ParakeetManager:
87
  self.model.eval()
88
  for p in self.model.parameters():
89
  p.requires_grad = False
 
90
  # Base decoding cfg differs by class
91
  if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"):
92
  self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg)
93
  else:
94
  self._base_decoding = copy.deepcopy(self.model.cfg.decoding)
 
95
  self._set_malsd_beam()
 
96
  # Enable encoder caching for better streaming context (per NeMo docs/tutorials)
97
  if hasattr(self.model.encoder, "set_default_att_context_size"):
98
- self.model.encoder.set_default_att_context_size([512, 16]) # Large left for cumulative context, small right for buffering
99
  logger.info("encoder_caching_enabled left=512 right=16")
 
100
  logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
 
101
  def _set_malsd_beam(self):
102
  cfg = copy.deepcopy(self._base_decoding)
103
  cfg.strategy = "malsd_batch"
@@ -105,17 +122,18 @@ class ParakeetManager:
105
  "beam_size": BEAM_SIZE,
106
  "return_best_hypothesis": True,
107
  "score_norm": True,
108
- "allow_cuda_graphs": False, # CPU-only
109
  "max_symbols_per_step": 10,
110
  })
111
  OmegaConf.set_struct(cfg, False)
112
  cfg["loop_labels"] = True
113
  cfg["fused_batch_size"] = -1
114
- cfg["compute_timestamps"] = True # Enabled for word-level timestamps
115
  if hasattr(cfg, "greedy"):
116
  cfg.greedy.use_cuda_graph_decoder = False
117
  self.model.change_decoding_strategy(cfg)
118
  logger.info("decoding_set strategy=malsd_batch loop_labels=True")
 
119
  def _transcribe(self, items: List, *, partial=None):
120
  with torch.inference_mode():
121
  return self.model.transcribe(
@@ -125,6 +143,7 @@ class ParakeetManager:
125
  return_hypotheses=True,
126
  partial_hypothesis=partial,
127
  )
 
128
  # Offline batch
129
  def transcribe_files(self, paths: List[str]):
130
  n = 0 if not paths else len(paths)
@@ -137,35 +156,18 @@ class ParakeetManager:
137
  for p, o in zip(paths, out):
138
  h = o[0] if isinstance(o, list) and o else o
139
  text = h if isinstance(h, str) else getattr(h, "text", "")
140
- # Extract timestamps if available
141
- if hasattr(h, 'timestep') and h.timestep:
142
- word_timestamps = h.timestep.get('word', [])
143
- if word_timestamps and text:
144
- # Format timed text
145
- words = text.split()
146
- if len(words) == len(word_timestamps):
147
- timed_parts = [f"{word} ({ts['start']}-{ts['end']}s)" for word, ts in zip(words, word_timestamps)]
148
- text = ' '.join(timed_parts)
149
- logger.debug(f"File timestamps for {p}: {word_timestamps}")
150
  results.append({"path": p, "text": text})
151
  logger.info("files_run ok")
152
  return results
 
153
  # Streaming step (rolling hypothesis)
154
  def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object:
155
  out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None)
156
  h = out[0][0] if isinstance(out[0], list) else out[0]
157
- return h # Hypothesis
 
158
  # ----------------------------
159
- # Helper for token merging
160
- # ----------------------------
161
- def common_prefix_len(a: list, b: list) -> int:
162
- min_len = min(len(a), len(b))
163
- for i in range(min_len):
164
- if a[i] != b[i]:
165
- return i
166
- return min_len
167
- # ----------------------------
168
- # Streaming session (rolling hypothesis with token merging)
169
  # ----------------------------
170
  class StreamingSession:
171
  def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float):
@@ -175,84 +177,61 @@ class StreamingSession:
175
  self.hyp = None
176
  self.pending = np.zeros(0, dtype=np.float32)
177
  self.text = ""
178
- self.tokens: List[int] = [] # Track current token sequence for merging
179
  logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s")
 
180
  def add_audio(self, audio: np.ndarray, src_sr: int):
181
  mono = to_mono_np(audio)
182
  res = RESAMPLER.resample(mono, src_sr)
183
- # Normalize volume
184
- if np.max(np.abs(res)) > 0:
185
- res = res / np.max(np.abs(res)) * 0.95 # Scale to [-0.95, 0.95]
186
- # Simple VAD (trim silence; use torchaudio's if import functional as F)
187
- from torchaudio.functional import vad
188
- res = vad(torch.from_numpy(res), sample_rate=TARGET_SR, trigger_level=7.0).numpy()
189
  self.pending = np.concatenate([self.pending, res]) if self.pending.size else res
190
  self._drain()
191
- def _merge_tokens(self, new_hyp: Hypothesis) -> None:
192
- """Merge new hypothesis tokens with existing, update text and hyp."""
193
- # Handle all possible types: tensor, ndarray, list, None
194
- if new_hyp.y_sequence is None:
195
- new_tokens = []
196
- elif isinstance(new_hyp.y_sequence, torch.Tensor):
197
- new_tokens = new_hyp.y_sequence.cpu().tolist()
198
- elif isinstance(new_hyp.y_sequence, np.ndarray):
199
- new_tokens = new_hyp.y_sequence.tolist()
200
- else:
201
- new_tokens = list(new_hyp.y_sequence)
202
- # Ensure self.tokens is list
203
- self.tokens = list(self.tokens)
204
- logger.debug(f"New hyp text: '{new_hyp.text}', y_sequence type: {type(new_hyp.y_sequence)}, len: {len(new_tokens) if new_tokens else 0}")
205
- if len(new_tokens) > 0:
206
- prefix_len = common_prefix_len(self.tokens, new_tokens)
207
- if prefix_len < len(new_tokens): # Skip if no new tokens
208
- merged_tokens = self.tokens + new_tokens[prefix_len:]
209
- logger.debug(f"Prev tokens len: {len(self.tokens)}, New tokens len: {len(new_tokens)}, Prefix len: {prefix_len}, Merged tokens len: {len(merged_tokens)}")
210
- self.text = self.mgr.model.tokenizer.ids_to_text(merged_tokens)
211
- self.tokens = merged_tokens
212
- # Update hyp for next partial (copy and set as tensor, as NeMo expects)
213
- self.hyp = copy.deepcopy(new_hyp)
214
- self.hyp.y_sequence = torch.tensor(merged_tokens, dtype=torch.long)
215
- logger.debug(f"Merged tokens: len={len(merged_tokens)}") # For debug
216
- # Log timestamps if available
217
- if hasattr(new_hyp, 'timestep') and new_hyp.timestep:
218
- word_timestamps = new_hyp.timestep.get('word', [])
219
- if word_timestamps:
220
- logger.debug(f"New hyp word timestamps: {word_timestamps}")
221
  def _drain(self):
222
  C = int(self.chunk_s * TARGET_SR)
223
  while self.pending.size >= C:
224
  chunk = self.pending[:C]
225
  self.pending = self.pending[C:]
226
  try:
227
- new_hyp = self.mgr.stream_step(chunk, self.hyp)
228
- logger.debug(f"Post-step hyp text: '{new_hyp.text}'")
229
- self._merge_tokens(new_hyp)
 
 
 
 
230
  except Exception:
231
  logger.exception("mic_step failed")
232
  break
 
233
  def flush(self) -> str:
234
  if self.pending.size:
235
  pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32)
236
  final = np.concatenate([self.pending, pad])
237
  try:
238
- new_hyp = self.mgr.stream_step(final, self.hyp)
239
- self._merge_tokens(new_hyp)
240
- if self.text: # Add period only if there's text
241
- self.text += '.'
 
 
 
 
242
  except Exception:
243
  logger.exception("mic_flush failed")
244
  self.pending = np.zeros(0, dtype=np.float32)
245
  return self.text
 
246
  # ----------------------------
247
  # Simple session registry (avoid deepcopy in gr.State)
248
  # ----------------------------
249
  SESS: Dict[str, StreamingSession] = {}
250
  def _new_session_id() -> str:
251
  return uuid.uuid4().hex
 
252
  # ----------------------------
253
  # Gradio callbacks
254
  # ----------------------------
255
  MANAGER = ParakeetManager(device="cpu")
 
256
  def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
257
  if x is None:
258
  return np.zeros(0, dtype=np.float32), TARGET_SR
@@ -263,6 +242,7 @@ def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
263
  if isinstance(x, np.ndarray):
264
  return x.astype(np.float32, copy=False), TARGET_SR
265
  logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload")
 
266
  def mic_step(audio_chunk, sess_id: Optional[str]):
267
  if not sess_id or sess_id not in SESS:
268
  sess_id = _new_session_id()
@@ -276,12 +256,14 @@ def mic_step(audio_chunk, sess_id: Optional[str]):
276
  if wav.size:
277
  sess.add_audio(wav, sr)
278
  return sess_id, sess.text
 
279
  def mic_flush(sess_id: Optional[str]):
280
  if not sess_id or sess_id not in SESS:
281
  return None, ""
282
  text = SESS[sess_id].flush()
283
  logger.info("mic_flush ok")
284
  return None, text
 
285
  def files_run(files):
286
  n = 0 if not files else len(files)
287
  logger.info(f"files_ui start count={n}")
@@ -300,6 +282,7 @@ def files_run(files):
300
  table = [[os.path.basename(r["path"]), r["text"]] for r in results]
301
  logger.info("files_ui ok")
302
  return table
 
303
  # ----------------------------
304
  # UI
305
  # ----------------------------
@@ -308,13 +291,15 @@ with gr.Blocks(title="Parakeet-TDT v3 (Unified MALSD Beam)") as demo:
308
  mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Speak")
309
  text_out = gr.Textbox(label="Transcript", lines=8)
310
  flush_btn = gr.Button("Flush")
311
- state_id = gr.State() # only a string id
312
  mic.stream(mic_step, inputs=[mic, state_id], outputs=[state_id, text_out])
313
  flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out])
 
314
  with gr.Tab("Files"):
315
  files = gr.File(file_count="multiple", type="filepath", label="Upload audio files")
316
  run_btn = gr.Button("Run")
317
  results_table = gr.Dataframe(headers=["file", "text"], label="Results",
318
  row_count=(0, "dynamic"), col_count=(2, "fixed"))
319
  run_btn.click(files_run, inputs=[files], outputs=[results_table])
 
320
  demo.queue().launch(ssr_mode=False)
 
1
+ # This is just a comment to make a somewhat of snapshot of this commit, version, this code works amazing, for mic and for file, its just great
2
  from __future__ import annotations
3
  import os
4
  import copy
5
  import uuid
6
  import logging
7
  from typing import List, Optional, Tuple, Dict
8
+
9
  # Reduce progress/log spam before heavy imports
10
  os.environ.setdefault("TQDM_DISABLE", "1")
11
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
12
+
13
  import numpy as np
14
  import torch
15
  import torchaudio
16
  import soundfile as sf
17
  import gradio as gr
18
+
19
  # NeMo
20
  from nemo.collections.asr.models import ASRModel
 
21
  from omegaconf import OmegaConf
22
  from nemo.utils import logging as nemo_logging
23
+
24
  # ----------------------------
25
  # Config
26
  # ----------------------------
27
+ MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3")
28
+ TARGET_SR = 16_000
29
+ BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "32")) # Increased for subtle quality gains
30
  OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
31
+ CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "2.0"))
32
+ FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0"))
33
+
34
  # ----------------------------
35
  # Logging (unified)
36
  # ----------------------------
37
+ LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
38
  logger = logging.getLogger("parakeet_app")
39
  logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
40
  _handler = logging.StreamHandler()
41
  _handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
42
  logger.handlers = [_handler]
43
  logger.propagate = False
44
+
45
  # Quiet NeMo logs
46
  nemo_logging.setLevel(logging.ERROR)
47
  logging.getLogger("nemo").setLevel(logging.ERROR)
48
  logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR)
49
+
50
  torch.set_grad_enabled(False)
51
+
52
  # ----------------------------
53
  # Audio utils
54
  # ----------------------------
 
56
  if x.ndim == 2:
57
  x = x.mean(axis=1)
58
  return x.astype(np.float32, copy=False)
59
+
60
  class ResamplerCache:
61
  def __init__(self):
62
  self._cache: Dict[int, torchaudio.transforms.Resample] = {}
 
71
  t = t.unsqueeze(0)
72
  y = self._cache[src_sr](t)
73
  return y.squeeze(0).numpy()
74
+
75
  RESAMPLER = ResamplerCache()
76
+
77
  def load_mono16k(path: str) -> np.ndarray:
78
  """Load any audio file, convert to mono float32 at 16 kHz."""
79
  try:
80
+ wav, sr = sf.read(path, dtype="float32", always_2d=True) # (T,C)
81
  wav = wav.mean(axis=1).astype(np.float32, copy=False)
82
  return RESAMPLER.resample(wav, sr)
83
  except Exception:
84
+ wav_t, sr = torchaudio.load(path) # (C,T)
85
  if wav_t.dtype != torch.float32:
86
  wav_t = wav_t.float()
87
  wav = wav_t.mean(dim=0).numpy()
88
  return RESAMPLER.resample(wav, int(sr))
89
+
90
  # ----------------------------
91
  # Model manager (MALSD batched beam everywhere, loop_labels=True)
92
  # ----------------------------
 
99
  self.model.eval()
100
  for p in self.model.parameters():
101
  p.requires_grad = False
102
+
103
  # Base decoding cfg differs by class
104
  if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"):
105
  self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg)
106
  else:
107
  self._base_decoding = copy.deepcopy(self.model.cfg.decoding)
108
+
109
  self._set_malsd_beam()
110
+
111
  # Enable encoder caching for better streaming context (per NeMo docs/tutorials)
112
  if hasattr(self.model.encoder, "set_default_att_context_size"):
113
+ self.model.encoder.set_default_att_context_size([512, 16]) # Large left for cumulative context, small right for buffering
114
  logger.info("encoder_caching_enabled left=512 right=16")
115
+
116
  logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
117
+
118
  def _set_malsd_beam(self):
119
  cfg = copy.deepcopy(self._base_decoding)
120
  cfg.strategy = "malsd_batch"
 
122
  "beam_size": BEAM_SIZE,
123
  "return_best_hypothesis": True,
124
  "score_norm": True,
125
+ "allow_cuda_graphs": False, # CPU-only
126
  "max_symbols_per_step": 10,
127
  })
128
  OmegaConf.set_struct(cfg, False)
129
  cfg["loop_labels"] = True
130
  cfg["fused_batch_size"] = -1
131
+ cfg["compute_timestamps"] = False
132
  if hasattr(cfg, "greedy"):
133
  cfg.greedy.use_cuda_graph_decoder = False
134
  self.model.change_decoding_strategy(cfg)
135
  logger.info("decoding_set strategy=malsd_batch loop_labels=True")
136
+
137
  def _transcribe(self, items: List, *, partial=None):
138
  with torch.inference_mode():
139
  return self.model.transcribe(
 
143
  return_hypotheses=True,
144
  partial_hypothesis=partial,
145
  )
146
+
147
  # Offline batch
148
  def transcribe_files(self, paths: List[str]):
149
  n = 0 if not paths else len(paths)
 
156
  for p, o in zip(paths, out):
157
  h = o[0] if isinstance(o, list) and o else o
158
  text = h if isinstance(h, str) else getattr(h, "text", "")
 
 
 
 
 
 
 
 
 
 
159
  results.append({"path": p, "text": text})
160
  logger.info("files_run ok")
161
  return results
162
+
163
  # Streaming step (rolling hypothesis)
164
  def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object:
165
  out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None)
166
  h = out[0][0] if isinstance(out[0], list) else out[0]
167
+ return h # Hypothesis
168
+
169
  # ----------------------------
170
+ # Streaming session (no overlap, rolling hypothesis)
 
 
 
 
 
 
 
 
 
171
  # ----------------------------
172
  class StreamingSession:
173
  def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float):
 
177
  self.hyp = None
178
  self.pending = np.zeros(0, dtype=np.float32)
179
  self.text = ""
 
180
  logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s")
181
+
182
  def add_audio(self, audio: np.ndarray, src_sr: int):
183
  mono = to_mono_np(audio)
184
  res = RESAMPLER.resample(mono, src_sr)
 
 
 
 
 
 
185
  self.pending = np.concatenate([self.pending, res]) if self.pending.size else res
186
  self._drain()
187
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  def _drain(self):
189
  C = int(self.chunk_s * TARGET_SR)
190
  while self.pending.size >= C:
191
  chunk = self.pending[:C]
192
  self.pending = self.pending[C:]
193
  try:
194
+ self.hyp = self.mgr.stream_step(chunk, self.hyp)
195
+ new_text = getattr(self.hyp, "text", "")
196
+ if new_text:
197
+ if self.text and new_text.startswith(self.text): # If cumulative (partial extends), replace with extended
198
+ self.text = new_text
199
+ else: # Else append (handles per-chunk case)
200
+ self.text += (' ' if self.text else '') + new_text
201
  except Exception:
202
  logger.exception("mic_step failed")
203
  break
204
+
205
  def flush(self) -> str:
206
  if self.pending.size:
207
  pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32)
208
  final = np.concatenate([self.pending, pad])
209
  try:
210
+ self.hyp = self.mgr.stream_step(final, self.hyp)
211
+ new_text = getattr(self.hyp, "text", "")
212
+ if new_text:
213
+ if self.text and new_text.startswith(self.text):
214
+ self.text = new_text
215
+ else:
216
+ self.text += (' ' if self.text else '') + new_text
217
+ self.text += '.' # Add period for sentence closure on flush
218
  except Exception:
219
  logger.exception("mic_flush failed")
220
  self.pending = np.zeros(0, dtype=np.float32)
221
  return self.text
222
+
223
  # ----------------------------
224
  # Simple session registry (avoid deepcopy in gr.State)
225
  # ----------------------------
226
  SESS: Dict[str, StreamingSession] = {}
227
  def _new_session_id() -> str:
228
  return uuid.uuid4().hex
229
+
230
  # ----------------------------
231
  # Gradio callbacks
232
  # ----------------------------
233
  MANAGER = ParakeetManager(device="cpu")
234
+
235
  def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
236
  if x is None:
237
  return np.zeros(0, dtype=np.float32), TARGET_SR
 
242
  if isinstance(x, np.ndarray):
243
  return x.astype(np.float32, copy=False), TARGET_SR
244
  logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload")
245
+
246
  def mic_step(audio_chunk, sess_id: Optional[str]):
247
  if not sess_id or sess_id not in SESS:
248
  sess_id = _new_session_id()
 
256
  if wav.size:
257
  sess.add_audio(wav, sr)
258
  return sess_id, sess.text
259
+
260
  def mic_flush(sess_id: Optional[str]):
261
  if not sess_id or sess_id not in SESS:
262
  return None, ""
263
  text = SESS[sess_id].flush()
264
  logger.info("mic_flush ok")
265
  return None, text
266
+
267
  def files_run(files):
268
  n = 0 if not files else len(files)
269
  logger.info(f"files_ui start count={n}")
 
282
  table = [[os.path.basename(r["path"]), r["text"]] for r in results]
283
  logger.info("files_ui ok")
284
  return table
285
+
286
  # ----------------------------
287
  # UI
288
  # ----------------------------
 
291
  mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Speak")
292
  text_out = gr.Textbox(label="Transcript", lines=8)
293
  flush_btn = gr.Button("Flush")
294
+ state_id = gr.State() # only a string id
295
  mic.stream(mic_step, inputs=[mic, state_id], outputs=[state_id, text_out])
296
  flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out])
297
+
298
  with gr.Tab("Files"):
299
  files = gr.File(file_count="multiple", type="filepath", label="Upload audio files")
300
  run_btn = gr.Button("Run")
301
  results_table = gr.Dataframe(headers=["file", "text"], label="Results",
302
  row_count=(0, "dynamic"), col_count=(2, "fixed"))
303
  run_btn.click(files_run, inputs=[files], outputs=[results_table])
304
+
305
  demo.queue().launch(ssr_mode=False)