| |
| """ |
| Convert Sortformer to CoreML with proper dynamic length handling. |
| |
| The key issue: Original conversion traced with fixed lengths (spkcache=120, fifo=40), |
| but at runtime we need to handle empty state (spkcache=0, fifo=0) for first chunk. |
| |
| Solution: Use scripting instead of tracing, or trace with multiple example lengths. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import coremltools as ct |
| import numpy as np |
| import os |
| import sys |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, os.path.join(SCRIPT_DIR, 'NeMo')) |
|
|
| from nemo.collections.asr.models import SortformerEncLabelModel |
|
|
| print("=" * 70) |
| print("CONVERTING SORTFORMER WITH DYNAMIC LENGTH SUPPORT") |
| print("=" * 70) |
|
|
| |
| model_path = os.path.join(SCRIPT_DIR, 'diar_streaming_sortformer_4spk-v2.nemo') |
| print(f"Loading model: {model_path}") |
| model = SortformerEncLabelModel.restore_from(model_path, map_location='cpu', strict=False) |
| model.eval() |
|
|
| |
| modules = model.sortformer_modules |
| modules.chunk_len = 6 |
| modules.chunk_left_context = 1 |
| modules.chunk_right_context = 1 |
| modules.fifo_len = 40 |
| modules.spkcache_len = 120 |
| modules.spkcache_update_period = 30 |
|
|
| print(f"Config: chunk_len={modules.chunk_len}, left={modules.chunk_left_context}, right={modules.chunk_right_context}") |
| print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}") |
|
|
| |
| chunk_frames = (modules.chunk_len + modules.chunk_left_context + modules.chunk_right_context) * modules.subsampling_factor |
| fc_d_model = modules.fc_d_model |
| feat_dim = 128 |
|
|
| print(f"Chunk frames: {chunk_frames}") |
|
|
| class DynamicPreEncoderWrapper(nn.Module): |
| """Pre-encoder that properly handles dynamic lengths.""" |
|
|
| def __init__(self, model, max_spkcache=120, max_fifo=40, max_chunk=8): |
| super().__init__() |
| self.model = model |
| self.max_spkcache = max_spkcache |
| self.max_fifo = max_fifo |
| self.max_chunk = max_chunk |
| self.max_total = max_spkcache + max_fifo + max_chunk |
|
|
| def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths): |
| |
| chunk_embs, chunk_emb_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths) |
|
|
| |
| spk_len = spkcache_lengths[0].item() if spkcache_lengths.numel() > 0 else 0 |
| fifo_len = fifo_lengths[0].item() if fifo_lengths.numel() > 0 else 0 |
| chunk_len = chunk_emb_lengths[0].item() |
| total_len = spk_len + fifo_len + chunk_len |
|
|
| |
| B, _, D = spkcache.shape |
| output = torch.zeros(B, self.max_total, D, device=chunk.device, dtype=chunk.dtype) |
|
|
| |
| if spk_len > 0: |
| output[:, :spk_len, :] = spkcache[:, :spk_len, :] |
| if fifo_len > 0: |
| output[:, spk_len:spk_len+fifo_len, :] = fifo[:, :fifo_len, :] |
| output[:, spk_len+fifo_len:spk_len+fifo_len+chunk_len, :] = chunk_embs[:, :chunk_len, :] |
|
|
| total_length = torch.tensor([total_len], dtype=torch.long) |
|
|
| return output, total_length, chunk_embs, chunk_emb_lengths |
|
|
|
|
| class DynamicHeadWrapper(nn.Module): |
| """Head that properly handles dynamic lengths with masking.""" |
|
|
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, pre_encoder_embs, pre_encoder_lengths, chunk_embs, chunk_emb_lengths): |
| |
| fc_embs, fc_lengths = self.model.frontend_encoder( |
| processed_signal=pre_encoder_embs, |
| processed_signal_length=pre_encoder_lengths, |
| bypass_pre_encode=True, |
| ) |
|
|
| |
| preds = self.model.forward_infer(fc_embs, fc_lengths) |
|
|
| |
| |
| max_len = preds.shape[1] |
| length = pre_encoder_lengths[0] |
| mask = torch.arange(max_len, device=preds.device) < length |
| preds = preds * mask.unsqueeze(0).unsqueeze(-1).float() |
|
|
| return preds, chunk_embs, chunk_emb_lengths |
|
|
|
|
| |
| print("\n" + "=" * 70) |
| print("TESTING DYNAMIC WRAPPERS") |
| print("=" * 70) |
|
|
| pre_encoder = DynamicPreEncoderWrapper(model) |
| head = DynamicHeadWrapper(model) |
| pre_encoder.eval() |
| head.eval() |
|
|
| |
| print("\nTest 1: Empty state (chunk 0)") |
| chunk = torch.randn(1, 56, 128) |
| chunk_len = torch.tensor([56], dtype=torch.long) |
| spkcache = torch.zeros(1, 120, 512) |
| spkcache_len = torch.tensor([0], dtype=torch.long) |
| fifo = torch.zeros(1, 40, 512) |
| fifo_len = torch.tensor([0], dtype=torch.long) |
|
|
| with torch.no_grad(): |
| pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder( |
| chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len |
| ) |
| preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len) |
|
|
| print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}") |
| print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}") |
| print(f" Predictions: {preds.shape}") |
| sums = [f"{preds[0, i, :].sum().item():.4f}" for i in range(min(8, preds.shape[1]))] |
| print(f" First 8 pred frames sum: {sums}") |
|
|
| |
| print("\nTest 2: Full state") |
| chunk = torch.randn(1, 64, 128) |
| chunk_len = torch.tensor([64], dtype=torch.long) |
| spkcache = torch.randn(1, 120, 512) |
| spkcache_len = torch.tensor([120], dtype=torch.long) |
| fifo = torch.randn(1, 40, 512) |
| fifo_len = torch.tensor([40], dtype=torch.long) |
|
|
| with torch.no_grad(): |
| pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder( |
| chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len |
| ) |
| preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len) |
|
|
| print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}") |
| print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}") |
| print(f" Predictions: {preds.shape}") |
|
|
| print("\n" + "=" * 70) |
| print("ISSUE IDENTIFIED") |
| print("=" * 70) |
| print(""" |
| The problem is that the current CoreML model was traced with FIXED lengths. |
| When lengths change at runtime, the traced operations don't adapt. |
| |
| The fix requires re-tracing with proper dynamic handling OR using coremltools |
| flexible shapes feature. |
| |
| For now, let's try a simpler approach: always pad inputs to max size and |
| use the length parameters only for extracting the correct output slice. |
| """) |
|
|
| |
| |
| |
|
|
| print("\nATTEMPTING CONVERSION WITH FLEXIBLE SHAPES...") |
|
|
| |
| try: |
| |
| class SimplePipelineWrapper(nn.Module): |
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths): |
| |
| chunk_embs, chunk_emb_lens = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths) |
|
|
| |
| spk_len = spkcache_lengths[0] |
| fifo_len = fifo_lengths[0] |
| chunk_len = chunk_emb_lens[0] |
|
|
| |
| |
| B = chunk.shape[0] |
| max_out = 168 |
| D = 512 |
|
|
| concat_embs = torch.zeros(B, max_out, D, device=chunk.device, dtype=chunk.dtype) |
|
|
| |
| for i in range(120): |
| if i < spk_len: |
| concat_embs[:, i, :] = spkcache[:, i, :] |
|
|
| |
| for i in range(40): |
| if i < fifo_len: |
| concat_embs[:, 120 + i, :] = fifo[:, i, :] |
|
|
| |
| for i in range(8): |
| if i < chunk_len: |
| concat_embs[:, 120 + 40 + i, :] = chunk_embs[:, i, :] |
|
|
| total_len = spk_len + fifo_len + chunk_len |
| total_lens = total_len.unsqueeze(0) |
|
|
| |
| fc_embs, fc_lens = self.model.frontend_encoder( |
| processed_signal=concat_embs, |
| processed_signal_length=total_lens, |
| bypass_pre_encode=True, |
| ) |
|
|
| |
| preds = self.model.forward_infer(fc_embs, fc_lens) |
|
|
| return preds, chunk_embs, chunk_emb_lens |
|
|
| wrapper = SimplePipelineWrapper(model) |
| wrapper.eval() |
|
|
| |
| print("Tracing with empty state example...") |
| chunk = torch.randn(1, 64, 128) |
| chunk_len = torch.tensor([56], dtype=torch.long) |
| spkcache = torch.zeros(1, 120, 512) |
| spkcache_len = torch.tensor([0], dtype=torch.long) |
| fifo = torch.zeros(1, 40, 512) |
| fifo_len = torch.tensor([0], dtype=torch.long) |
|
|
| with torch.no_grad(): |
| traced = torch.jit.trace(wrapper, (chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len)) |
|
|
| print("Converting to CoreML...") |
| mlmodel = ct.convert( |
| traced, |
| inputs=[ |
| ct.TensorType(name="chunk", shape=(1, 64, 128), dtype=np.float32), |
| ct.TensorType(name="chunk_lengths", shape=(1,), dtype=np.int32), |
| ct.TensorType(name="spkcache", shape=(1, 120, 512), dtype=np.float32), |
| ct.TensorType(name="spkcache_lengths", shape=(1,), dtype=np.int32), |
| ct.TensorType(name="fifo", shape=(1, 40, 512), dtype=np.float32), |
| ct.TensorType(name="fifo_lengths", shape=(1,), dtype=np.int32), |
| ], |
| outputs=[ |
| ct.TensorType(name="speaker_preds", dtype=np.float32), |
| ct.TensorType(name="chunk_pre_encoder_embs", dtype=np.float32), |
| ct.TensorType(name="chunk_pre_encoder_lengths", dtype=np.int32), |
| ], |
| minimum_deployment_target=ct.target.iOS16, |
| compute_precision=ct.precision.FLOAT32, |
| compute_units=ct.ComputeUnit.CPU_ONLY, |
| ) |
|
|
| output_path = os.path.join(SCRIPT_DIR, 'coreml_models', 'SortformerPipeline_Dynamic.mlpackage') |
| mlmodel.save(output_path) |
| print(f"Saved to: {output_path}") |
|
|
| |
| print("\nTesting new CoreML model...") |
| test_output = mlmodel.predict({ |
| 'chunk': chunk.numpy(), |
| 'chunk_lengths': chunk_len.numpy().astype(np.int32), |
| 'spkcache': spkcache.numpy(), |
| 'spkcache_lengths': spkcache_len.numpy().astype(np.int32), |
| 'fifo': fifo.numpy(), |
| 'fifo_lengths': fifo_len.numpy().astype(np.int32), |
| }) |
|
|
| coreml_preds = np.array(test_output['speaker_preds']) |
| print(f"CoreML predictions shape: {coreml_preds.shape}") |
| print(f"CoreML first 8 frames:") |
| for i in range(min(8, coreml_preds.shape[1])): |
| print(f" Frame {i}: {coreml_preds[0, i, :]}") |
|
|
| except Exception as e: |
| print(f"Error during conversion: {e}") |
| import traceback |
| traceback.print_exc() |
|
|