WJ88 commited on
Commit
0e644a1
·
verified ·
1 Parent(s): c6358db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -62
app.py CHANGED
@@ -4,50 +4,43 @@ import copy
4
  import uuid
5
  import logging
6
  from typing import List, Optional, Tuple, Dict
7
-
8
  # Reduce progress/log spam before heavy imports
9
  os.environ.setdefault("TQDM_DISABLE", "1")
10
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
11
-
12
  import numpy as np
13
  import torch
14
  import torchaudio
15
  import soundfile as sf
16
  import gradio as gr
17
-
18
  # NeMo
19
  from nemo.collections.asr.models import ASRModel
 
20
  from omegaconf import OmegaConf
21
  from nemo.utils import logging as nemo_logging
22
-
23
  # ----------------------------
24
  # Config
25
  # ----------------------------
26
- MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3")
27
- TARGET_SR = 16_000
28
- BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "32")) # Increased for subtle quality gains
29
  OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
30
- CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "2.0"))
31
- FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0"))
32
-
33
  # ----------------------------
34
  # Logging (unified)
35
  # ----------------------------
36
- LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
37
  logger = logging.getLogger("parakeet_app")
38
  logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
39
  _handler = logging.StreamHandler()
40
  _handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
41
  logger.handlers = [_handler]
42
  logger.propagate = False
43
-
44
  # Quiet NeMo logs
45
  nemo_logging.setLevel(logging.ERROR)
46
  logging.getLogger("nemo").setLevel(logging.ERROR)
47
  logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR)
48
-
49
  torch.set_grad_enabled(False)
50
-
51
  # ----------------------------
52
  # Audio utils
53
  # ----------------------------
@@ -55,7 +48,6 @@ def to_mono_np(x: np.ndarray) -> np.ndarray:
55
  if x.ndim == 2:
56
  x = x.mean(axis=1)
57
  return x.astype(np.float32, copy=False)
58
-
59
  class ResamplerCache:
60
  def __init__(self):
61
  self._cache: Dict[int, torchaudio.transforms.Resample] = {}
@@ -70,22 +62,19 @@ class ResamplerCache:
70
  t = t.unsqueeze(0)
71
  y = self._cache[src_sr](t)
72
  return y.squeeze(0).numpy()
73
-
74
  RESAMPLER = ResamplerCache()
75
-
76
  def load_mono16k(path: str) -> np.ndarray:
77
  """Load any audio file, convert to mono float32 at 16 kHz."""
78
  try:
79
- wav, sr = sf.read(path, dtype="float32", always_2d=True) # (T,C)
80
  wav = wav.mean(axis=1).astype(np.float32, copy=False)
81
  return RESAMPLER.resample(wav, sr)
82
  except Exception:
83
- wav_t, sr = torchaudio.load(path) # (C,T)
84
  if wav_t.dtype != torch.float32:
85
  wav_t = wav_t.float()
86
  wav = wav_t.mean(dim=0).numpy()
87
  return RESAMPLER.resample(wav, int(sr))
88
-
89
  # ----------------------------
90
  # Model manager (MALSD batched beam everywhere, loop_labels=True)
91
  # ----------------------------
@@ -98,22 +87,17 @@ class ParakeetManager:
98
  self.model.eval()
99
  for p in self.model.parameters():
100
  p.requires_grad = False
101
-
102
  # Base decoding cfg differs by class
103
  if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"):
104
  self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg)
105
  else:
106
  self._base_decoding = copy.deepcopy(self.model.cfg.decoding)
107
-
108
  self._set_malsd_beam()
109
-
110
  # Enable encoder caching for better streaming context (per NeMo docs/tutorials)
111
  if hasattr(self.model.encoder, "set_default_att_context_size"):
112
- self.model.encoder.set_default_att_context_size([512, 16]) # Large left for cumulative context, small right for buffering
113
  logger.info("encoder_caching_enabled left=512 right=16")
114
-
115
  logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
116
-
117
  def _set_malsd_beam(self):
