mazesmazes commited on
Commit
d6de53b
·
verified ·
1 Parent(s): 72a24d8

Training in progress - step 1000

Browse files
Files changed (5) hide show
  1. asr_config.py +18 -29
  2. asr_modeling.py +53 -58
  3. asr_pipeline.py +23 -29
  4. asr_processing.py +7 -6
  5. projectors.py +41 -33
asr_config.py CHANGED
@@ -6,6 +6,19 @@ import transformers
6
  DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class ASRConfig(transformers.PretrainedConfig):
10
  """Configuration class for the ASR model.
11
 
@@ -14,7 +27,7 @@ class ASRConfig(transformers.PretrainedConfig):
14
  - Text decoder (Qwen)
15
  - Projector (MLP, MOSA, MoE, QFormer)
16
  - Generation parameters
17
- - Training options (SpecAugment, LoRA)
18
  """
19
 
20
  model_type = "asr_model"
@@ -38,9 +51,6 @@ class ASRConfig(transformers.PretrainedConfig):
38
  downsample_rate: int = 5, # Granite default
39
  projector_hidden_dim: Optional[int] = None,
40
  projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
41
- projector_num_layers: int = 2, # Number of layers in MLP projector
42
- projector_init_std: float = 0.02, # Weight initialization std
43
- projector_dropout: float = 0.0, # Dropout rate for projector layers
44
  # MoE-specific configuration
45
  num_experts: int = 4, # Number of experts in MoE projectors
46
  num_experts_per_tok: int = 2, # Top-k experts per token
@@ -51,14 +61,6 @@ class ASRConfig(transformers.PretrainedConfig):
51
  qformer_num_layers: int = 2, # Number of QFormer transformer layers
52
  qformer_num_heads: int = 16, # Number of attention heads in QFormer
53
  qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
54
- label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
55
- inference_warmup_tokens: int = 10,
56
- # SpecAugment settings
57
- use_specaugment: bool = False,
58
- num_time_masks: int = 2,
59
- time_mask_length: int = 10,
60
- num_freq_masks: int = 0,
61
- freq_mask_length: int = 10,
62
  # LoRA configuration (for Stage 2 fine-tuning)
63
  use_lora: bool = False,
64
  lora_rank: int = 8, # SALMONN default
@@ -88,22 +90,20 @@ class ASRConfig(transformers.PretrainedConfig):
88
  model_dtype: Model dtype ("bfloat16", "float16", "float32")
89
  projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
90
  use_lora: Enable LoRA adapters for Stage 2 fine-tuning
91
- use_specaugment: Enable SpecAugment data augmentation
92
  """
93
- # Set default generation parameters (greedy decoding only)
 
 
94
  generation_defaults = {
95
  "num_beams": 1,
96
  "max_new_tokens": 128,
97
  "min_new_tokens": 0,
98
  "repetition_penalty": 1.0,
99
  "length_penalty": 1.0,
100
- "no_repeat_ngram_size": 0, # Prevent repeating 3-grams like "so so so"
101
  "use_cache": True,
102
  }
103
 
104
- # Apply defaults (config.json values take precedence)
105
- kwargs = {**generation_defaults, **kwargs}
106
-
107
  self.audio_model_id = audio_model_id
108
  self.text_model_id = text_model_id
109
  self.attn_implementation = attn_implementation
@@ -113,13 +113,10 @@ class ASRConfig(transformers.PretrainedConfig):
113
  self.llm_dim = llm_dim
114
  self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
115
  self.audio_sample_rate = audio_sample_rate
116
- self.projector_init_std = projector_init_std
117
  self.projector_pool_stride = projector_pool_stride
118
  self.downsample_rate = downsample_rate
119
  self.projector_hidden_dim = projector_hidden_dim
120
  self.projector_type = projector_type
121
- self.projector_num_layers = projector_num_layers
122
- self.projector_dropout = projector_dropout
123
  # MoE-specific configuration
124
  self.num_experts = num_experts
125
  self.num_experts_per_tok = num_experts_per_tok
@@ -130,14 +127,6 @@ class ASRConfig(transformers.PretrainedConfig):
130
  self.qformer_num_layers = qformer_num_layers
131
  self.qformer_num_heads = qformer_num_heads
132
  self.qformer_intermediate_size = qformer_intermediate_size
133
- self.label_smoothing = label_smoothing
134
- self.inference_warmup_tokens = inference_warmup_tokens
135
- # SpecAugment configuration
136
- self.use_specaugment = use_specaugment
137
- self.num_time_masks = num_time_masks
138
- self.time_mask_length = time_mask_length
139
- self.num_freq_masks = num_freq_masks
140
- self.freq_mask_length = freq_mask_length
141
  # LoRA configuration
142
  self.use_lora = use_lora
143
  self.lora_rank = lora_rank
 
6
  DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
7
 
8
 
9
+ def compute_encoder_output_length(mel_length, conv_layers=None):
10
+ """Apply encoder conv layer formulas to compute output length.
11
+
12
+ Works with both Python ints and torch tensors of mel lengths; the formula
13
+ `(L + 2*p - (k-1) - 1) // s + 1` per layer is identical for both.
14
+ """
15
+ layers = conv_layers if conv_layers is not None else DEFAULT_ENCODER_CONV_LAYERS
16
+ length = mel_length
17
+ for padding, kernel_size, stride in layers:
18
+ length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
19
+ return length
20
+
21
+
22
  class ASRConfig(transformers.PretrainedConfig):
23
  """Configuration class for the ASR model.
