back to giga chad verson
Browse files
app.py
CHANGED
|
@@ -1,46 +1,54 @@
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import os
|
| 3 |
import copy
|
| 4 |
import uuid
|
| 5 |
import logging
|
| 6 |
from typing import List, Optional, Tuple, Dict
|
|
|
|
| 7 |
# Reduce progress/log spam before heavy imports
|
| 8 |
os.environ.setdefault("TQDM_DISABLE", "1")
|
| 9 |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
import torchaudio
|
| 13 |
import soundfile as sf
|
| 14 |
import gradio as gr
|
|
|
|
| 15 |
# NeMo
|
| 16 |
from nemo.collections.asr.models import ASRModel
|
| 17 |
-
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis # For hypothesis handling
|
| 18 |
from omegaconf import OmegaConf
|
| 19 |
from nemo.utils import logging as nemo_logging
|
|
|
|
| 20 |
# ----------------------------
|
| 21 |
# Config
|
| 22 |
# ----------------------------
|
| 23 |
-
MODEL_NAME
|
| 24 |
-
TARGET_SR
|
| 25 |
-
BEAM_SIZE
|
| 26 |
OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
|
| 27 |
-
CHUNK_S
|
| 28 |
-
FLUSH_PAD_S
|
|
|
|
| 29 |
# ----------------------------
|
| 30 |
# Logging (unified)
|
| 31 |
# ----------------------------
|
| 32 |
-
LOG_LEVEL = os.environ.get("LOG_LEVEL", "
|
| 33 |
logger = logging.getLogger("parakeet_app")
|
| 34 |
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
|
| 35 |
_handler = logging.StreamHandler()
|
| 36 |
_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
|
| 37 |
logger.handlers = [_handler]
|
| 38 |
logger.propagate = False
|
|
|
|
| 39 |
# Quiet NeMo logs
|
| 40 |
nemo_logging.setLevel(logging.ERROR)
|
| 41 |
logging.getLogger("nemo").setLevel(logging.ERROR)
|
| 42 |
logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR)
|
|
|
|
| 43 |
torch.set_grad_enabled(False)
|
|
|
|
| 44 |
# ----------------------------
|
| 45 |
# Audio utils
|
| 46 |
# ----------------------------
|
|
@@ -48,6 +56,7 @@ def to_mono_np(x: np.ndarray) -> np.ndarray:
|
|
| 48 |
if x.ndim == 2:
|
| 49 |
x = x.mean(axis=1)
|
| 50 |
return x.astype(np.float32, copy=False)
|
|
|
|
| 51 |
class ResamplerCache:
|
| 52 |
def __init__(self):
|
| 53 |
self._cache: Dict[int, torchaudio.transforms.Resample] = {}
|
|
@@ -62,19 +71,22 @@ class ResamplerCache:
|
|
| 62 |
t = t.unsqueeze(0)
|
| 63 |
y = self._cache[src_sr](t)
|
| 64 |
return y.squeeze(0).numpy()
|
|
|
|
| 65 |
RESAMPLER = ResamplerCache()
|
|
|
|
| 66 |
def load_mono16k(path: str) -> np.ndarray:
|
| 67 |
"""Load any audio file, convert to mono float32 at 16 kHz."""
|
| 68 |
try:
|
| 69 |
-
wav, sr = sf.read(path, dtype="float32", always_2d=True)
|
| 70 |
wav = wav.mean(axis=1).astype(np.float32, copy=False)
|
| 71 |
return RESAMPLER.resample(wav, sr)
|
| 72 |
except Exception:
|
| 73 |
-
wav_t, sr = torchaudio.load(path)
|
| 74 |
if wav_t.dtype != torch.float32:
|
| 75 |
wav_t = wav_t.float()
|
| 76 |
wav = wav_t.mean(dim=0).numpy()
|
| 77 |
return RESAMPLER.resample(wav, int(sr))
|
|
|
|
| 78 |
# ----------------------------
|
| 79 |
# Model manager (MALSD batched beam everywhere, loop_labels=True)
|
| 80 |
# ----------------------------
|
|
@@ -87,17 +99,22 @@ class ParakeetManager:
|
|
| 87 |
self.model.eval()
|
| 88 |
for p in self.model.parameters():
|
| 89 |
p.requires_grad = False
|
|
|
|
| 90 |
# Base decoding cfg differs by class
|
| 91 |
if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"):
|
| 92 |
self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg)
|
| 93 |
else:
|
| 94 |
self._base_decoding = copy.deepcopy(self.model.cfg.decoding)
|
|
|
|
| 95 |
self._set_malsd_beam()
|
|
|
|
| 96 |
# Enable encoder caching for better streaming context (per NeMo docs/tutorials)
|
| 97 |
if hasattr(self.model.encoder, "set_default_att_context_size"):
|
| 98 |
-
self.model.encoder.set_default_att_context_size([512, 16])
|
| 99 |
logger.info("encoder_caching_enabled left=512 right=16")
|
|
|
|
| 100 |
logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
|
|
|
|
| 101 |
def _set_malsd_beam(self):
|
| 102 |
cfg = copy.deepcopy(self._base_decoding)
|
| 103 |
cfg.strategy = "malsd_batch"
|
|
@@ -105,17 +122,18 @@ class ParakeetManager:
|
|
| 105 |
"beam_size": BEAM_SIZE,
|
| 106 |
"return_best_hypothesis": True,
|
| 107 |
"score_norm": True,
|
| 108 |
-
"allow_cuda_graphs": False,
|
| 109 |
"max_symbols_per_step": 10,
|
| 110 |
})
|
| 111 |
OmegaConf.set_struct(cfg, False)
|
| 112 |
cfg["loop_labels"] = True
|
| 113 |
cfg["fused_batch_size"] = -1
|
| 114 |
-
cfg["compute_timestamps"] =
|
| 115 |
if hasattr(cfg, "greedy"):
|
| 116 |
cfg.greedy.use_cuda_graph_decoder = False
|
| 117 |
self.model.change_decoding_strategy(cfg)
|
| 118 |
logger.info("decoding_set strategy=malsd_batch loop_labels=True")
|
|
|
|
| 119 |
def _transcribe(self, items: List, *, partial=None):
|
| 120 |
with torch.inference_mode():
|
| 121 |
return self.model.transcribe(
|
|
@@ -125,6 +143,7 @@ class ParakeetManager:
|
|
| 125 |
return_hypotheses=True,
|
| 126 |
partial_hypothesis=partial,
|
| 127 |
)
|
|
|
|
| 128 |
# Offline batch
|
| 129 |
def transcribe_files(self, paths: List[str]):
|
| 130 |
n = 0 if not paths else len(paths)
|
|
@@ -137,35 +156,18 @@ class ParakeetManager:
|
|
| 137 |
for p, o in zip(paths, out):
|
| 138 |
h = o[0] if isinstance(o, list) and o else o
|
| 139 |
text = h if isinstance(h, str) else getattr(h, "text", "")
|
| 140 |
-
# Extract timestamps if available
|
| 141 |
-
if hasattr(h, 'timestep') and h.timestep:
|
| 142 |
-
word_timestamps = h.timestep.get('word', [])
|
| 143 |
-
if word_timestamps and text:
|
| 144 |
-
# Format timed text
|
| 145 |
-
words = text.split()
|
| 146 |
-
if len(words) == len(word_timestamps):
|
| 147 |
-
timed_parts = [f"{word} ({ts['start']}-{ts['end']}s)" for word, ts in zip(words, word_timestamps)]
|
| 148 |
-
text = ' '.join(timed_parts)
|
| 149 |
-
logger.debug(f"File timestamps for {p}: {word_timestamps}")
|
| 150 |
results.append({"path": p, "text": text})
|
| 151 |
logger.info("files_run ok")
|
| 152 |
return results
|
|
|
|
| 153 |
# Streaming step (rolling hypothesis)
|
| 154 |
def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object:
|
| 155 |
out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None)
|
| 156 |
h = out[0][0] if isinstance(out[0], list) else out[0]
|
| 157 |
-
return h
|
|
|
|
| 158 |
# ----------------------------
|
| 159 |
-
#
|
| 160 |
-
# ----------------------------
|
| 161 |
-
def common_prefix_len(a: list, b: list) -> int:
|
| 162 |
-
min_len = min(len(a), len(b))
|
| 163 |
-
for i in range(min_len):
|
| 164 |
-
if a[i] != b[i]:
|
| 165 |
-
return i
|
| 166 |
-
return min_len
|
| 167 |
-
# ----------------------------
|
| 168 |
-
# Streaming session (rolling hypothesis with token merging)
|
| 169 |
# ----------------------------
|
| 170 |
class StreamingSession:
|
| 171 |
def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float):
|
|
@@ -175,84 +177,61 @@ class StreamingSession:
|
|
| 175 |
self.hyp = None
|
| 176 |
self.pending = np.zeros(0, dtype=np.float32)
|
| 177 |
self.text = ""
|
| 178 |
-
self.tokens: List[int] = [] # Track current token sequence for merging
|
| 179 |
logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s")
|
|
|
|
| 180 |
def add_audio(self, audio: np.ndarray, src_sr: int):
|
| 181 |
mono = to_mono_np(audio)
|
| 182 |
res = RESAMPLER.resample(mono, src_sr)
|
| 183 |
-
# Normalize volume
|
| 184 |
-
if np.max(np.abs(res)) > 0:
|
| 185 |
-
res = res / np.max(np.abs(res)) * 0.95 # Scale to [-0.95, 0.95]
|
| 186 |
-
# Simple VAD (trim silence; use torchaudio's if import functional as F)
|
| 187 |
-
from torchaudio.functional import vad
|
| 188 |
-
res = vad(torch.from_numpy(res), sample_rate=TARGET_SR, trigger_level=7.0).numpy()
|
| 189 |
self.pending = np.concatenate([self.pending, res]) if self.pending.size else res
|
| 190 |
self._drain()
|
| 191 |
-
|
| 192 |
-
"""Merge new hypothesis tokens with existing, update text and hyp."""
|
| 193 |
-
# Handle all possible types: tensor, ndarray, list, None
|
| 194 |
-
if new_hyp.y_sequence is None:
|
| 195 |
-
new_tokens = []
|
| 196 |
-
elif isinstance(new_hyp.y_sequence, torch.Tensor):
|
| 197 |
-
new_tokens = new_hyp.y_sequence.cpu().tolist()
|
| 198 |
-
elif isinstance(new_hyp.y_sequence, np.ndarray):
|
| 199 |
-
new_tokens = new_hyp.y_sequence.tolist()
|
| 200 |
-
else:
|
| 201 |
-
new_tokens = list(new_hyp.y_sequence)
|
| 202 |
-
# Ensure self.tokens is list
|
| 203 |
-
self.tokens = list(self.tokens)
|
| 204 |
-
logger.debug(f"New hyp text: '{new_hyp.text}', y_sequence type: {type(new_hyp.y_sequence)}, len: {len(new_tokens) if new_tokens else 0}")
|
| 205 |
-
if len(new_tokens) > 0:
|
| 206 |
-
prefix_len = common_prefix_len(self.tokens, new_tokens)
|
| 207 |
-
if prefix_len < len(new_tokens): # Skip if no new tokens
|
| 208 |
-
merged_tokens = self.tokens + new_tokens[prefix_len:]
|
| 209 |
-
logger.debug(f"Prev tokens len: {len(self.tokens)}, New tokens len: {len(new_tokens)}, Prefix len: {prefix_len}, Merged tokens len: {len(merged_tokens)}")
|
| 210 |
-
self.text = self.mgr.model.tokenizer.ids_to_text(merged_tokens)
|
| 211 |
-
self.tokens = merged_tokens
|
| 212 |
-
# Update hyp for next partial (copy and set as tensor, as NeMo expects)
|
| 213 |
-
self.hyp = copy.deepcopy(new_hyp)
|
| 214 |
-
self.hyp.y_sequence = torch.tensor(merged_tokens, dtype=torch.long)
|
| 215 |
-
logger.debug(f"Merged tokens: len={len(merged_tokens)}") # For debug
|
| 216 |
-
# Log timestamps if available
|
| 217 |
-
if hasattr(new_hyp, 'timestep') and new_hyp.timestep:
|
| 218 |
-
word_timestamps = new_hyp.timestep.get('word', [])
|
| 219 |
-
if word_timestamps:
|
| 220 |
-
logger.debug(f"New hyp word timestamps: {word_timestamps}")
|
| 221 |
def _drain(self):
|
| 222 |
C = int(self.chunk_s * TARGET_SR)
|
| 223 |
while self.pending.size >= C:
|
| 224 |
chunk = self.pending[:C]
|
| 225 |
self.pending = self.pending[C:]
|
| 226 |
try:
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
except Exception:
|
| 231 |
logger.exception("mic_step failed")
|
| 232 |
break
|
|
|
|
| 233 |
def flush(self) -> str:
|
| 234 |
if self.pending.size:
|
| 235 |
pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32)
|
| 236 |
final = np.concatenate([self.pending, pad])
|
| 237 |
try:
|
| 238 |
-
|
| 239 |
-
self.
|
| 240 |
-
if
|
| 241 |
-
self.text
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
except Exception:
|
| 243 |
logger.exception("mic_flush failed")
|
| 244 |
self.pending = np.zeros(0, dtype=np.float32)
|
| 245 |
return self.text
|
|
|
|
| 246 |
# ----------------------------
|
| 247 |
# Simple session registry (avoid deepcopy in gr.State)
|
| 248 |
# ----------------------------
|
| 249 |
SESS: Dict[str, StreamingSession] = {}
|
| 250 |
def _new_session_id() -> str:
|
| 251 |
return uuid.uuid4().hex
|
|
|
|
| 252 |
# ----------------------------
|
| 253 |
# Gradio callbacks
|
| 254 |
# ----------------------------
|
| 255 |
MANAGER = ParakeetManager(device="cpu")
|
|
|
|
| 256 |
def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
|
| 257 |
if x is None:
|
| 258 |
return np.zeros(0, dtype=np.float32), TARGET_SR
|
|
@@ -263,6 +242,7 @@ def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
|
|
| 263 |
if isinstance(x, np.ndarray):
|
| 264 |
return x.astype(np.float32, copy=False), TARGET_SR
|
| 265 |
logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload")
|
|
|
|
| 266 |
def mic_step(audio_chunk, sess_id: Optional[str]):
|
| 267 |
if not sess_id or sess_id not in SESS:
|
| 268 |
sess_id = _new_session_id()
|
|
@@ -276,12 +256,14 @@ def mic_step(audio_chunk, sess_id: Optional[str]):
|
|
| 276 |
if wav.size:
|
| 277 |
sess.add_audio(wav, sr)
|
| 278 |
return sess_id, sess.text
|
|
|
|
| 279 |
def mic_flush(sess_id: Optional[str]):
|
| 280 |
if not sess_id or sess_id not in SESS:
|
| 281 |
return None, ""
|
| 282 |
text = SESS[sess_id].flush()
|
| 283 |
logger.info("mic_flush ok")
|
| 284 |
return None, text
|
|
|
|
| 285 |
def files_run(files):
|
| 286 |
n = 0 if not files else len(files)
|
| 287 |
logger.info(f"files_ui start count={n}")
|
|
@@ -300,6 +282,7 @@ def files_run(files):
|
|
| 300 |
table = [[os.path.basename(r["path"]), r["text"]] for r in results]
|
| 301 |
logger.info("files_ui ok")
|
| 302 |
return table
|
|
|
|
| 303 |
# ----------------------------
|
| 304 |
# UI
|
| 305 |
# ----------------------------
|
|
@@ -308,13 +291,15 @@ with gr.Blocks(title="Parakeet-TDT v3 (Unified MALSD Beam)") as demo:
|
|
| 308 |
mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Speak")
|
| 309 |
text_out = gr.Textbox(label="Transcript", lines=8)
|
| 310 |
flush_btn = gr.Button("Flush")
|
| 311 |
-
state_id = gr.State()
|
| 312 |
mic.stream(mic_step, inputs=[mic, state_id], outputs=[state_id, text_out])
|
| 313 |
flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out])
|
|
|
|
| 314 |
with gr.Tab("Files"):
|
| 315 |
files = gr.File(file_count="multiple", type="filepath", label="Upload audio files")
|
| 316 |
run_btn = gr.Button("Run")
|
| 317 |
results_table = gr.Dataframe(headers=["file", "text"], label="Results",
|
| 318 |
row_count=(0, "dynamic"), col_count=(2, "fixed"))
|
| 319 |
run_btn.click(files_run, inputs=[files], outputs=[results_table])
|
|
|
|
| 320 |
demo.queue().launch(ssr_mode=False)
|
|
|
|
| 1 |
+
# This is just a comment to make a somewhat of snapshot of this commit, version, this code works amazing, for mic and for file, its just great
|
| 2 |
from __future__ import annotations
|
| 3 |
import os
|
| 4 |
import copy
|
| 5 |
import uuid
|
| 6 |
import logging
|
| 7 |
from typing import List, Optional, Tuple, Dict
|
| 8 |
+
|
| 9 |
# Reduce progress/log spam before heavy imports
|
| 10 |
os.environ.setdefault("TQDM_DISABLE", "1")
|
| 11 |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 12 |
+
|
| 13 |
import numpy as np
|
| 14 |
import torch
|
| 15 |
import torchaudio
|
| 16 |
import soundfile as sf
|
| 17 |
import gradio as gr
|
| 18 |
+
|
| 19 |
# NeMo
|
| 20 |
from nemo.collections.asr.models import ASRModel
|
|
|
|
| 21 |
from omegaconf import OmegaConf
|
| 22 |
from nemo.utils import logging as nemo_logging
|
| 23 |
+
|
| 24 |
# ----------------------------
|
| 25 |
# Config
|
| 26 |
# ----------------------------
|
| 27 |
+
MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3")
|
| 28 |
+
TARGET_SR = 16_000
|
| 29 |
+
BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "32")) # Increased for subtle quality gains
|
| 30 |
OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
|
| 31 |
+
CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "2.0"))
|
| 32 |
+
FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0"))
|
| 33 |
+
|
| 34 |
# ----------------------------
|
| 35 |
# Logging (unified)
|
| 36 |
# ----------------------------
|
| 37 |
+
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
|
| 38 |
logger = logging.getLogger("parakeet_app")
|
| 39 |
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
|
| 40 |
_handler = logging.StreamHandler()
|
| 41 |
_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
|
| 42 |
logger.handlers = [_handler]
|
| 43 |
logger.propagate = False
|
| 44 |
+
|
| 45 |
# Quiet NeMo logs
|
| 46 |
nemo_logging.setLevel(logging.ERROR)
|
| 47 |
logging.getLogger("nemo").setLevel(logging.ERROR)
|
| 48 |
logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR)
|
| 49 |
+
|
| 50 |
torch.set_grad_enabled(False)
|
| 51 |
+
|
| 52 |
# ----------------------------
|
| 53 |
# Audio utils
|
| 54 |
# ----------------------------
|
|
|
|
| 56 |
if x.ndim == 2:
|
| 57 |
x = x.mean(axis=1)
|
| 58 |
return x.astype(np.float32, copy=False)
|
| 59 |
+
|
| 60 |
class ResamplerCache:
|
| 61 |
def __init__(self):
|
| 62 |
self._cache: Dict[int, torchaudio.transforms.Resample] = {}
|
|
|
|
| 71 |
t = t.unsqueeze(0)
|
| 72 |
y = self._cache[src_sr](t)
|
| 73 |
return y.squeeze(0).numpy()
|
| 74 |
+
|
| 75 |
RESAMPLER = ResamplerCache()
|
| 76 |
+
|
| 77 |
def load_mono16k(path: str) -> np.ndarray:
|
| 78 |
"""Load any audio file, convert to mono float32 at 16 kHz."""
|
| 79 |
try:
|
| 80 |
+
wav, sr = sf.read(path, dtype="float32", always_2d=True) # (T,C)
|
| 81 |
wav = wav.mean(axis=1).astype(np.float32, copy=False)
|
| 82 |
return RESAMPLER.resample(wav, sr)
|
| 83 |
except Exception:
|
| 84 |
+
wav_t, sr = torchaudio.load(path) # (C,T)
|
| 85 |
if wav_t.dtype != torch.float32:
|
| 86 |
wav_t = wav_t.float()
|
| 87 |
wav = wav_t.mean(dim=0).numpy()
|
| 88 |
return RESAMPLER.resample(wav, int(sr))
|
| 89 |
+
|
| 90 |
# ----------------------------
|
| 91 |
# Model manager (MALSD batched beam everywhere, loop_labels=True)
|
| 92 |
# ----------------------------
|
|
|
|
| 99 |
self.model.eval()
|
| 100 |
for p in self.model.parameters():
|
| 101 |
p.requires_grad = False
|
| 102 |
+
|
| 103 |
# Base decoding cfg differs by class
|
| 104 |
if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"):
|
| 105 |
self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg)
|
| 106 |
else:
|
| 107 |
self._base_decoding = copy.deepcopy(self.model.cfg.decoding)
|
| 108 |
+
|
| 109 |
self._set_malsd_beam()
|
| 110 |
+
|
| 111 |
# Enable encoder caching for better streaming context (per NeMo docs/tutorials)
|
| 112 |
if hasattr(self.model.encoder, "set_default_att_context_size"):
|
| 113 |
+
self.model.encoder.set_default_att_context_size([512, 16]) # Large left for cumulative context, small right for buffering
|
| 114 |
logger.info("encoder_caching_enabled left=512 right=16")
|
| 115 |
+
|
| 116 |
logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
|
| 117 |
+
|
| 118 |
def _set_malsd_beam(self):
|
| 119 |
cfg = copy.deepcopy(self._base_decoding)
|
| 120 |
cfg.strategy = "malsd_batch"
|
|
|
|
| 122 |
"beam_size": BEAM_SIZE,
|
| 123 |
"return_best_hypothesis": True,
|
| 124 |
"score_norm": True,
|
| 125 |
+
"allow_cuda_graphs": False, # CPU-only
|
| 126 |
"max_symbols_per_step": 10,
|
| 127 |
})
|
| 128 |
OmegaConf.set_struct(cfg, False)
|
| 129 |
cfg["loop_labels"] = True
|
| 130 |
cfg["fused_batch_size"] = -1
|
| 131 |
+
cfg["compute_timestamps"] = False
|
| 132 |
if hasattr(cfg, "greedy"):
|
| 133 |
cfg.greedy.use_cuda_graph_decoder = False
|
| 134 |
self.model.change_decoding_strategy(cfg)
|
| 135 |
logger.info("decoding_set strategy=malsd_batch loop_labels=True")
|
| 136 |
+
|
| 137 |
def _transcribe(self, items: List, *, partial=None):
|
| 138 |
with torch.inference_mode():
|
| 139 |
return self.model.transcribe(
|
|
|
|
| 143 |
return_hypotheses=True,
|
| 144 |
partial_hypothesis=partial,
|
| 145 |
)
|
| 146 |
+
|
| 147 |
# Offline batch
|
| 148 |
def transcribe_files(self, paths: List[str]):
|
| 149 |
n = 0 if not paths else len(paths)
|
|
|
|
| 156 |
for p, o in zip(paths, out):
|
| 157 |
h = o[0] if isinstance(o, list) and o else o
|
| 158 |
text = h if isinstance(h, str) else getattr(h, "text", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
results.append({"path": p, "text": text})
|
| 160 |
logger.info("files_run ok")
|
| 161 |
return results
|
| 162 |
+
|
| 163 |
# Streaming step (rolling hypothesis)
|
| 164 |
def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object:
|
| 165 |
out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None)
|
| 166 |
h = out[0][0] if isinstance(out[0], list) else out[0]
|
| 167 |
+
return h # Hypothesis
|
| 168 |
+
|
| 169 |
# ----------------------------
|
| 170 |
+
# Streaming session (no overlap, rolling hypothesis)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
# ----------------------------
|
| 172 |
class StreamingSession:
|
| 173 |
def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float):
|
|
|
|
| 177 |
self.hyp = None
|
| 178 |
self.pending = np.zeros(0, dtype=np.float32)
|
| 179 |
self.text = ""
|
|
|
|
| 180 |
logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s")
|
| 181 |
+
|
| 182 |
def add_audio(self, audio: np.ndarray, src_sr: int):
|
| 183 |
mono = to_mono_np(audio)
|
| 184 |
res = RESAMPLER.resample(mono, src_sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
self.pending = np.concatenate([self.pending, res]) if self.pending.size else res
|
| 186 |
self._drain()
|
| 187 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
def _drain(self):
|
| 189 |
C = int(self.chunk_s * TARGET_SR)
|
| 190 |
while self.pending.size >= C:
|
| 191 |
chunk = self.pending[:C]
|
| 192 |
self.pending = self.pending[C:]
|
| 193 |
try:
|
| 194 |
+
self.hyp = self.mgr.stream_step(chunk, self.hyp)
|
| 195 |
+
new_text = getattr(self.hyp, "text", "")
|
| 196 |
+
if new_text:
|
| 197 |
+
if self.text and new_text.startswith(self.text): # If cumulative (partial extends), replace with extended
|
| 198 |
+
self.text = new_text
|
| 199 |
+
else: # Else append (handles per-chunk case)
|
| 200 |
+
self.text += (' ' if self.text else '') + new_text
|
| 201 |
except Exception:
|
| 202 |
logger.exception("mic_step failed")
|
| 203 |
break
|
| 204 |
+
|
| 205 |
def flush(self) -> str:
|
| 206 |
if self.pending.size:
|
| 207 |
pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32)
|
| 208 |
final = np.concatenate([self.pending, pad])
|
| 209 |
try:
|
| 210 |
+
self.hyp = self.mgr.stream_step(final, self.hyp)
|
| 211 |
+
new_text = getattr(self.hyp, "text", "")
|
| 212 |
+
if new_text:
|
| 213 |
+
if self.text and new_text.startswith(self.text):
|
| 214 |
+
self.text = new_text
|
| 215 |
+
else:
|
| 216 |
+
self.text += (' ' if self.text else '') + new_text
|
| 217 |
+
self.text += '.' # Add period for sentence closure on flush
|
| 218 |
except Exception:
|
| 219 |
logger.exception("mic_flush failed")
|
| 220 |
self.pending = np.zeros(0, dtype=np.float32)
|
| 221 |
return self.text
|
| 222 |
+
|
| 223 |
# ----------------------------
|
| 224 |
# Simple session registry (avoid deepcopy in gr.State)
|
| 225 |
# ----------------------------
|
| 226 |
SESS: Dict[str, StreamingSession] = {}
|
| 227 |
def _new_session_id() -> str:
|
| 228 |
return uuid.uuid4().hex
|
| 229 |
+
|
| 230 |
# ----------------------------
|
| 231 |
# Gradio callbacks
|
| 232 |
# ----------------------------
|
| 233 |
MANAGER = ParakeetManager(device="cpu")
|
| 234 |
+
|
| 235 |
def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
|
| 236 |
if x is None:
|
| 237 |
return np.zeros(0, dtype=np.float32), TARGET_SR
|
|
|
|
| 242 |
if isinstance(x, np.ndarray):
|
| 243 |
return x.astype(np.float32, copy=False), TARGET_SR
|
| 244 |
logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload")
|
| 245 |
+
|
| 246 |
def mic_step(audio_chunk, sess_id: Optional[str]):
|
| 247 |
if not sess_id or sess_id not in SESS:
|
| 248 |
sess_id = _new_session_id()
|
|
|
|
| 256 |
if wav.size:
|
| 257 |
sess.add_audio(wav, sr)
|
| 258 |
return sess_id, sess.text
|
| 259 |
+
|
| 260 |
def mic_flush(sess_id: Optional[str]):
|
| 261 |
if not sess_id or sess_id not in SESS:
|
| 262 |
return None, ""
|
| 263 |
text = SESS[sess_id].flush()
|
| 264 |
logger.info("mic_flush ok")
|
| 265 |
return None, text
|
| 266 |
+
|
| 267 |
def files_run(files):
|
| 268 |
n = 0 if not files else len(files)
|
| 269 |
logger.info(f"files_ui start count={n}")
|
|
|
|
| 282 |
table = [[os.path.basename(r["path"]), r["text"]] for r in results]
|
| 283 |
logger.info("files_ui ok")
|
| 284 |
return table
|
| 285 |
+
|
| 286 |
# ----------------------------
|
| 287 |
# UI
|
| 288 |
# ----------------------------
|
|
|
|
| 291 |
mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Speak")
|
| 292 |
text_out = gr.Textbox(label="Transcript", lines=8)
|
| 293 |
flush_btn = gr.Button("Flush")
|
| 294 |
+
state_id = gr.State() # only a string id
|
| 295 |
mic.stream(mic_step, inputs=[mic, state_id], outputs=[state_id, text_out])
|
| 296 |
flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out])
|
| 297 |
+
|
| 298 |
with gr.Tab("Files"):
|
| 299 |
files = gr.File(file_count="multiple", type="filepath", label="Upload audio files")
|
| 300 |
run_btn = gr.Button("Run")
|
| 301 |
results_table = gr.Dataframe(headers=["file", "text"], label="Results",
|
| 302 |
row_count=(0, "dynamic"), col_count=(2, "fixed"))
|
| 303 |
run_btn.click(files_run, inputs=[files], outputs=[results_table])
|
| 304 |
+
|
| 305 |
demo.queue().launch(ssr_mode=False)
|