118
  cfg = copy.deepcopy(self._base_decoding)
119
  cfg.strategy = "malsd_batch"
@@ -121,18 +105,17 @@ class ParakeetManager:
121
  "beam_size": BEAM_SIZE,
122
  "return_best_hypothesis": True,
123
  "score_norm": True,
124
- "allow_cuda_graphs": False, # CPU-only
125
  "max_symbols_per_step": 10,
126
  })
127
  OmegaConf.set_struct(cfg, False)
128
  cfg["loop_labels"] = True
129
  cfg["fused_batch_size"] = -1
130
- cfg["compute_timestamps"] = False
131
  if hasattr(cfg, "greedy"):
132
  cfg.greedy.use_cuda_graph_decoder = False
133
  self.model.change_decoding_strategy(cfg)
134
  logger.info("decoding_set strategy=malsd_batch loop_labels=True")
135
-
136
  def _transcribe(self, items: List, *, partial=None):
137
  with torch.inference_mode():
138
  return self.model.transcribe(
@@ -142,7 +125,6 @@ class ParakeetManager:
142
  return_hypotheses=True,
143
  partial_hypothesis=partial,
144
  )
145
-
146
  # Offline batch
147
  def transcribe_files(self, paths: List[str]):
148
  n = 0 if not paths else len(paths)
@@ -155,18 +137,35 @@ class ParakeetManager:
155
  for p, o in zip(paths, out):
156
  h = o[0] if isinstance(o, list) and o else o
157
  text = h if isinstance(h, str) else getattr(h, "text", "")
 
 
 
 
 
 
 
 
 
 
158
  results.append({"path": p, "text": text})
159
  logger.info("files_run ok")
160
  return results
161
-
162
  # Streaming step (rolling hypothesis)
163
  def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object:
164
  out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None)
165
  h = out[0][0] if isinstance(out[0], list) else out[0]
166
- return h # Hypothesis
167
-
168
  # ----------------------------
169
- # Streaming session (no overlap, rolling hypothesis)
 
 
 
 
 
 
 
 
 
170
  # ----------------------------
171
  class StreamingSession:
172
  def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float):
@@ -176,61 +175,84 @@ class StreamingSession:
176
  self.hyp = None
177
  self.pending = np.zeros(0, dtype=np.float32)
178
  self.text = ""
 
179
  logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s")
180
-
181
  def add_audio(self, audio: np.ndarray, src_sr: int):
182
  mono = to_mono_np(audio)
183
  res = RESAMPLER.resample(mono, src_sr)
 
 
 
 
 
 
184
  self.pending = np.concatenate([self.pending, res]) if self.pending.size else res
185
  self._drain()
186
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def _drain(self):
188
  C = int(self.chunk_s * TARGET_SR)
189
  while self.pending.size >= C:
190
  chunk = self.pending[:C]
191
  self.pending = self.pending[C:]
192
  try:
193
- self.hyp = self.mgr.stream_step(chunk, self.hyp)
194
- new_text = getattr(self.hyp, "text", "")
195
- if new_text:
196
- if self.text and new_text.startswith(self.text): # If cumulative (partial extends), replace with extended
197
- self.text = new_text
198
- else: # Else append (handles per-chunk case)
199
- self.text += (' ' if self.text else '') + new_text
200
  except Exception:
201
  logger.exception("mic_step failed")
202
  break
203
-
204
  def flush(self) -> str:
205
  if self.pending.size:
206
  pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32)
207
  final = np.concatenate([self.pending, pad])
208
  try:
209
- self.hyp = self.mgr.stream_step(final, self.hyp)
210
- new_text = getattr(self.hyp, "text", "")
211
- if new_text:
212
- if self.text and new_text.startswith(self.text):
213
- self.text = new_text
214
- else:
215
- self.text += (' ' if self.text else '') + new_text
216
- self.text += '.' # Add period for sentence closure on flush
217
  except Exception:
218
  logger.exception("mic_flush failed")