24
 
 
27
  - Text decoder (Qwen)
28
  - Projector (MLP, MOSA, MoE, QFormer)
29
  - Generation parameters
30
+ - Training options (LoRA)
31
  """
32
 
33
  model_type = "asr_model"
 
51
  downsample_rate: int = 5, # Granite default
52
  projector_hidden_dim: Optional[int] = None,
53
  projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
 
 
 
54
  # MoE-specific configuration
55
  num_experts: int = 4, # Number of experts in MoE projectors
56
  num_experts_per_tok: int = 2, # Top-k experts per token
 
61
  qformer_num_layers: int = 2, # Number of QFormer transformer layers
62
  qformer_num_heads: int = 16, # Number of attention heads in QFormer
63
  qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
 
 
 
 
 
 
 
 
64
  # LoRA configuration (for Stage 2 fine-tuning)
65
  use_lora: bool = False,
66
  lora_rank: int = 8, # SALMONN default
 
90
  model_dtype: Model dtype ("bfloat16", "float16", "float32")
91
  projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
92
  use_lora: Enable LoRA adapters for Stage 2 fine-tuning
 
93
  """
94
+ # Set default generation parameters (greedy decoding only).
95
+ # Applied via setattr below — keeping these out of kwargs so they
96
+ # don't get re-overwritten by super().__init__(**kwargs) at the end.
97
  generation_defaults = {
98
  "num_beams": 1,
99
  "max_new_tokens": 128,
100
  "min_new_tokens": 0,
101
  "repetition_penalty": 1.0,
102
  "length_penalty": 1.0,
103
+ "no_repeat_ngram_size": 0,
104
  "use_cache": True,
105
  }
106
 
 
 
 
107
  self.audio_model_id = audio_model_id
108
  self.text_model_id = text_model_id
109
  self.attn_implementation = attn_implementation
 
113
  self.llm_dim = llm_dim
114
  self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
115
  self.audio_sample_rate = audio_sample_rate
 
116
  self.projector_pool_stride = projector_pool_stride
117
  self.downsample_rate = downsample_rate
118
  self.projector_hidden_dim = projector_hidden_dim
119
  self.projector_type = projector_type
 
 
120
  # MoE-specific configuration
121
  self.num_experts = num_experts
122
  self.num_experts_per_tok = num_experts_per_tok
 
127
  self.qformer_num_layers = qformer_num_layers
128
  self.qformer_num_heads = qformer_num_heads
129
  self.qformer_intermediate_size = qformer_intermediate_size
 
 
 
 
 
 
 
 
130
  # LoRA configuration
131
  self.use_lora = use_lora
132
  self.lora_rank = lora_rank
asr_modeling.py CHANGED
@@ -17,16 +17,13 @@ from transformers.generation import GenerationMixin
17
  from transformers.modeling_outputs import CausalLMOutputWithPast
18
 
19
  try:
20
- from .asr_config import ASRConfig
21
  from .projectors import PROJECTOR_CLASSES
