| | |
| | """Get exact NeMo streaming inference output for comparison with Swift.""" |
| |
|
| | import os |
| | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
| |
|
| | import torch |
| | import numpy as np |
| | import librosa |
| | import json |
| |
|
| | from nemo.collections.asr.models import SortformerEncLabelModel |
| |
|
| | def main(): |
| | print("Loading NeMo model...") |
| | model = SortformerEncLabelModel.restore_from( |
| | 'diar_streaming_sortformer_4spk-v2.nemo', map_location='cpu' |
| | ) |
| | model.eval() |
| |
|
| | |
| | if hasattr(model.preprocessor, 'featurizer'): |
| | if hasattr(model.preprocessor.featurizer, 'dither'): |
| | model.preprocessor.featurizer.dither = 0.0 |
| |
|
| | |
| | modules = model.sortformer_modules |
| | modules.chunk_len = 6 |
| | modules.chunk_left_context = 1 |
| | modules.chunk_right_context = 7 |
| | modules.fifo_len = 40 |
| | modules.spkcache_len = 188 |
| | modules.spkcache_update_period = 31 |
| |
|
| | print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, right_ctx={modules.chunk_right_context}") |
| | print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}") |
| |
|
| | |
| | audio_path = "../audio.wav" |
| | audio, sr = librosa.load(audio_path, sr=16000, mono=True) |
| | print(f"Loaded audio: {len(audio)} samples ({len(audio)/16000:.2f}s)") |
| |
|
| | waveform = torch.from_numpy(audio).unsqueeze(0).float() |
| |
|
| | |
| | with torch.no_grad(): |
| | audio_len = torch.tensor([waveform.shape[1]]) |
| | features, feat_len = model.process_signal( |
| | audio_signal=waveform, audio_signal_length=audio_len |
| | ) |
| |
|
| | |
| | features = features[:, :, :feat_len.max()] |
| | print(f"Features: {features.shape} (batch, mel, time)") |
| |
|
| | |
| | subsampling = modules.subsampling_factor |
| | chunk_len = modules.chunk_len |
| | left_context = modules.chunk_left_context |
| | right_context = modules.chunk_right_context |
| | core_frames = chunk_len * subsampling |
| |
|
| | total_mel_frames = features.shape[2] |
| | print(f"Total mel frames: {total_mel_frames}") |
| | print(f"Core frames per chunk: {core_frames}") |
| |
|
| | |
| | streaming_state = modules.init_streaming_state(device=features.device) |
| |
|
| | |
| | total_preds = torch.zeros((1, 0, 4), device=features.device) |
| |
|
| | all_preds = [] |
| | chunk_idx = 0 |
| |
|
| | |
| | stt_feat = 0 |
| | while stt_feat < total_mel_frames: |
| | end_feat = min(stt_feat + core_frames, total_mel_frames) |
| |
|
| | |
| | left_offset = min(left_context * subsampling, stt_feat) |
| | right_offset = min(right_context * subsampling, total_mel_frames - end_feat) |
| |
|
| | chunk_start = stt_feat - left_offset |
| | chunk_end = end_feat + right_offset |
| |
|
| | |
| | chunk = features[:, :, chunk_start:chunk_end] |
| | chunk_t = chunk.transpose(1, 2) |
| | chunk_len_tensor = torch.tensor([chunk_t.shape[1]], dtype=torch.long) |
| |
|
| | with torch.no_grad(): |
| | |
| | streaming_state, total_preds = model.forward_streaming_step( |
| | processed_signal=chunk_t, |
| | processed_signal_length=chunk_len_tensor, |
| | streaming_state=streaming_state, |
| | total_preds=total_preds, |
| | left_offset=left_offset, |
| | right_offset=right_offset, |
| | ) |
| |
|
| | chunk_idx += 1 |
| | stt_feat = end_feat |
| |
|
| | |
| | all_preds = total_preds[0].numpy() |
| | print(f"\nTotal output frames: {all_preds.shape[0]}") |
| | print(f"Predictions shape: {all_preds.shape}") |
| |
|
| | |
| | print("\n=== NeMo Streaming Timeline (80ms per frame, threshold=0.55) ===") |
| | print("Frame Time Spk0 Spk1 Spk2 Spk3 | Visual") |
| | print("-" * 60) |
| |
|
| | for frame in range(all_preds.shape[0]): |
| | time_sec = frame * 0.08 |
| | probs = all_preds[frame] |
| | visual = ['■' if p > 0.55 else '·' for p in probs] |
| | print(f"{frame:5d} {time_sec:5.2f}s {probs[0]:.3f} {probs[1]:.3f} {probs[2]:.3f} {probs[3]:.3f} | [{visual[0]}{visual[1]}{visual[2]}{visual[3]}]") |
| |
|
| | print("-" * 60) |
| |
|
| | |
| | print("\n=== Speaker Activity Summary ===") |
| | threshold = 0.55 |
| | for spk in range(4): |
| | active_frames = np.sum(all_preds[:, spk] > threshold) |
| | active_time = active_frames * 0.08 |
| | percent = active_time / (all_preds.shape[0] * 0.08) * 100 |
| | print(f"Speaker_{spk}: {active_time:.1f}s active ({percent:.1f}%)") |
| |
|
| | |
| | output = { |
| | "total_frames": int(all_preds.shape[0]), |
| | "frame_duration_seconds": 0.08, |
| | "probabilities": all_preds.flatten().tolist(), |
| | "config": { |
| | "chunk_len": chunk_len, |
| | "chunk_left_context": left_context, |
| | "chunk_right_context": right_context, |
| | "fifo_len": modules.fifo_len, |
| | "spkcache_len": modules.spkcache_len, |
| | } |
| | } |
| |
|
| | with open("/tmp/nemo_streaming_reference.json", "w") as f: |
| | json.dump(output, f, indent=2) |
| | print("\nSaved to /tmp/nemo_streaming_reference.json") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|