219
  self.pending = np.zeros(0, dtype=np.float32)
220
  return self.text
221
-
222
  # ----------------------------
223
  # Simple session registry (avoid deepcopy in gr.State)
224
  # ----------------------------
225
  SESS: Dict[str, StreamingSession] = {}
226
  def _new_session_id() -> str:
227
  return uuid.uuid4().hex
228
-
229
  # ----------------------------
230
  # Gradio callbacks
231
  # ----------------------------
232
  MANAGER = ParakeetManager(device="cpu")
233
-
234
  def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
235
  if x is None:
236
  return np.zeros(0, dtype=np.float32), TARGET_SR
@@ -241,7 +263,6 @@ def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
241
  if isinstance(x, np.ndarray):
242
  return x.astype(np.float32, copy=False), TARGET_SR
243
  logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload")
244
-
245
  def mic_step(audio_chunk, sess_id: Optional[str]):
246
  if not sess_id or sess_id not in SESS:
247
  sess_id = _new_session_id()
@@ -255,14 +276,12 @@ def mic_step(audio_chunk, sess_id: Optional[str]):
255
  if wav.size:
256
  sess.add_audio(wav, sr)
257
  return sess_id, sess.text
258
-
259
  def mic_flush(sess_id: Optional[str]):
260
  if not sess_id or sess_id not in SESS:
261
  return None, ""
262
  text = SESS[sess_id].flush()
263
  logger.info("mic_flush ok")
264
  return None, text
265
-
266
  def files_run(files):
267
  n = 0 if not files else len(files)
268
  logger.info(f"files_ui start count={n}")
@@ -281,7 +300,6 @@ def files_run(files):
281
  table = [[os.path.basename(r["path"]), r["text"]] for r in results]
282
  logger.info("files_ui ok")
283
  return table
284
-
285
  # ----------------------------
286
  # UI
287
  # ----------------------------
@@ -290,15 +308,13 @@ with gr.Blocks(title="Parakeet-TDT v3 (Unified MALSD Beam)") as demo:
290
  mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Speak")
291
  text_out = gr.Textbox(label="Transcript", lines=8)
292
  flush_btn = gr.Button("Flush")
293
- state_id = gr.State() # only a string id
294
  mic.stream(mic_step, inputs=[mic, state_id], outputs=[state_id, text_out])
295
  flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out])
296
-
297
  with gr.Tab("Files"):
298
  files = gr.File(file_count="multiple", type="filepath", label="Upload audio files")
299
  run_btn = gr.Button("Run")
300
  results_table = gr.Dataframe(headers=["file", "text"], label="Results",
301
  row_count=(0, "dynamic"), col_count=(2, "fixed"))
302
  run_btn.click(files_run, inputs=[files], outputs=[results_table])
303
-
304
  demo.queue().launch(ssr_mode=False)
 
4
  import uuid
5
  import logging
6
  from typing import List, Optional, Tuple, Dict
 
7
  # Reduce progress/log spam before heavy imports
8
  os.environ.setdefault("TQDM_DISABLE", "1")
9
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
 
10
  import numpy as np
11
  import torch
12
  import torchaudio
13
  import soundfile as sf
14
  import gradio as gr
 
15
  # NeMo
16
  from nemo.collections.asr.models import ASRModel
17
+ from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis # For hypothesis handling
18
  from omegaconf import OmegaConf
19
  from nemo.utils import logging as nemo_logging
 
20
  # ----------------------------
21
  # Config
22
  # ----------------------------
23
+ MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3")
24
+ TARGET_SR = 16_000
25
+ BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "32")) # Increased for subtle quality gains
26
  OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
27
+ CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "4.0"))
28
+ FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0"))
 
29
  # ----------------------------
30
  # Logging (unified)
31
  # ----------------------------
32
+ LOG_LEVEL = os.environ.get("LOG_LEVEL", "DEBUG").upper()
33
  logger = logging.getLogger("parakeet_app")
34
  logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