22
  except ImportError:
23
- from asr_config import ASRConfig # type: ignore[no-redef]
24
  from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
 
26
 
27
- from torchaudio.transforms import SpecAugment
28
-
29
-
30
  def _gather_audio_embeds(audio_embeds: torch.Tensor, token_counts: torch.Tensor) -> torch.Tensor:
31
  """Flatten per-sample audio embeddings into a packed tensor.
32
 
@@ -56,7 +53,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
56
  _supports_flash_attn_2 = True
57
  supports_gradient_checkpointing = True
58
  _is_loading_from_pretrained: bool = False
59
- _pretrained_model_path: Optional[str] = None
60
 
61
  TRANSCRIBE_PROMPT = "Transcribe the speech to text"
62
 
@@ -72,7 +68,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
72
 
73
  # Set flag to avoid device_map="auto" in sub-model loaders
74
  cls._is_loading_from_pretrained = True
75
- cls._pretrained_model_path = pretrained_model_name_or_path
76
 
77
  try:
78
  model = cls(config, **kwargs)
@@ -134,7 +129,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
134
  return model
135
  finally:
136
  cls._is_loading_from_pretrained = False
137
- cls._pretrained_model_path = None
138
 
139
  def __init__(self, config: ASRConfig, **kwargs) -> None:
140
  super().__init__(config)
@@ -190,17 +184,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
190
  if getattr(config, "freeze_projector", False):
191
  self.projector.requires_grad_(False)
192
 
193
- # SpecAugment for data augmentation during training
194
- if getattr(config, "use_specaugment", False):
195
- self.spec_augment = SpecAugment(
196
- n_time_masks=config.num_time_masks,
197
- time_mask_param=config.time_mask_length,
198
- n_freq_masks=config.num_freq_masks,
199
- freq_mask_param=config.freq_mask_length,
200
- )
201
- else:
202
- self.spec_augment = None
203
-
204
  # For model parallelism
205
  self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
206
 
@@ -340,7 +323,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
340
  self.tokenizer.add_special_tokens(
341
  {"additional_special_tokens": existing_special + ["<audio>"]}
342
  )
343
- self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
 
 
 
 
 
 
344
 
345
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
346
  self.tokenizer.padding_side = "right"
@@ -352,9 +341,20 @@ class ASRModel(PreTrainedModel, GenerationMixin):
352
  cfg.eos_token_id = self.tokenizer.eos_token_id
353
  cfg.bos_token_id = self.tokenizer.bos_token_id
354
 
355
- def _init_weights(self, _module):
356
- """Weight initialization (projector weights are initialized in MoEAudioProjector)."""
357
- pass
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
360
  """Enable/disable gradient checkpointing for the language model."""
@@ -396,34 +396,40 @@ class ASRModel(PreTrainedModel, GenerationMixin):
396
  )
397
 
398
  def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
399
- """Save trainable weights: projector, plus the language model when fine-tuned."""
 
 
 
 
 
 
 
 
400
  sd = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
401
  if not getattr(self.config, "freeze_language_model", True):
402
- sd.update(
403
- {f"language_model.{k}": v for k, v in self.language_model.state_dict().items()}
404
- )
 
 
 
 
 
 
 
 
 
405
  return sd
406
 
407
  def _compute_encoder_output_lengths(
408
  self,
409
  audio_attention_mask: torch.Tensor,
410
  ) -> torch.Tensor:
411
- """Compute per-sample encoder output lengths using conv layer formulas.
412
-
413
- Args:
414
- audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
415
-
416
- Returns:
417
- Tensor of encoder output lengths per sample (batch,)
418
- """
419
- # Get mel frame lengths from attention mask
420
- lengths = audio_attention_mask.sum(dim=-1)
421
-
422
- # Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
423
- for padding, kernel_size, stride in self.config.encoder_conv_layers:
424
- lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
425
-
426
- return lengths
427
 
428
  def _encode_audio(
429
  self,
@@ -468,9 +474,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
468
  inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
469
 
470
  if input_features is not None and input_ids is not None:
471
- if self.training and self.spec_augment is not None:
472
- input_features = self.spec_augment(input_features)
473
-
474
  is_audio_token = input_ids == self.audio_token_id
475
  if audio_token_counts is None:
476
  audio_token_counts = is_audio_token.sum(dim=-1)
@@ -556,13 +559,9 @@ class ASRModel(PreTrainedModel, GenerationMixin):
556
  device = input_features.device
557
  batch_size = input_features.shape[0]
558
 
559
- # Encode audio -> flattened embeddings
560
  encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
561
- token_counts = torch.tensor(
562
- [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
563
- device=input_features.device,
564
- dtype=torch.long,
565
- )
566
  audio_embeds = self._encode_audio(input_features, token_counts)
567
 
568
  # If input_ids not provided, build prompt with correct number of audio tokens
@@ -646,13 +645,9 @@ class ASRModel(PreTrainedModel, GenerationMixin):
646
  device = input_features.device
647
  batch_size = input_features.shape[0]
648
 
649
- # Encode audio -> flattened embeddings
650
  encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
651
- token_counts = torch.tensor(
652
- [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
653
- device=input_features.device,
654
- dtype=torch.long,
655
- )
656
  audio_embeds = self._encode_audio(input_features, token_counts)
657
 
658
  # Build prompt with correct number of audio tokens
 
17
  from transformers.modeling_outputs import CausalLMOutputWithPast
18
 
19
  try:
20
+ from .asr_config import ASRConfig, compute_encoder_output_length
21
  from .projectors import PROJECTOR_CLASSES
22
  except ImportError:
23
+ from asr_config import ASRConfig, compute_encoder_output_length # type: ignore[no-redef]
24
  from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
25
 
26
 
 
 
 
27
  def _gather_audio_embeds(audio_embeds: torch.Tensor, token_counts: torch.Tensor) -> torch.Tensor:
28
  """Flatten per-sample audio embeddings into a packed tensor.
