streaming-speech-translation / src /asr /cache_aware_modules.py
pltobing's picture
Formatting black, isort, flake8
0c397a9
#!/usr/bin/env python3
# License: CC-BY-NC-ND-4.0
# Created by: Patrick Lumbantobing, Vertox-AI
# Copyright (c) 2026 Vertox-AI. All rights reserved.
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-NoDerivatives 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-nd/4.0/
"""
Cache-aware streaming audio and feature buffers for Nemotron ASR.
Adapted from: https://github.com/NVIDIA-NeMo/NeMo/tree/main
Implements:
- :class:`CacheAwareStreamingAudioBuffer` for audio → feature chunks
compatible with NeMo cache-aware encoders.
- :class:`CacheAwareStreamingASR` for encoder/decoder state management,
hypothesis accumulation, and timestamped text output.
"""
from __future__ import annotations
import re
from collections.abc import Iterable
from typing import Generator, List, Optional
import numpy as np
import numpy.typing as npt
from src.asr.cache_aware_modules_config import (CacheAwareStreamingConfig,
TimestampedResult)
from src.asr.utils import log_softmax
LOG_ZERO_GUARD_VALUE = float(2**-24)
class CacheAwareStreamingAudioBuffer:
"""
Streaming audio and feature buffer for cache-aware ASR.
Handles:
- Chunking raw audio into overlapping frames for the preprocessor.
- Dropping padded STFT frames after the first chunk.
- Maintaining a feature buffer with pre-encode cache appended.
"""
def __init__(self, preprocessor, streaming_cfg: CacheAwareStreamingConfig) -> None:
"""
Parameters
----------
preprocessor :
Callable that maps ``(waveforms, lengths)`` to
``(features, feature_lengths)``.
streaming_cfg :
Cache-aware streaming configuration.
"""
self._preprocessor = preprocessor
self._streaming_cfg = streaming_cfg
self.audio_buffer: Optional[npt.NDArray[np.float32]] = None
self.audio_step: int = 0
self.features_buffer: Optional[npt.NDArray[np.float32]] = None
self._audio_chunks_lens = np.array(
[self._streaming_cfg.audio_chunk_frames * self._streaming_cfg.audio_frame_size],
dtype=np.int64,
)
self._audio_frames_drops_lens = (
self._streaming_cfg.audio_chunk_frames_drop * self._streaming_cfg.audio_frame_size
)
self._features_frames_takes_lens = self._streaming_cfg.audio_chunk_frames - 1
self._chunk_size = self._streaming_cfg.chunk_size[1]
self._shift_size = self._streaming_cfg.shift_size[1]
self._pre_encode_cache_size = self._streaming_cfg.pre_encode_cache_size[1]
self._cache_chunk_size = self._pre_encode_cache_size + self._chunk_size
self._features_chunk_lengths = np.array([self._cache_chunk_size], dtype=np.int64)
self._current_text: str = ""
self._first_cache_pre_encode = np.log(
np.zeros(
(1, self._streaming_cfg.input_features, self._pre_encode_cache_size),
dtype=np.float32,
)
+ LOG_ZERO_GUARD_VALUE
)
def len_audio_buffer(self) -> int:
"""Return current audio buffer length (samples)."""
return int(self.audio_buffer.shape[-1]) if self.audio_buffer is not None else 0
def len_features_buffer(self) -> int:
"""Return current feature buffer length (frames)."""
return int(self.features_buffer.shape[-1]) if self.features_buffer is not None else 0
def reset_buffers(self) -> None:
"""Reset both audio and feature buffers."""
self.reset_audio_buffer()
self.reset_features_buffer()
def reset_audio_buffer(self) -> None:
"""Reset audio buffer and step counter."""
self.audio_buffer = None
self.audio_step = 0
def reset_features_buffer(self) -> None:
"""Reset feature buffer."""
self.features_buffer = None
def append_audio_buffer(self, audio_signal: npt.NDArray[np.float32]) -> None:
"""Append new audio samples to the buffer."""
if self.audio_buffer is None:
self.audio_buffer = audio_signal
else:
self.audio_buffer = np.concatenate((self.audio_buffer, audio_signal), axis=-1).astype(np.float32)
def process_audio_buffer(
self,
last: bool = False,
) -> Generator[Optional[npt.NDArray[np.float32]], None, None]:
"""
Convert buffered audio into feature chunks.
Yields
------
np.ndarray or None
Feature chunks of shape ``(1, feats, frames)`` or ``None`` when
no more chunks are available.
"""
if self.audio_buffer is None:
if last:
yield None
return
while self._audio_chunks_lens[0] <= self.audio_buffer.shape[-1]:
audio_chunks = self.audio_buffer[:, : self._audio_chunks_lens[0]]
audio_features, _ = self._preprocessor(audio_chunks, self._audio_chunks_lens)
self.audio_buffer = self.audio_buffer[:, self._audio_frames_drops_lens :]
if self.audio_step > 0:
audio_features = audio_features[
:,
:,
self._streaming_cfg.audio_chunk_frames_drop : self._features_frames_takes_lens,
]
else:
audio_features = audio_features[:, :, : self._features_frames_takes_lens]
self.audio_step += self._audio_frames_drops_lens
yield audio_features
if last and self.audio_buffer is not None and self.audio_buffer.shape[-1] > 0:
n_pad = self._audio_chunks_lens[0] - self.audio_buffer.shape[-1]
zeros_pad = np.zeros((1, n_pad), dtype=np.float32)
self.audio_buffer = np.concatenate((self.audio_buffer, zeros_pad), axis=-1).astype(np.float32)
audio_chunks = self.audio_buffer[:, : self._audio_chunks_lens[0]]
audio_features, _ = self._preprocessor(audio_chunks, self._audio_chunks_lens)
self.audio_buffer = self.audio_buffer[:, self._audio_chunks_lens[0] :]
if self.audio_step > 0:
yield audio_features[:, :, self._streaming_cfg.audio_chunk_frames_drop :]
else:
yield audio_features
self.reset_audio_buffer()
yield None
def append_audio_buffer_to_process_for_features(
self,
audio_signal: npt.NDArray[np.float32],
last: bool = False,
) -> Generator[Optional[npt.NDArray[np.float32]], None, None]:
"""Append audio and immediately yield any ready feature chunks."""
self.append_audio_buffer(audio_signal)
return self.process_audio_buffer(last=last)
def append_features_buffer(self, audio_features: npt.NDArray[np.float32]) -> None:
"""Append new feature frames, preprending initial pre-encode cache if needed."""
if self.features_buffer is None:
self.features_buffer = np.concatenate((self._first_cache_pre_encode, audio_features), axis=-1).astype(
np.float32
)
else:
self.features_buffer = np.concatenate((self.features_buffer, audio_features), axis=-1).astype(np.float32)
def process_features_buffer(
self,
last: bool = False,
) -> Generator[Optional[npt.NDArray[np.float32]], None, None]:
"""
Convert feature buffer into encoder-ready feature chunks.
Yields
------
np.ndarray or None
Feature chunks of shape ``(1, feats, cache_chunk_size)`` or
``None`` when no more chunks are available.
"""
if self.features_buffer is None:
if last:
yield None
return
while self._cache_chunk_size <= self.features_buffer.shape[-1]:
features_chunk = self.features_buffer[:, :, : self._cache_chunk_size]
self.features_buffer = self.features_buffer[:, :, self._shift_size :]
yield features_chunk
if last and self.features_buffer.shape[-1] > 0:
n_pad = self._cache_chunk_size - self.features_buffer.shape[-1]
zeros_pad = np.log(
np.zeros(
(1, self.features_buffer.shape[1], n_pad),
dtype=np.float32,
)
+ LOG_ZERO_GUARD_VALUE
)
features_chunk = np.concatenate((self.features_buffer, zeros_pad), axis=-1).astype(np.float32)
self.features_buffer = self.features_buffer[:, :, self._cache_chunk_size :]
yield features_chunk
self.reset_features_buffer()
yield None
def append_features_buffer_to_process_for_features_chunk(
self,
audio_features: npt.NDArray[np.float32],
last: bool = False,
) -> Generator[Optional[npt.NDArray[np.float32]], None, None]:
"""Append features and immediately yield any ready feature chunks."""
self.append_features_buffer(audio_features)
return self.process_features_buffer(last=last)
class CacheAwareStreamingASR:
"""
Cache-aware streaming ASR wrapper around encoder/decoder ONNX models.
Maintains encoder caches, decoder recurrent state, and an evolving
hypothesis (tokens, timestamps, logprobs), producing incremental
:class:`TimestampedResult` objects from feature chunks.
"""
def __init__(
self,
asr_encoder,
asr_decoder,
vocab: List[int],
blank_idx: int,
streaming_cfg: CacheAwareStreamingConfig,
) -> None:
"""
Parameters
----------
asr_encoder :
ONNX Runtime session for the cache-aware encoder.
asr_decoder :
ONNX Runtime session for the decoder/joint network.
vocab :
Mapping from token IDs to text pieces.
blank_idx :
Index of the blank label in the vocabulary.
streaming_cfg :
Cache-aware streaming configuration.
"""
self._asr_encoder = asr_encoder
self._asr_decoder = asr_decoder
self._vocab = vocab
self._vocab_size = len(self._vocab)
self._blank_idx = blank_idx
self._streaming_cfg = streaming_cfg
# encoder cache
self._cache_last_channel: npt.NDArray[np.float32] | None = None
self._cache_last_time: npt.NDArray[np.float32] | None = None
self._cache_last_channel_len: npt.NDArray[np.int64] | None = None
self.set_init_encoder_cache()
# encoder lengths
self._chunk_size = self._streaming_cfg.chunk_size[1]
self._pre_encode_cache_size = self._streaming_cfg.pre_encode_cache_size[1]
self._cache_chunk_size = self._pre_encode_cache_size + self._chunk_size
self._features_chunk_lengths = np.array([self._cache_chunk_size], dtype=np.int64)
self._encoder_out_lengths = np.array(
[self._streaming_cfg.valid_encoder_out_len],
dtype=np.int64,
)
# decoder state
self._prev_state: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]] | None = None
self._tokens: List[int] | None = None
self._timestamps: List[int] | None = None
self._logprobs: List[float] | None = None
self._t_index: int | None = None
self.set_init_decoder_state()
self.set_init_decoder_vars()
self._current_text: str = ""
self._DECODE_SPACE_PATTERN = re.compile(r"\A\s|\s\B|(\s)\b")
def set_init_encoder_cache(self) -> None:
"""Initialise encoder caches to zeros."""
self._cache_last_channel = np.zeros(
(
self._streaming_cfg.len_layers,
1,
self._streaming_cfg.last_channel_cache_size,
self._streaming_cfg.d_model,
),
dtype=np.float32,
).transpose(1, 0, 2, 3)
self._cache_last_time = np.zeros(
(
self._streaming_cfg.len_layers,
1,
self._streaming_cfg.d_model,
self._streaming_cfg.conv_context_size[0],
),
dtype=np.float32,
).transpose(1, 0, 2, 3)
self._cache_last_channel_len = np.zeros(1, dtype=np.int64)
def set_init_decoder_state(self) -> None:
"""Initialise decoder hidden states to zeros based on input shapes."""
shapes = {x.name: x.shape for x in self._asr_decoder.get_inputs()}
self._prev_state = (
np.zeros(
shape=(shapes["input_states_1"][0], 1, shapes["input_states_1"][2]),
dtype=np.float32,
),
np.zeros(
shape=(shapes["input_states_2"][0], 1, shapes["input_states_2"][2]),
dtype=np.float32,
),
)
def set_init_decoder_vars(self) -> None:
"""Reset token, timestamp, logprob lists and time index."""
self._tokens = []
self._timestamps = []
self._logprobs = []
self._t_index = 0
def reset_states(self) -> None:
"""Reset encoder cache, decoder state, and current text."""
self.set_init_encoder_cache()
self.set_init_decoder_state()
self.set_init_decoder_vars()
self._current_text = ""
def process_encoder_step(
self,
features_chunk: npt.NDArray[np.float32],
) -> npt.NDArray[np.float32]:
"""
Run one encoder step with cache-aware inputs.
Returns
-------
encoder_out: ``(batch, time, dimension)``
"""
assert self._features_chunk_lengths[0] == features_chunk.shape[-1]
(
encoder_out,
encoder_out_lens,
cache_last_channel_next,
cache_last_time_next,
cache_last_channel_next_len,
) = self._asr_encoder.run(
[
"outputs",
"encoded_lengths",
"cache_last_channel_next",
"cache_last_time_next",
"cache_last_channel_next_len",
],
{
"audio_signal": features_chunk,
"length": self._features_chunk_lengths,
"cache_last_channel": self._cache_last_channel,
"cache_last_time": self._cache_last_time,
"cache_last_channel_len": self._cache_last_channel_len,
},
)
self._cache_last_channel = cache_last_channel_next
self._cache_last_time = cache_last_time_next
self._cache_last_channel_len = cache_last_channel_next_len
return encoder_out.transpose(0, 2, 1)
def _decode_tokens(
self, ids: Iterable[int], indices: Iterable[int] | None, logprobs: Iterable[float] | None
) -> TimestampedResult:
"""
Decode token ids including timestamps, running text, and text delta.
Returns
-------
TimestampedResult:
contains running text, timestamps, all tokens, all logprobs, and text delta
"""
tokens = [self._vocab[i] for i in ids]
text = re.sub(self._DECODE_SPACE_PATTERN, lambda x: " " if x.group(1) else "", "".join(tokens))
n_added_chars = len(text) - len(self._current_text)
added_text = text[-n_added_chars:] if n_added_chars > 0 else ""
timestamps = (
None
if indices is None
else (
self._streaming_cfg.window_step * self._streaming_cfg.subsampling_factor * np.asarray(indices)
).tolist()
)
return TimestampedResult(
text, timestamps, tokens, None if logprobs is None else np.asarray(logprobs).tolist(), added_text
)
def process_decoder_step(self, encoder_out):
"""
Run decoder steps with chunked encoder output.
Returns
-------
text: string
full transcript from the start
added_text: string
text delta
"""
encodings = encoder_out[0]
encodings_len = self._encoder_out_lengths[0]
assert encodings_len == encodings.shape[0]
step = 0
emitted_tokens = 0
while step < encodings_len:
outputs, state1, state2 = self._asr_decoder.run(
["outputs", "output_states_1", "output_states_2"],
{
"encoder_outputs": encodings[step : step + 1, :, None],
"targets": [[self._tokens[-1] if self._tokens else self._blank_idx]],
"target_length": [1],
"input_states_1": self._prev_state[0],
"input_states_2": self._prev_state[1],
},
)
logits = outputs.squeeze()
state = (state1, state2)
assert logits.shape[-1] <= self._vocab_size
token = logits.argmax()
if token != self._blank_idx:
self._prev_state = state
self._tokens.append(int(token))
self._timestamps.append(self._t_index)
emitted_tokens += 1
self._logprobs.append(log_softmax(logits)[token])
if token == self._blank_idx or emitted_tokens == self._streaming_cfg.max_tokens_per_step:
self._t_index += 1
emitted_tokens = 0
step += 1
if len(self._tokens) > 0:
res = self._decode_tokens(self._tokens, self._timestamps, self._logprobs)
self._current_text = res.text
return res.text, res.added_text
else:
return None, None