35
  _handler = logging.StreamHandler()
36
  _handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
37
  logger.handlers = [_handler]
38
  logger.propagate = False
 
39
  # Quiet NeMo logs
40
  nemo_logging.setLevel(logging.ERROR)
41
  logging.getLogger("nemo").setLevel(logging.ERROR)
42
  logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR)
 
43
  torch.set_grad_enabled(False)
 
44
  # ----------------------------
45
  # Audio utils
46
  # ----------------------------
 
48
  if x.ndim == 2:
49
  x = x.mean(axis=1)
50
  return x.astype(np.float32, copy=False)
 
51
  class ResamplerCache:
52
  def __init__(self):
53
  self._cache: Dict[int, torchaudio.transforms.Resample] = {}
 
62
  t = t.unsqueeze(0)
63
  y = self._cache[src_sr](t)
64
  return y.squeeze(0).numpy()
 
65
  RESAMPLER = ResamplerCache()
 
66
  def load_mono16k(path: str) -> np.ndarray:
67
  """Load any audio file, convert to mono float32 at 16 kHz."""
68
  try:
69
+ wav, sr = sf.read(path, dtype="float32", always_2d=True) # (T,C)
70
  wav = wav.mean(axis=1).astype(np.float32, copy=False)
71
  return RESAMPLER.resample(wav, sr)
72
  except Exception:
73
+ wav_t, sr = torchaudio.load(path) # (C,T)
74
  if wav_t.dtype != torch.float32:
75
  wav_t = wav_t.float()
76
  wav = wav_t.mean(dim=0).numpy()
77
  return RESAMPLER.resample(wav, int(sr))
 
78
  # ----------------------------
79
  # Model manager (MALSD batched beam everywhere, loop_labels=True)
80
  # ----------------------------
 
87
  self.model.eval()
88
  for p in self.model.parameters():
89
  p.requires_grad = False
 
90
  # Base decoding cfg differs by class
91
  if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"):
92
  self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg)
93
  else:
94
  self._base_decoding = copy.deepcopy(self.model.cfg.decoding)
 
95
  self._set_malsd_beam()
 
96
  # Enable encoder caching for better streaming context (per NeMo docs/tutorials)
97
  if hasattr(self.model.encoder, "set_default_att_context_size"):
98
+ self.model.encoder.set_default_att_context_size([512, 16]) # Large left for cumulative context, small right for buffering
99
  logger.info("encoder_caching_enabled left=512 right=16")
 
100
  logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
 
101
  def _set_malsd_beam(self):
102
  cfg = copy.deepcopy(self._base_decoding)
103
  cfg.strategy = "malsd_batch"
 
105
  "beam_size": BEAM_SIZE,
106
  "return_best_hypothesis": True,
107
  "score_norm": True,
108
+ "allow_cuda_graphs": False, # CPU-only
109
  "max_symbols_per_step": 10,
110
  })
111
  OmegaConf.set_struct(cfg, False)
112
  cfg["loop_labels"] = True
113
  cfg["fused_batch_size"] = -1
114
+ cfg["compute_timestamps"] = True # Enabled for word-level timestamps
115
  if hasattr(cfg, "greedy"):
116
  cfg.greedy.use_cuda_graph_decoder = False
117
  self.model.change_decoding_strategy(cfg)
118
  logger.info("decoding_set strategy=malsd_batch loop_labels=True")
 
119
  def _transcribe(self, items: List, *, partial=None):
120
  with torch.inference_mode():
121
  return self.model.transcribe(
 
125
  return_hypotheses=True,
126
  partial_hypothesis=partial,
127
  )
 
128
  # Offline batch
129
  def transcribe_files(self, paths: List[str]):
130
  n = 0 if not paths else len(paths)
 
137
  for p, o in zip(paths, out):
138
  h = o[0] if isinstance(o, list) and o else o
139
  text = h if isinstance(h, str) else getattr(h, "text", "")