29
 
 
53
  _supports_flash_attn_2 = True
54
  supports_gradient_checkpointing = True
55
  _is_loading_from_pretrained: bool = False
 
56
 
57
  TRANSCRIBE_PROMPT = "Transcribe the speech to text"
58
 
 
68
 
69
  # Set flag to avoid device_map="auto" in sub-model loaders
70
  cls._is_loading_from_pretrained = True
 
71
 
72
  try:
73
  model = cls(config, **kwargs)
 
129
  return model
130
  finally:
131
  cls._is_loading_from_pretrained = False
 
132
 
133
  def __init__(self, config: ASRConfig, **kwargs) -> None:
134
  super().__init__(config)
 
184
  if getattr(config, "freeze_projector", False):
185
  self.projector.requires_grad_(False)
186
 
 
 
 
 
 
 
 
 
 
 
 
187
  # For model parallelism
188
  self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
189
 
 
323
  self.tokenizer.add_special_tokens(
324
  {"additional_special_tokens": existing_special + ["<audio>"]}
325
  )
326
+ # mean_resizing=True initializes the new <audio> row at the mean of
327
+ # existing rows so its scale matches the pretrained distribution. The
328
+ # input-side <audio> embedding is overwritten via masked_scatter and
329
+ # never seen by the LM, but with tied embeddings (Qwen3-0.6B) this
330
+ # same row is the lm_head column for predicting <audio>; a Gaussian
331
+ # draw at config.initializer_range was visible in early-step logits.
332
+ self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=True)
333
 
334
  self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
335
  self.tokenizer.padding_side = "right"
 
341
  cfg.eos_token_id = self.tokenizer.eos_token_id
342
  cfg.bos_token_id = self.tokenizer.bos_token_id
343
 
344
+ def train(self, mode: bool = True):
345
+ """Set train/eval mode, but keep frozen submodules out of train mode.
346
+
347
+ HF Trainer calls `model.train()` at the top of every training step, which
348
+ recursively switches every submodule into train mode — re-enabling dropout
349
+ on modules with `requires_grad_(False)`. The frozen encoder (and the LM
350
+ when `freeze_language_model=True`) should always run deterministically;
351
+ train-mode dropout only adds noise that can't improve a frozen network.
352
+ """
353
+ super().train(mode)
354
+ self.audio_tower.train(False)
355
+ if getattr(self.config, "freeze_language_model", True):
356
+ self.language_model.train(False)
357
+ return self
358
 
359
  def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
