| |
| |
|
|
| """ |
| 流式Mel特征处理器 |
| |
| 用于实时音频流的Mel频谱特征提取,支持chunk-based处理。 |
| 支持配置CNN冗余以保证与离线处理的一致性。 |
| """ |
|
|
| import logging |
| from typing import Dict |
| from typing import Optional |
| from typing import Tuple |
|
|
| import numpy as np |
| import torch |
|
|
| from .processing_audio_minicpma import MiniCPMAAudioProcessor |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class StreamingMelProcessorExact: |
| """ |
| 严格离线等价的流式Mel处理器。 |
| |
| 思路: |
| - 累积全部历史音频到缓冲;每次新增后用同一个 feature_extractor 计算整段 mel。 |
| - 只输出"已稳定"的帧:帧中心不依赖未来(右侧)上下文,即 center + n_fft//2 <= 当前缓冲长度。 |
| - 结束时(flush)再输出最后一批帧,确保与离线全量计算完全一致。 |
| |
| 代价:每次会对累积缓冲做一次特征提取(可按需优化为增量)。 |
| """ |
|
|
| def __init__( |
| self, |
| feature_extractor: MiniCPMAAudioProcessor, |
| chunk_ms: int = 100, |
| first_chunk_ms: Optional[int] = None, |
| sample_rate: int = 16000, |
| n_fft: int = 400, |
| hop_length: int = 160, |
| n_mels: int = 80, |
| verbose: bool = False, |
| cnn_redundancy_ms: int = 10, |
| |
| enable_sliding_window: bool = False, |
| slide_trigger_seconds: float = 30.0, |
| slide_stride_seconds: float = 10.0, |
| ): |
| self.feature_extractor = feature_extractor |
| self.chunk_ms = chunk_ms |
| self.first_chunk_ms = first_chunk_ms if first_chunk_ms is not None else chunk_ms |
| self.sample_rate = sample_rate |
| self.n_fft = n_fft |
| self.hop_length = hop_length |
| self.n_mels = n_mels |
| self.verbose = verbose |
|
|
| self.chunk_samples = int(round(chunk_ms * sample_rate / 1000)) |
| self.chunk_frames = self.chunk_samples // hop_length |
| |
| hop = self.hop_length |
| raw_first_samples = int(round(self.first_chunk_ms * sample_rate / 1000)) |
| aligned_first = max(hop, (raw_first_samples // hop) * hop) |
| self.first_chunk_samples = aligned_first |
| self.half_window = n_fft // 2 |
|
|
| |
| self.cnn_redundancy_ms = cnn_redundancy_ms |
| self.cnn_redundancy_samples = int(cnn_redundancy_ms * sample_rate / 1000) |
| self.cnn_redundancy_frames = max(0, self.cnn_redundancy_samples // hop_length) |
|
|
| |
| self.enable_sliding_window = enable_sliding_window |
| self.trigger_seconds = slide_trigger_seconds |
| self.slide_seconds = slide_stride_seconds |
|
|
| |
| self.left_samples_dropped = 0 |
| self.base_T = 0 |
|
|
| self.reset() |
|
|
| def reset(self): |
| self.buffer = np.zeros(0, dtype=np.float32) |
| self.last_emitted_T = 0 |
| self.total_samples_processed = 0 |
| self.chunk_count = 0 |
| self.is_first = True |
| self.left_samples_dropped = 0 |
| self.base_T = 0 |
|
|
| def get_chunk_size(self) -> int: |
| return self.first_chunk_samples if self.is_first else self.chunk_samples |
|
|
| def get_expected_output_frames(self) -> int: |
| raise NotImplementedError("get_expected_output_frames is not implemented") |
|
|
| def _extract_full(self) -> torch.Tensor: |
| |
| |
| if len(self.buffer) < self.n_fft: |
| raise ValueError(f"buffer length is shorter than n_fft {len(self.buffer)} < {self.n_fft}") |
| |
| if len(self.buffer) < 5 * self.sample_rate: |
| |
| self.feature_extractor.set_spac_log_norm(log_floor_db=-10) |
| |
| else: |
| self.feature_extractor.set_spac_log_norm(dynamic_range_db=8) |
| feats = self.feature_extractor( |
| self.buffer, |
| sampling_rate=self.sample_rate, |
| return_tensors="pt", |
| padding=False, |
| ) |
| return feats.input_features |
|
|
| def _stable_frames_count(self) -> int: |
| |
| L = int(self.buffer.shape[0]) |
| if L <= 0: |
| return 0 |
| if L < self.half_window: |
| return 0 |
| return max(0, (L - self.half_window) // self.hop_length + 1) |
|
|
| def _maybe_slide_buffer(self): |
| """Trigger模式滑窗:当缓冲区达到触发阈值时,滑动固定长度的窗口。""" |
| if not self.enable_sliding_window: |
| return |
|
|
| sr = self.sample_rate |
| hop = self.hop_length |
| L = len(self.buffer) |
|
|
| |
| trigger_samples = int(self.trigger_seconds * sr) |
| stride_samples = int(self.slide_seconds * sr) |
|
|
| |
| if L < trigger_samples: |
| return |
|
|
| |
| drop = stride_samples |
|
|
| |
| |
| |
| last_emitted_local = self.last_emitted_T - self.base_T |
|
|
| |
| min_keep_seconds = 1.0 |
| min_keep_samples = int(min_keep_seconds * sr) |
|
|
| |
| guard_samples = min(min_keep_samples, L - drop) |
|
|
| |
| max_allowed_drop = max(0, L - guard_samples) |
| drop = min(drop, max_allowed_drop) |
| drop = (drop // hop) * hop |
|
|
| if drop <= 0: |
| return |
|
|
| |
| self.buffer = self.buffer[drop:] |
| self.left_samples_dropped += drop |
| self.base_T += drop // hop |
|
|
| if self.verbose: |
| print( |
| f"[Slide] Trigger模式: drop={drop/sr:.2f}s samples, base_T={self.base_T}, buffer_after={len(self.buffer)/sr:.2f}s" |
| ) |
|
|
| def process(self, audio_chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[torch.Tensor, Dict]: |
| self.chunk_count += 1 |
| |
| if len(self.buffer) == 0: |
| self.buffer = audio_chunk.astype(np.float32, copy=True) |
| else: |
| self.buffer = np.concatenate([self.buffer, audio_chunk.astype(np.float32, copy=True)]) |
|
|
| |
| self._maybe_slide_buffer() |
|
|
| |
| mel_full = self._extract_full() |
| T_full = mel_full.shape[-1] |
| stable_T = min(T_full, self._stable_frames_count()) |
| stable_T_global = self.base_T + stable_T |
|
|
| |
| core_start_g = self.last_emitted_T |
| core_end_g = core_start_g + self.chunk_frames |
| required_stable_g = core_end_g + self.cnn_redundancy_frames |
|
|
| if self.verbose: |
| print( |
| f"[Exact] buffer_len={len(self.buffer)} samples, T_full(local)={T_full}, " |
| f"stable_T(local)={stable_T}, base_T={self.base_T}, " |
| f"stable_T(global)={stable_T_global}, last_emitted={self.last_emitted_T}" |
| ) |
|
|
| if stable_T_global >= required_stable_g or is_last_chunk: |
| emit_start_g = max(0, core_start_g - self.cnn_redundancy_frames) |
| emit_end_g = core_end_g + self.cnn_redundancy_frames |
|
|
| |
| emit_start = max(0, emit_start_g - self.base_T) |
| emit_end = emit_end_g - self.base_T |
| emit_start = max(0, min(emit_start, T_full)) |
| emit_end = max(emit_start, min(emit_end, T_full)) |
|
|
| mel_output = mel_full[:, :, emit_start:emit_end] |
| self.last_emitted_T = core_end_g |
| else: |
| mel_output = mel_full[:, :, 0:0] |
|
|
| self.total_samples_processed += len(audio_chunk) |
| self.is_first = False |
|
|
| info = { |
| "type": "exact_chunk", |
| "chunk_number": self.chunk_count, |
| "emitted_frames": mel_output.shape[-1], |
| "stable_T": stable_T, |
| "T_full": T_full, |
| "base_T": self.base_T, |
| "stable_T_global": stable_T_global, |
| "buffer_len_samples": int(self.buffer.shape[0]), |
| "left_samples_dropped": self.left_samples_dropped, |
| "core_start": core_start_g, |
| "core_end": core_end_g, |
| } |
| return mel_output, info |
|
|
| def flush(self) -> torch.Tensor: |
| """在流结束时调用,输出剩余未发出的帧,保证与离线一致(按全局坐标计算)。""" |
| if len(self.buffer) == 0: |
| return torch.zeros(1, 80, 0) |
|
|
| mel_full = self._extract_full() |
| T_local = mel_full.shape[-1] |
| T_global = self.base_T + T_local |
|
|
| if self.last_emitted_T < T_global: |
| start_l = max(0, self.last_emitted_T - self.base_T) |
| tail = mel_full[:, :, start_l:] |
| self.last_emitted_T = T_global |
| if self.verbose: |
| print(f"[Exact] flush {tail.shape[-1]} frames (T_global={T_global})") |
| return tail |
| return mel_full[:, :, 0:0] |
|
|
| def get_config(self) -> Dict: |
| return { |
| "chunk_ms": self.chunk_ms, |
| "first_chunk_ms": self.first_chunk_ms, |
| "effective_first_chunk_ms": self.first_chunk_samples / self.sample_rate * 1000.0, |
| "sample_rate": self.sample_rate, |
| "n_fft": self.n_fft, |
| "hop_length": self.hop_length, |
| "cnn_redundancy_ms": self.cnn_redundancy_ms, |
| "cnn_redundancy_frames": self.cnn_redundancy_frames, |
| "enable_sliding_window": self.enable_sliding_window, |
| "trigger_seconds": self.trigger_seconds, |
| "slide_seconds": self.slide_seconds, |
| } |
|
|
| def get_state(self) -> Dict: |
| return { |
| "chunk_count": self.chunk_count, |
| "last_emitted_T": self.last_emitted_T, |
| "total_samples_processed": self.total_samples_processed, |
| "buffer_len": int(self.buffer.shape[0]), |
| "base_T": self.base_T, |
| "left_samples_dropped": self.left_samples_dropped, |
| } |
|
|
| def get_snapshot(self) -> Dict: |
| """获取完整状态快照(包括 buffer),用于抢跑恢复 |
| |
| Returns: |
| 包含完整状态的字典,可用于 restore_snapshot 恢复 |
| """ |
| buffer_copy = self.buffer.copy() |
| snapshot = { |
| "chunk_count": self.chunk_count, |
| "last_emitted_T": self.last_emitted_T, |
| "total_samples_processed": self.total_samples_processed, |
| "buffer": buffer_copy, |
| "base_T": self.base_T, |
| "left_samples_dropped": self.left_samples_dropped, |
| "is_first": self.is_first, |
| |
| "fe_dynamic_log_norm": getattr(self.feature_extractor, "dynamic_log_norm", None), |
| "fe_dynamic_range_db": getattr(self.feature_extractor, "dynamic_range_db", None), |
| "fe_log_floor_db": getattr(self.feature_extractor, "log_floor_db", None), |
| } |
| logger.debug( |
| "[MelProcessor] Created snapshot: chunk_count=%d, last_emitted_T=%d, " |
| "buffer_len=%d, buffer_sum=%.6f, total_samples=%d", |
| self.chunk_count, |
| self.last_emitted_T, |
| len(buffer_copy), |
| float(buffer_copy.sum()) if len(buffer_copy) > 0 else 0.0, |
| self.total_samples_processed, |
| ) |
| return snapshot |
|
|
| def restore_snapshot(self, snapshot: Dict) -> None: |
| """从快照恢复状态 |
| |
| Args: |
| snapshot: 由 get_snapshot 返回的快照字典 |
| """ |
| |
| prev_state = { |
| "chunk_count": self.chunk_count, |
| "last_emitted_T": self.last_emitted_T, |
| "buffer_len": len(self.buffer), |
| } |
|
|
| |
| self.chunk_count = snapshot["chunk_count"] |
| self.last_emitted_T = snapshot["last_emitted_T"] |
| self.total_samples_processed = snapshot["total_samples_processed"] |
| self.buffer = snapshot["buffer"].copy() |
| self.base_T = snapshot["base_T"] |
| self.left_samples_dropped = snapshot["left_samples_dropped"] |
| self.is_first = snapshot["is_first"] |
|
|
| |
| if snapshot.get("fe_dynamic_log_norm") is not None: |
| self.feature_extractor.dynamic_log_norm = snapshot["fe_dynamic_log_norm"] |
| if snapshot.get("fe_dynamic_range_db") is not None: |
| self.feature_extractor.dynamic_range_db = snapshot["fe_dynamic_range_db"] |
| if snapshot.get("fe_log_floor_db") is not None: |
| self.feature_extractor.log_floor_db = snapshot["fe_log_floor_db"] |
|
|
| logger.info( |
| "[MelProcessor] Restored snapshot: chunk_count %d->%d, last_emitted_T %d->%d, " |
| "buffer_len %d->%d, total_samples=%d", |
| prev_state["chunk_count"], |
| self.chunk_count, |
| prev_state["last_emitted_T"], |
| self.last_emitted_T, |
| prev_state["buffer_len"], |
| len(self.buffer), |
| self.total_samples_processed, |
| ) |
|
|