140
+ # Extract timestamps if available
141
+ if hasattr(h, 'timestep') and h.timestep:
142
+ word_timestamps = h.timestep.get('word', [])
143
+ if word_timestamps and text:
144
+ # Format timed text
145
+ words = text.split()
146
+ if len(words) == len(word_timestamps):
147
+ timed_parts = [f"{word} ({ts['start']}-{ts['end']}s)" for word, ts in zip(words, word_timestamps)]
148
+ text = ' '.join(timed_parts)
149
+ logger.debug(f"File timestamps for {p}: {word_timestamps}")
150
  results.append({"path": p, "text": text})
151
  logger.info("files_run ok")
152
  return results
 
153
  # Streaming step (rolling hypothesis)
154
  def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object:
155
  out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None)
156
  h = out[0][0] if isinstance(out[0], list) else out[0]
157
+ return h # Hypothesis
 
158
  # ----------------------------
159
+ # Helper for token merging
160
+ # ----------------------------
161
+ def common_prefix_len(a: list, b: list) -> int:
162
+ min_len = min(len(a), len(b))
163
+ for i in range(min_len):
164
+ if a[i] != b[i]:
165
+ return i
166
+ return min_len
167
+ # ----------------------------
168
+ # Streaming session (rolling hypothesis with token merging)
169
  # ----------------------------
170
  class StreamingSession:
171
  def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float):
 
175
  self.hyp = None
176
  self.pending = np.zeros(0, dtype=np.float32)
177
  self.text = ""
178
+ self.tokens: List[int] = [] # Track current token sequence for merging
179
  logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s")
 
180
  def add_audio(self, audio: np.ndarray, src_sr: int):
181
  mono = to_mono_np(audio)
182
  res = RESAMPLER.resample(mono, src_sr)
183
+ # Normalize volume
184
+ if np.max(np.abs(res)) > 0:
185
+ res = res / np.max(np.abs(res)) * 0.95 # Scale to [-0.95, 0.95]
186
+ # Simple VAD (trim silence; use torchaudio's if import functional as F)
187
+ from torchaudio.functional import vad
188
+ res = vad(torch.from_numpy(res), sample_rate=TARGET_SR, trigger_level=7.0).numpy()
189
  self.pending = np.concatenate([self.pending, res]) if self.pending.size else res
190
  self._drain()
191
+ def _merge_tokens(self, new_hyp: Hypothesis) -> None:
192
+ """Merge new hypothesis tokens with existing, update text and hyp."""
193
+ # Handle all possible types: tensor, ndarray, list, None
194
+ if new_hyp.y_sequence is None:
195
+ new_tokens = []
196
+ elif isinstance(new_hyp.y_sequence, torch.Tensor):
197
+ new_tokens = new_hyp.y_sequence.cpu().tolist()
198
+ elif isinstance(new_hyp.y_sequence, np.ndarray):
199
+ new_tokens = new_hyp.y_sequence.tolist()
200
+ else:
201
+ new_tokens = list(new_hyp.y_sequence)
202
+ # Ensure self.tokens is list
203
+ self.tokens = list(self.tokens)
204
+ logger.debug(f"New hyp text: '{new_hyp.text}', y_sequence type: {type(new_hyp.y_sequence)}, len: {len(new_tokens) if new_tokens else 0}")
205
+ if len(new_tokens) > 0:
206
+ prefix_len = common_prefix_len(self.tokens, new_tokens)
207
+ if prefix_len < len(new_tokens): # Skip if no new tokens
208
+ merged_tokens = self.tokens + new_tokens[prefix_len:]
209
+ logger.debug(f"Prev tokens len: {len(self.tokens)}, New tokens len: {len(new_tokens)}, Prefix len: {prefix_len}, Merged tokens len: {len(merged_tokens)}")
210
+ self.text = self.mgr.model.tokenizer.ids_to_text(merged_tokens)
211
+ self.tokens = merged_tokens
212
+ # Update hyp for next partial (copy and set as tensor, as NeMo expects)
213
+ self.hyp = copy.deepcopy(new_hyp)
214
+ self.hyp.y_sequence = torch.tensor(merged_tokens, dtype=torch.long)
215
+ logger.debug(f"Merged tokens: len={len(merged_tokens)}") # For debug
216
+ # Log timestamps if available
217
+ if hasattr(new_hyp, 'timestep') and new_hyp.timestep:
218
+ word_timestamps = new_hyp.timestep.get('word', [])
219
+ if word_timestamps:
220
+ logger.debug(f"New hyp word timestamps: {word_timestamps}")
221
  def _drain(self):