360
  """Enable/disable gradient checkpointing for the language model."""
 
396
  )
397
 
398
  def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
399
+ """Save trainable weights: projector, plus the language model when fine-tuned.
400
+
401
+ With LoRA attached, the language_model entries are flattened to plain
402
+ (non-PEFT) HF naming so model.safetensors round-trips through
403
+ ASRModel.from_pretrained — which builds a vanilla base LM, overlays
404
+ these weights, and only then re-attaches PEFT. lora_*/adapter weights
405
+ are skipped here; PEFT serializes them separately as
406
+ adapter_model.safetensors via the save_pretrained path below.
407
+ """
408
  sd = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
409
  if not getattr(self.config, "freeze_language_model", True):
410
+ lm = self.language_model
411
+ if hasattr(lm, "peft_config"):
412
+ for k, v in lm.state_dict().items():
413
+ if "lora_" in k:
414
+ continue
415
+ if k.startswith("base_model.model."):
416
+ k = k[len("base_model.model.") :]
417
+ # LoRA layers wrap the original Linear as `<name>.base_layer.<weight|bias>`.
418
+ k = k.replace(".base_layer.", ".")
419
+ sd[f"language_model.{k}"] = v
420
+ else:
421
+ sd.update({f"language_model.{k}": v for k, v in lm.state_dict().items()})
422
  return sd
423
 
424
  def _compute_encoder_output_lengths(
425
  self,
426
  audio_attention_mask: torch.Tensor,
427
  ) -> torch.Tensor:
428
+ """Compute per-sample encoder output lengths using conv layer formulas."""
429
+ return compute_encoder_output_length(
430
+ audio_attention_mask.sum(dim=-1),
431
+ self.config.encoder_conv_layers,
432
+ )
 
 
 
 
 
 
 
 
 
 
 
433
 