222
  C = int(self.chunk_s * TARGET_SR)
223
  while self.pending.size >= C:
224
  chunk = self.pending[:C]
225
  self.pending = self.pending[C:]
226
  try:
227
+ new_hyp = self.mgr.stream_step(chunk, self.hyp)
228
+ logger.debug(f"Post-step hyp text: '{new_hyp.text}'")
229
+ self._merge_tokens(new_hyp)
 
 
 
 
230
  except Exception:
231
  logger.exception("mic_step failed")
232
  break
 
233
  def flush(self) -> str:
234
  if self.pending.size:
235
  pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32)
236
  final = np.concatenate([self.pending, pad])
237
  try:
238
+ new_hyp = self.mgr.stream_step(final, self.hyp)
239
+ self._merge_tokens(new_hyp)
240
+ if self.text: # Add period only if there's text
241
+ self.text += '.'
 
 
 
 
242
  except Exception:
243
  logger.exception("mic_flush failed")
244
  self.pending = np.zeros(0, dtype=np.float32)
245
  return self.text
 
246
  # ----------------------------
247
  # Simple session registry (avoid deepcopy in gr.State)
248
  # ----------------------------
249
  SESS: Dict[str, StreamingSession] = {}
250
  def _new_session_id() -> str:
251
  return uuid.uuid4().hex
 
252
  # ----------------------------
253
  # Gradio callbacks
254
  # ----------------------------
255
  MANAGER = ParakeetManager(device="cpu")
 
256
  def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
257
  if x is None:
258
  return np.zeros(0, dtype=np.float32), TARGET_SR
 
263
  if isinstance(x, np.ndarray):
264
  return x.astype(np.float32, copy=False), TARGET_SR
265
  logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload")
 
266
  def mic_step(audio_chunk, sess_id: Optional[str]):
267
  if not sess_id or sess_id not in SESS:
268
  sess_id = _new_session_id()
 
276
  if wav.size:
277
  sess.add_audio(wav, sr)
278
  return sess_id, sess.text
 
279
  def mic_flush(sess_id: Optional[str]):
280
  if not sess_id or sess_id not in SESS:
281
  return None, ""
282
  text = SESS[sess_id].flush()
283
  logger.info("mic_flush ok")
284
  return None, text
 
285
  def files_run(files):
286
  n = 0 if not files else len(files)
287
  logger.info(f"files_ui start count={n}")
 
300
  table = [[os.path.basename(r["path"]), r["text"]] for r in results]
301
  logger.info("files_ui ok")
302
  return table
 
303
  # ----------------------------
304
  # UI
305
  # ----------------------------
 
308
  mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Speak")
309
  text_out = gr.Textbox(label="Transcript", lines=8)
310
  flush_btn = gr.Button("Flush")
311
+ state_id = gr.State() # only a string id
312
  mic.stream(mic_step, inputs=[mic, state_id], outputs=[state_id, text_out])
313
  flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out])
 
314
  with gr.Tab("Files"):
315
  files = gr.File(file_count="multiple", type="filepath", label="Upload audio files")
316
  run_btn = gr.Button("Run")
317
  results_table = gr.Dataframe(headers=["file", "text"], label="Results",
318
  row_count=(0, "dynamic"), col_count=(2, "fixed"))
319
  run_btn.click(files_run, inputs=[files], outputs=[results_table])
 
320
  demo.queue().launch(ssr_mode=False)