434
  def _encode_audio(
435
  self,
 
474
  inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
475
 
476
  if input_features is not None and input_ids is not None:
 
 
 
477
  is_audio_token = input_ids == self.audio_token_id
478
  if audio_token_counts is None:
479
  audio_token_counts = is_audio_token.sum(dim=-1)
 
559
  device = input_features.device
560
  batch_size = input_features.shape[0]
561
 
562
+ # Encode audio -> flattened embeddings (no per-sample host sync)
563
  encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
564
+ token_counts = self.projector.get_output_length(encoder_lengths).to(torch.long)
 
 
 
 
565
  audio_embeds = self._encode_audio(input_features, token_counts)
566
 
567
  # If input_ids not provided, build prompt with correct number of audio tokens
 
645
  device = input_features.device
646
  batch_size = input_features.shape[0]
647
 
648
+ # Encode audio -> flattened embeddings (no per-sample host sync)
649
  encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
650
+ token_counts = self.projector.get_output_length(encoder_lengths).to(torch.long)
 
 
 
 
651
  audio_embeds = self._encode_audio(input_features, token_counts)
652
 
653
  # Build prompt with correct number of audio tokens
asr_pipeline.py CHANGED
@@ -23,9 +23,9 @@ __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
23
 
24
  _THINK_TAG_RE = re.compile(r"<think>.*?</think>\s*", flags=re.DOTALL)
25
  _DEFAULT_MIN_REPEATS = 3
26
- _TRAILING_CHAR_RE = re.compile(r"(.)\1{" + str(_DEFAULT_MIN_REPEATS - 1) + r",}$")
27
  _TRAILING_WORD_RE = re.compile(
28
- r"\b(\w+)(?:\s+\1){" + str(_DEFAULT_MIN_REPEATS - 1) + r",}\s*$", re.IGNORECASE
29
  )
30
 
31
 
@@ -291,10 +291,8 @@ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
291
  char_pattern = _TRAILING_CHAR_RE
292
  word_pattern = _TRAILING_WORD_RE
293
  else:
294
- char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$")
295
- word_pattern = re.compile(
296
- r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
297
- )
298
 
299
  text = char_pattern.sub(r"\1", text)
300
  while word_pattern.search(text):
@@ -303,28 +301,24 @@ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
303
  # 3. Truncate repeated phrases (2-20 words) at end
304
  # e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
305
  words = text.split()
306
- if len(words) >= min_repeats * 2:
307
- # Try phrase lengths from 2 to 20 words
308
- for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
309
- # Check if the last phrase_len words repeat
310
- phrase = " ".join(words[-phrase_len:])
311
- # Build pattern to match repeated phrases at end
312
- phrase_escaped = re.escape(phrase)
313
- phrase_pattern = re.compile(
314
- r"(^|.*?\s)("
315
- + phrase_escaped
316
- + r")(?:\s+"
317
- + phrase_escaped
318
- + r"){"
319
- + str(min_repeats - 1)
320
- + r",}\s*$",
321
- re.IGNORECASE,
322
- )
323
- match = phrase_pattern.match(text)
324
- if match:
325
- # Keep prefix + one instance of the phrase
326
- text = (match.group(1) + match.group(2)).strip()
327
- words = text.split()
328
- break
329
 
330
  return text
 
23
 
24
  _THINK_TAG_RE = re.compile(r"<think>.*?</think>\s*", flags=re.DOTALL)
25
  _DEFAULT_MIN_REPEATS = 3
26
+ _TRAILING_CHAR_RE = re.compile(rf"(.)\1{{{_DEFAULT_MIN_REPEATS - 1},}}$")
27
  _TRAILING_WORD_RE = re.compile(
28
+ rf"\b(\w+)(?:\s+\1){{{_DEFAULT_MIN_REPEATS - 1},}}\s*$", re.IGNORECASE
29
  )
30
 
31
 
 
291
  char_pattern = _TRAILING_CHAR_RE
292
  word_pattern = _TRAILING_WORD_RE
293
  else:
294
+ char_pattern = re.compile(rf"(.)\1{{{min_repeats - 1},}}$")
295
+ word_pattern = re.compile(rf"\b(\w+)(?:\s+\1){{{min_repeats - 1},}}\s*$", re.IGNORECASE)
 
 
296
 
297
  text = char_pattern.sub(r"\1", text)
298
  while word_pattern.search(text):
 
301
  # 3. Truncate repeated phrases (2-20 words) at end
302
  # e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
303
  words = text.split()
304
+ if len(words) < min_repeats * 2:
305
+ return text
306
+
307
+ # Cheap pre-check: trailing window must contain duplicates for any phrase repeat
308
+ # to be possible. set(window) == window means all unique → no repetition.
309
+ window = words[-min_repeats * 2 :]
310
+ if len(set(window)) == len(window):
311
+ return text
312
+
313
+ for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
314
+ phrase_escaped = re.escape(" ".join(words[-phrase_len:]))
315
+ phrase_pattern = re.compile(
316
+ rf"(^|.*?\s)({phrase_escaped})(?:\s+{phrase_escaped}){{{min_repeats - 1},}}\s*$",
317
+ re.IGNORECASE,
318
+ )
319
+ match = phrase_pattern.match(text)
320
+ if match:
321
+ text = (match.group(1) + match.group(2)).strip()
322
+ break
 
 
 
 
323
 
324
  return text
asr_processing.py CHANGED
@@ -5,9 +5,13 @@ import transformers
5
  from transformers import ProcessorMixin
6
 
7
  try:
8
- from .asr_config import DEFAULT_ENCODER_CONV_LAYERS, ASRConfig
9
  except ImportError:
10
- from asr_config import DEFAULT_ENCODER_CONV_LAYERS, ASRConfig # type: ignore[no-redef]
 
 
 
 
11
 
12
 
13
  class ASRProcessor(ProcessorMixin):
@@ -42,10 +46,7 @@ class ASRProcessor(ProcessorMixin):
42
 
43
  def _compute_encoder_output_length(self, mel_length: int) -> int:
44
  """Compute encoder output length using conv layer formulas."""
45
- length = mel_length
46
- for padding, kernel_size, stride in self.encoder_conv_layers:
47
- length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
48
- return length
49
 
50
  def __call__(
51
  self,
 
5
  from transformers import ProcessorMixin
6
 
7
  try:
8
+ from .asr_config import DEFAULT_ENCODER_CONV_LAYERS, ASRConfig, compute_encoder_output_length
9
  except ImportError:
10
+ from asr_config import ( # type: ignore[no-redef]
11
+ DEFAULT_ENCODER_CONV_LAYERS,
12
+ ASRConfig,
13
+ compute_encoder_output_length,
14
+ )
15
 
16
 
17
  class ASRProcessor(ProcessorMixin):
 
46
 
47
  def _compute_encoder_output_length(self, mel_length: int) -> int:
48
  """Compute encoder output length using conv layer formulas."""
49
+ return compute_encoder_output_length(mel_length, self.encoder_conv_layers)
 
 
 
50
 
51
  def __call__(
52
  self,
projectors.py CHANGED
@@ -43,6 +43,11 @@ class MLPAudioProjector(nn.Module):
43
  self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
44
  self.act = nn.GELU()
45
  self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
 
 
 
 
 
46
 
47
  def get_output_length(self, input_length: int) -> int:
48
  """Calculate output sequence length given input length (matches GLM-ASR)."""
@@ -62,7 +67,8 @@ class MLPAudioProjector(nn.Module):
62
  x = self.linear_1(x)
63
  x = self.norm(x)
64
  x = self.act(x)
65
- return self.linear_2(x)
 
66
 
67
 
68
  # =============================================================================
@@ -102,6 +108,12 @@ class MOSAProjector(nn.Module):
102
  Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total).
103
  """
104
 
 
 
 
 
 
 
105
  def __init__(self, config):
106
  """Initialize MOSA projector.
107
 
@@ -112,31 +124,28 @@ class MOSAProjector(nn.Module):
112
  self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
113
  self.llm_dim = getattr(config, "llm_dim", None) or 2048
114
  self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
115
- adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
116
- router_hidden = getattr(config, "router_hidden_dim", None) or 512
117
 
118
- # --- 1. Conv1d Downsampler (4x reduction) ---
119
- # 2 layers of stride-2 convolution
 
 
 
120
  self.downsampler = nn.Sequential(
121
- nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=3, stride=2, padding=1),
122
  nn.GELU(),
123
- nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
124
  nn.GELU(),
125
  )
126
 
127
- # --- 2. Simple Router (MOSA-Base: 2 layers with ReLU) ---
128
- # Takes downsampled features (llm_dim) -> 512 -> num_experts
129
  self.router = nn.Sequential(
130
- nn.Linear(self.llm_dim, router_hidden),
131
  nn.ReLU(),
132
- nn.Linear(router_hidden, self.num_experts),
133
  )
134
 
135
- # --- 3. Experts (Simple 2-layer GELU adapters) ---
136
- # Each expert: llm_dim -> hidden -> llm_dim (much smaller than frame-stacking)
137
  self.experts = nn.ModuleList(
138
  [
139
- SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
140
  for _ in range(self.num_experts)
141
  ]
142
  )
@@ -150,26 +159,22 @@ class MOSAProjector(nn.Module):
150
  Returns:
151
  Projected features of shape [batch, out_len, llm_dim]
152
  """
153
- # --- 1. Conv1d Downsampling ---
154
- # Permute for Conv1d: [B, S, D] -> [B, D, S]
155
- x = x.transpose(1, 2)
156
- x = self.downsampler(x)
157
- # Permute back: [B, D, S] -> [B, S, D]
158
- x = x.transpose(1, 2)
159
-
160
- # --- 2. Routing ---
161
  routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
162
 
163
- # --- 3. Expert Mixture (Dense Execution) ---
164
- expert_outputs = torch.stack([expert(x) for expert in self.experts]) # (E, B, out_len, D)
165
- return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
 
 
166
 
167
  def get_output_length(self, input_length: int) -> int:
168
  """Calculate output sequence length after Conv1d downsampling (4x reduction)."""
169
- # Conv1d with stride 2, kernel 3, padding 1: out = (in + 2*1 - 3) // 2 + 1 = (in - 1) // 2 + 1
170
- # Applied twice for 4x total reduction
171
- after_conv1 = (input_length + 2 * 1 - 3) // 2 + 1
172
- return (after_conv1 + 2 * 1 - 3) // 2 + 1
173
 
174
 
175
  # =============================================================================
@@ -414,10 +419,13 @@ class QFormerAudioProjector(nn.Module):
414
  # Final projection to LLM dimension (Granite uses bias=True)
415
  self.linear = nn.Linear(qformer_hidden, llm_dim)
416
 
417
- def get_output_length(self, input_length: int) -> int:
418
- """Calculate output sequence length given input length."""
419
- # QFormer uses window-based processing with num_queries per window
420
- nblocks = math.ceil(input_length / self.window_size)
 
 
 
421
  return nblocks * self.num_queries
422
 
423
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
43
  self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
44
  self.act = nn.GELU()
45
  self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
46
+ # Output norm aligns the projector's RMS with the LM's embed_tokens
47
+ # distribution. Without it, linear_2's Kaiming-uniform init produces
48
+ # outputs ~30× quieter than embed rows, which saturates softmax at
49
+ # audio positions and starves them of gradient.
50
+ self.norm_2 = LlamaRMSNorm(llm_dim, eps=1e-6)
51
 
52
  def get_output_length(self, input_length: int) -> int:
53
  """Calculate output sequence length given input length (matches GLM-ASR)."""
 
67
  x = self.linear_1(x)
68
  x = self.norm(x)
69
  x = self.act(x)
70
+ x = self.linear_2(x)
71
+ return self.norm_2(x)
72
 
73
 
74
  # =============================================================================
 
108
  Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total).
109
  """
110
 
111
+ ADAPTER_HIDDEN_DIM = 4096
112
+ ROUTER_HIDDEN_DIM = 512
113
+ CONV_KERNEL = 3
114
+ CONV_STRIDE = 2
115
+ CONV_PADDING = 1
116
+
117
  def __init__(self, config):
118
  """Initialize MOSA projector.
119
 
 
124
  self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
125
  self.llm_dim = getattr(config, "llm_dim", None) or 2048
126
  self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
 
 
127
 
128
+ conv_kwargs = {
129
+ "kernel_size": self.CONV_KERNEL,
130
+ "stride": self.CONV_STRIDE,
131
+ "padding": self.CONV_PADDING,
132
+ }
133
  self.downsampler = nn.Sequential(
134
+ nn.Conv1d(self.encoder_dim, self.encoder_dim, **conv_kwargs),
135
  nn.GELU(),
136
+ nn.Conv1d(self.encoder_dim, self.llm_dim, **conv_kwargs),
137
  nn.GELU(),
138
  )
139
 
 
 
140
  self.router = nn.Sequential(
141
+ nn.Linear(self.llm_dim, self.ROUTER_HIDDEN_DIM),
142
  nn.ReLU(),
143
+ nn.Linear(self.ROUTER_HIDDEN_DIM, self.num_experts),
144
  )
145
 
 
 
146
  self.experts = nn.ModuleList(
147
  [
148
+ SimpleAdapter(self.llm_dim, self.ADAPTER_HIDDEN_DIM, self.llm_dim)
149
  for _ in range(self.num_experts)
150
  ]
151
  )
 
159
  Returns:
160
  Projected features of shape [batch, out_len, llm_dim]
161
  """
162
+ x = self.downsampler(x.transpose(1, 2)).transpose(1, 2)
163
+
 
 
 
 
 
 
164
  routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
165
 
166
+ # Accumulate weighted expert outputs without materializing all experts at once.
167
+ output = self.experts[0](x) * routing_weights[..., 0:1]
168
+ for i, expert in enumerate(self.experts[1:], start=1):
169
+ output = output + expert(x) * routing_weights[..., i : i + 1]
170
+ return output
171
 
172
  def get_output_length(self, input_length: int) -> int:
173
  """Calculate output sequence length after Conv1d downsampling (4x reduction)."""
174
+ length = input_length
175
+ for _ in range(2):
176
+ length = (length + 2 * self.CONV_PADDING - self.CONV_KERNEL) // self.CONV_STRIDE + 1
177
+ return length
178
 
179
 
180
  # =============================================================================
 
419
  # Final projection to LLM dimension (Granite uses bias=True)
420
  self.linear = nn.Linear(qformer_hidden, llm_dim)
421
 
422
+ def get_output_length(self, input_length):
423
+ """Calculate output sequence length given input length.
424
+
425
+ Accepts either Python ints or torch tensors; uses ceiling division so
426
+ the formula is identical for both — math.ceil would block tensors.
427
+ """
428
+ nblocks = (input_length + self.window_size - 1) // self.window_size
429
  return nblocks * self.num_queries
430
 
431
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: