mazesmazes commited on
Commit
8e30b59
·
verified ·
1 Parent(s): 661be01

Training in progress - step 1000

Browse files
Files changed (7) hide show
  1. alignment.py +6 -3
  2. asr_config.py +16 -24
  3. asr_modeling.py +54 -116
  4. asr_pipeline.py +19 -11
  5. asr_processing.py +3 -5
  6. diarization.py +0 -2
  7. projectors.py +15 -41
alignment.py CHANGED
@@ -120,16 +120,19 @@ class ForcedAligner:
120
 
121
  if move_score >= stay_score:
122
  # Token j-1 was emitted at frame t-1
123
- token_frames[j - 1].insert(0, t - 1)
124
  j -= 1
125
- # Always decrement time (monotonic)
126
  t -= 1
127
 
128
  # Handle any remaining tokens at the start (edge case)
129
  while j > 0:
130
- token_frames[j - 1].insert(0, 0)
131
  j -= 1
132
 
 
 
 
 
133
  # Convert to spans
134
  token_spans: list[tuple[int, float, float]] = []
135
  for token_idx, frames in enumerate(token_frames):
 
120
 
121
  if move_score >= stay_score:
122
  # Token j-1 was emitted at frame t-1
123
+ token_frames[j - 1].append(t - 1)
124
  j -= 1
 
125
  t -= 1
126
 
127
  # Handle any remaining tokens at the start (edge case)
128
  while j > 0:
129
+ token_frames[j - 1].append(0)
130
  j -= 1
131
 
132
+ # We appended in reverse-time order; restore monotonic order
133
+ for frames in token_frames:
134
+ frames.reverse()
135
+
136
  # Convert to spans
137
  token_spans: list[tuple[int, float, float]] = []
138
  for token_idx, frames in enumerate(token_frames):
asr_config.py CHANGED
@@ -2,6 +2,9 @@ from typing import Optional
2
 
3
  import transformers
4
 
 
 
 
5
 
6
  class ASRConfig(transformers.PretrainedConfig):
7
  """Configuration class for the ASR model.
@@ -107,8 +110,7 @@ class ASRConfig(transformers.PretrainedConfig):
107
  self.system_prompt = system_prompt
108
  self.encoder_dim = encoder_dim
109
  self.llm_dim = llm_dim
110
- # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
111
- self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
112
  self.audio_sample_rate = audio_sample_rate
113
  self.projector_init_std = projector_init_std
114
  self.projector_pool_stride = projector_pool_stride
@@ -151,28 +153,18 @@ class ASRConfig(transformers.PretrainedConfig):
151
  ]
152
  self.freeze_projector = freeze_projector
153
 
154
- # Generation parameters (use explicit value if provided, else use default)
155
- self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
156
- self.max_new_tokens = (
157
- max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
158
- )
159
- self.min_new_tokens = (
160
- min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
161
- )
162
- self.repetition_penalty = (
163
- repetition_penalty
164
- if repetition_penalty is not None
165
- else generation_defaults["repetition_penalty"]
166
- )
167
- self.length_penalty = (
168
- length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
169
- )
170
- self.no_repeat_ngram_size = (
171
- no_repeat_ngram_size
172
- if no_repeat_ngram_size is not None
173
- else generation_defaults["no_repeat_ngram_size"]
174
- )
175
- self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
176
  self.do_sample = do_sample
177
  self.temperature = temperature
178
  self.top_p = top_p
 
2
 
3
  import transformers
4
 
5
+ # Default conv layers for Whisper/GLM-ASR audio encoders: [(pad, kernel, stride), ...]
6
+ DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
7
+
8
 
9
  class ASRConfig(transformers.PretrainedConfig):
10
  """Configuration class for the ASR model.
 
110
  self.system_prompt = system_prompt
111
  self.encoder_dim = encoder_dim
112
  self.llm_dim = llm_dim
113
+ self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
 
114
  self.audio_sample_rate = audio_sample_rate
115
  self.projector_init_std = projector_init_std
116
  self.projector_pool_stride = projector_pool_stride
 
153
  ]
154
  self.freeze_projector = freeze_projector
155
 
156
+ explicit_generation_args = {
157
+ "num_beams": num_beams,
158
+ "max_new_tokens": max_new_tokens,
159
+ "min_new_tokens": min_new_tokens,
160
+ "repetition_penalty": repetition_penalty,
161
+ "length_penalty": length_penalty,
162
+ "no_repeat_ngram_size": no_repeat_ngram_size,
163
+ "use_cache": use_cache,
164
+ }
165
+ for key, default in generation_defaults.items():
166
+ value = explicit_generation_args[key]
167
+ setattr(self, key, value if value is not None else default)
 
 
 
 
 
 
 
 
 
 
168
  self.do_sample = do_sample
169
  self.temperature = temperature
170
  self.top_p = top_p
asr_modeling.py CHANGED
@@ -5,8 +5,8 @@ from typing import Iterator, Optional, Union
5
 
6
  import torch
7
  import torch.nn as nn
 
8
  from transformers import (
9
- AutoConfig,
10
  AutoModel,
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
@@ -27,6 +27,26 @@ except ImportError:
27
  from torchaudio.transforms import SpecAugment
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class ASRModel(PreTrainedModel, GenerationMixin):
31
  """Audio-to-text model combining an audio encoder, projector, and language model."""
32
 
@@ -402,61 +422,25 @@ class ASRModel(PreTrainedModel, GenerationMixin):
402
  def _encode_audio(
403
  self,
404
  audio_features: torch.Tensor,
405
- audio_attention_mask: torch.Tensor,
406
- expected_token_counts: torch.Tensor | None = None,
407
  ) -> torch.Tensor:
408
- """Encode audio and project to LLM embedding space.
409
 
410
  Args:
411
  audio_features: Mel spectrogram features (batch, n_mels, mel_len)
412
- audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
413
- expected_token_counts: Expected number of audio tokens per sample from input_ids.
414
- If provided, output will match these counts exactly (padding/truncating as needed).
415
 
416
  Returns:
417
- Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
418
  """
419
  with torch.no_grad():
420
  encoder_out = self.audio_tower(input_features=audio_features)
421
  hidden_states = encoder_out.last_hidden_state
422
 
423
- # Project to LLM space
424
  audio_embeds = self.projector(hidden_states)
425
 
426
- # Use expected token counts if provided (from input_ids), otherwise compute from audio
427
- if expected_token_counts is not None:
428
- token_counts = expected_token_counts
429
- else:
430
- # Compute per-sample encoder output lengths using conv formulas
431
- encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
432
- token_counts = torch.tensor(
433
- [
434
- self.projector.get_output_length(int(length.item()))
435
- for length in encoder_lengths
436
- ],
437
- device=audio_embeds.device,
438
- )
439
-
440
- # Extract embeddings matching expected token counts per sample
441
- batch_size = audio_embeds.shape[0]
442
- hidden_dim = audio_embeds.shape[2]
443
-
444
- result_embeds = []
445
- for i in range(batch_size):
446
- count = int(token_counts[i].item())
447
- sample_embeds = audio_embeds[i, :count, :] # Take first 'count' embeddings
448
- # Pad with zeros if we don't have enough embeddings
449
- if sample_embeds.shape[0] < count:
450
- padding = torch.zeros(
451
- count - sample_embeds.shape[0],
452
- hidden_dim,
453
- device=audio_embeds.device,
454
- dtype=audio_embeds.dtype,
455
- )
456
- sample_embeds = torch.cat([sample_embeds, padding], dim=0)
457
- result_embeds.append(sample_embeds)
458
-
459
- return torch.cat(result_embeds, dim=0)
460
 
461
  def forward(
462
  self,
@@ -470,34 +454,33 @@ class ASRModel(PreTrainedModel, GenerationMixin):
470
  labels: Optional[torch.Tensor] = None,
471
  use_cache: Optional[bool] = None,
472
  cache_position: Optional[torch.Tensor] = None,
 
473
  **kwargs,
474
  ) -> CausalLMOutputWithPast:
475
  """Forward pass for training and inference."""
476
- # Get text embeddings if not provided
477
  if inputs_embeds is None:
478
  inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
479
 
480
  if input_features is not None and input_ids is not None:
481
- # Apply SpecAugment during training if enabled
482
  if self.training and self.spec_augment is not None:
483
  input_features = self.spec_augment(input_features)
484
 
485
- # Count expected audio tokens from input_ids (ground truth from collator)
486
- audio_token_counts = (input_ids == self.audio_token_id).sum(dim=-1)
 
 
 
 
 
487
 
488
- # Encode audio -> flattened (total_audio_tokens, hidden_dim)
489
- audio_embeds = self._encode_audio(
490
- input_features, audio_attention_mask, audio_token_counts
491
- )
492
 
493
- # Replace <audio> token placeholders with audio embeddings using masked_scatter
494
- audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
495
  inputs_embeds = inputs_embeds.masked_scatter(
496
  audio_token_mask.to(inputs_embeds.device),
497
  audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
498
  )
499
 
500
- # Run through language model (let it compute loss if labels provided)
501
  outputs = self.language_model(
502
  attention_mask=attention_mask,
503
  position_ids=position_ids,
@@ -509,7 +492,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
509
  **kwargs,
510
  )
511
 
512
- # Add auxiliary loss from MoE projectors if available
513
  if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
514
  aux_loss = self.projector.get_aux_loss()
515
  if aux_loss is not None and aux_loss.numel() > 0:
@@ -569,7 +551,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
569
  batch_size = input_features.shape[0]
570
 
571
  # Encode audio -> flattened embeddings
572
- audio_embeds = self._encode_audio(input_features, audio_attention_mask)
 
 
 
 
 
 
573
 
574
  # If input_ids not provided, build prompt with correct number of audio tokens
575
  if input_ids is None:
@@ -653,7 +641,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
653
  batch_size = input_features.shape[0]
654
 
655
  # Encode audio -> flattened embeddings
656
- audio_embeds = self._encode_audio(input_features, audio_attention_mask)
 
 
 
 
 
 
657
 
658
  # Build prompt with correct number of audio tokens
659
  num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
@@ -747,63 +741,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
747
 
748
  thread.join()
749
 
750
- @torch.no_grad()
751
- def generate_text_only(
752
- self,
753
- messages: list[dict[str, str]],
754
- max_new_tokens: int = 256,
755
- **generate_kwargs,
756
- ) -> str:
757
- """Generate text using only the LLM (no audio encoding).
758
-
759
- Used for SIFT-style response generation from metadata prompts.
760
-
761
- Args:
762
- messages: List of chat messages [{"role": "user", "content": "..."}]
763
- max_new_tokens: Maximum tokens to generate
764
- **generate_kwargs: Additional generation arguments
765
-
766
- Returns:
767
- Generated text response
768
- """
769
- device = next(self.language_model.parameters()).device
770
-
771
- # Apply chat template
772
- input_ids = self.tokenizer.apply_chat_template(
773
- messages,
774
- tokenize=True,
775
- add_generation_prompt=True,
776
- return_tensors="pt",
777
- enable_thinking=False, # Disable Qwen3 thinking mode for ASR
778
- ).to(device)
779
-
780
- if input_ids.dim() == 1:
781
- input_ids = input_ids.unsqueeze(0)
782
-
783
- attention_mask = torch.ones_like(input_ids)
784
-
785
- # Generate using language model directly
786
- output = self.language_model.generate(
787
- input_ids=input_ids,
788
- attention_mask=attention_mask,
789
- max_new_tokens=max_new_tokens,
790
- do_sample=False,
791
- pad_token_id=self.tokenizer.pad_token_id,
792
- eos_token_id=self.tokenizer.eos_token_id,
793
- **generate_kwargs,
794
- )
795
-
796
- # Decode only the new tokens
797
- new_tokens = output[0, input_ids.shape[1] :]
798
- response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
799
- return response.strip()
800
-
801
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
802
  """Save model, tokenizer, and processor."""
803
  import shutil
804
- from pathlib import Path as PathlibPath
805
 
806
- save_dir = PathlibPath(save_directory)
807
  save_dir.mkdir(parents=True, exist_ok=True)
808
 
809
  # Update config with actual vocab size
@@ -874,7 +816,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
874
  json.dump(processor_config, f, indent=2)
875
 
876
  # Copy source files for auto-loading
877
- src_dir = PathlibPath(__file__).parent
878
  for asr_file in src_dir.glob("asr_*.py"):
879
  shutil.copy(asr_file, save_dir / asr_file.name)
880
  # Copy projectors module
@@ -896,11 +838,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
896
  # Call parent's push_to_hub
897
  return super().push_to_hub(repo_id, **kwargs)
898
 
899
- def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
900
- """No-op for model card creation - we use MODEL_CARD.md in repo instead."""
901
- pass
902
-
903
 
904
  # Register with transformers Auto classes
905
- AutoConfig.register("asr_model", ASRConfig)
906
  AutoModel.register(ASRConfig, ASRModel)
 
5
 
6
  import torch
7
  import torch.nn as nn
8
+ import torch.nn.functional as F # noqa: N812
9
  from transformers import (
 
10
  AutoModel,
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
 
27
  from torchaudio.transforms import SpecAugment
28
 
29
 
30
+ def _gather_audio_embeds(audio_embeds: torch.Tensor, token_counts: torch.Tensor) -> torch.Tensor:
31
+ """Flatten per-sample audio embeddings into a packed tensor.
32
+
33
+ For each row i, takes the first ``token_counts[i]`` rows of
34
+ ``audio_embeds[i]`` and concatenates them. If any token count exceeds
35
+ ``audio_embeds.shape[1]``, the deficit is zero-padded.
36
+
37
+ Equivalent to a per-sample slice/cat loop but with O(1) host-device
38
+ syncs per call (one ``max().item()``) instead of one per sample.
39
+ """
40
+ _, max_len, _ = audio_embeds.shape
41
+ needed = int(token_counts.max().item())
42
+ if needed > max_len:
43
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, needed - max_len))
44
+ max_len = needed
45
+ indices = torch.arange(max_len, device=audio_embeds.device).unsqueeze(0)
46
+ mask = indices < token_counts.unsqueeze(1)
47
+ return audio_embeds[mask]
48
+
49
+
50
  class ASRModel(PreTrainedModel, GenerationMixin):
51
  """Audio-to-text model combining an audio encoder, projector, and language model."""
52
 
 
422
  def _encode_audio(
423
  self,
424
  audio_features: torch.Tensor,
425
+ expected_token_counts: torch.Tensor,
 
426
  ) -> torch.Tensor:
427
+ """Encode audio features and return flattened embeddings matching expected_token_counts.
428
 
429
  Args:
430
  audio_features: Mel spectrogram features (batch, n_mels, mel_len)
431
+ expected_token_counts: Per-sample audio token counts as int64 tensor (batch,).
 
 
432
 
433
  Returns:
434
+ Flattened audio embeddings of shape (sum(expected_token_counts), hidden_dim).
435
  """
436
  with torch.no_grad():
437
  encoder_out = self.audio_tower(input_features=audio_features)
438
  hidden_states = encoder_out.last_hidden_state
439
 
 
440
  audio_embeds = self.projector(hidden_states)
441
 
442
+ token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
443
+ return _gather_audio_embeds(audio_embeds, token_counts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  def forward(
446
  self,
 
454
  labels: Optional[torch.Tensor] = None,
455
  use_cache: Optional[bool] = None,
456
  cache_position: Optional[torch.Tensor] = None,
457
+ audio_token_counts: Optional[torch.Tensor] = None,
458
  **kwargs,
459
  ) -> CausalLMOutputWithPast:
460
  """Forward pass for training and inference."""
 
461
  if inputs_embeds is None:
462
  inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
463
 
464
  if input_features is not None and input_ids is not None:
 
465
  if self.training and self.spec_augment is not None:
466
  input_features = self.spec_augment(input_features)
467
 
468
+ is_audio_token = input_ids == self.audio_token_id
469
+ if audio_token_counts is None:
470
+ audio_token_counts = is_audio_token.sum(dim=-1)
471
+ else:
472
+ audio_token_counts = audio_token_counts.to(
473
+ device=input_ids.device, dtype=torch.long
474
+ )
475
 
476
+ audio_embeds = self._encode_audio(input_features, audio_token_counts)
 
 
 
477
 
478
+ audio_token_mask = is_audio_token.unsqueeze(-1)
 
479
  inputs_embeds = inputs_embeds.masked_scatter(
480
  audio_token_mask.to(inputs_embeds.device),
481
  audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
482
  )
483
 
 
484
  outputs = self.language_model(
485
  attention_mask=attention_mask,
486
  position_ids=position_ids,
 
492
  **kwargs,
493
  )
494
 
 
495
  if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
496
  aux_loss = self.projector.get_aux_loss()
497
  if aux_loss is not None and aux_loss.numel() > 0:
 
551
  batch_size = input_features.shape[0]
552
 
553
  # Encode audio -> flattened embeddings
554
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
555
+ token_counts = torch.tensor(
556
+ [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
557
+ device=input_features.device,
558
+ dtype=torch.long,
559
+ )
560
+ audio_embeds = self._encode_audio(input_features, token_counts)
561
 
562
  # If input_ids not provided, build prompt with correct number of audio tokens
563
  if input_ids is None:
 
641
  batch_size = input_features.shape[0]
642
 
643
  # Encode audio -> flattened embeddings
644
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
645
+ token_counts = torch.tensor(
646
+ [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
647
+ device=input_features.device,
648
+ dtype=torch.long,
649
+ )
650
+ audio_embeds = self._encode_audio(input_features, token_counts)
651
 
652
  # Build prompt with correct number of audio tokens
653
  num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
 
741
 
742
  thread.join()
743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
745
  """Save model, tokenizer, and processor."""
746
  import shutil
 
747
 
748
+ save_dir = Path(save_directory)
749
  save_dir.mkdir(parents=True, exist_ok=True)
750
 
751
  # Update config with actual vocab size
 
816
  json.dump(processor_config, f, indent=2)
817
 
818
  # Copy source files for auto-loading
819
+ src_dir = Path(__file__).parent
820
  for asr_file in src_dir.glob("asr_*.py"):
821
  shutil.copy(asr_file, save_dir / asr_file.name)
822
  # Copy projectors module
 
838
  # Call parent's push_to_hub
839
  return super().push_to_hub(repo_id, **kwargs)
840
 
 
 
 
 
841
 
842
  # Register with transformers Auto classes
843
+ # (AutoConfig.register is handled in asr_config.py at module load.)
844
  AutoModel.register(ASRConfig, ASRModel)
asr_pipeline.py CHANGED
@@ -7,6 +7,7 @@ from typing import Any
7
  import numpy as np
8
  import torch
9
  import transformers
 
10
 
11
  try:
12
  from .alignment import ForcedAligner
@@ -20,6 +21,13 @@ except ImportError:
20
  # Re-export for backwards compatibility
21
  __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
22
 
 
 
 
 
 
 
 
23
 
24
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
25
  """ASR Pipeline for audio-to-text transcription."""
@@ -152,8 +160,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
152
 
153
  def _extract_audio(self, inputs) -> dict | None:
154
  """Extract audio array from various input formats using HF utilities."""
155
- from transformers.pipelines.audio_utils import ffmpeg_read
156
-
157
  if isinstance(inputs, dict):
158
  if "array" in inputs:
159
  return {
@@ -257,8 +263,8 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
257
 
258
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
259
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
260
- text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
261
- # Truncate repetitions at end of text
262
  text = _truncate_repetitions(text)
263
  return {"text": text}
264
 
@@ -281,14 +287,16 @@ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
281
  if not text:
282
  return text
283
 
284
- # 1. Truncate repeated characters at end (e.g., "444444" -> "4")
285
- char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$")
286
- text = char_pattern.sub(r"\1", text)
 
 
 
 
 
287
 
288
- # 2. Truncate repeated words at end (e.g., "the the the" -> "the")
289
- word_pattern = re.compile(
290
- r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
291
- )
292
  while word_pattern.search(text):
293
  text = word_pattern.sub(r"\1", text)
294
 
 
7
  import numpy as np
8
  import torch
9
  import transformers
10
+ from transformers.pipelines.audio_utils import ffmpeg_read
11
 
12
  try:
13
  from .alignment import ForcedAligner
 
21
  # Re-export for backwards compatibility
22
  __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
23
 
24
+ _THINK_TAG_RE = re.compile(r"<think>.*?</think>\s*", flags=re.DOTALL)
25
+ _DEFAULT_MIN_REPEATS = 3
26
+ _TRAILING_CHAR_RE = re.compile(r"(.)\1{" + str(_DEFAULT_MIN_REPEATS - 1) + r",}$")
27
+ _TRAILING_WORD_RE = re.compile(
28
+ r"\b(\w+)(?:\s+\1){" + str(_DEFAULT_MIN_REPEATS - 1) + r",}\s*$", re.IGNORECASE
29
+ )
30
+
31
 
32
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
33
  """ASR Pipeline for audio-to-text transcription."""
 
160
 
161
  def _extract_audio(self, inputs) -> dict | None:
162
  """Extract audio array from various input formats using HF utilities."""
 
 
163
  if isinstance(inputs, dict):
164
  if "array" in inputs:
165
  return {
 
263
 
264
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
265
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
266
+ if "<think>" in text:
267
+ text = _THINK_TAG_RE.sub("", text).strip()
268
  text = _truncate_repetitions(text)
269
  return {"text": text}
270
 
 
287
  if not text:
288
  return text
289
 
290
+ if min_repeats == _DEFAULT_MIN_REPEATS:
291
+ char_pattern = _TRAILING_CHAR_RE
292
+ word_pattern = _TRAILING_WORD_RE
293
+ else:
294
+ char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$")
295
+ word_pattern = re.compile(
296
+ r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
297
+ )
298
 
299
+ text = char_pattern.sub(r"\1", text)
 
 
 
300
  while word_pattern.search(text):
301
  text = word_pattern.sub(r"\1", text)
302
 
asr_processing.py CHANGED
@@ -5,9 +5,9 @@ import transformers
5
  from transformers import ProcessorMixin
6
 
7
  try:
8
- from .asr_config import ASRConfig
9
  except ImportError:
10
- from asr_config import ASRConfig # type: ignore[no-redef]
11
 
12
 
13
  class ASRProcessor(ProcessorMixin):
@@ -18,8 +18,6 @@ class ASRProcessor(ProcessorMixin):
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
  TRANSCRIBE_PROMPT = ""
21
- # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
- DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
 
24
  def __init__(
25
  self,
@@ -40,7 +38,7 @@ class ASRProcessor(ProcessorMixin):
40
  self.tokenizer = tokenizer
41
  self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
42
  self.projector = projector
43
- self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
44
 
45
  def _compute_encoder_output_length(self, mel_length: int) -> int:
46
  """Compute encoder output length using conv layer formulas."""
 
5
  from transformers import ProcessorMixin
6
 
7
  try:
8
+ from .asr_config import DEFAULT_ENCODER_CONV_LAYERS, ASRConfig
9
  except ImportError:
10
+ from asr_config import DEFAULT_ENCODER_CONV_LAYERS, ASRConfig # type: ignore[no-redef]
11
 
12
 
13
  class ASRProcessor(ProcessorMixin):
 
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
  TRANSCRIBE_PROMPT = ""
 
 
21
 
22
  def __init__(
23
  self,
 
38
  self.tokenizer = tokenizer
39
  self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
40
  self.projector = projector
41
+ self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
42
 
43
  def _compute_encoder_output_length(self, mel_length: int) -> int:
44
  """Compute encoder output length using conv layer formulas."""
diarization.py CHANGED
@@ -154,8 +154,6 @@ class SpeakerClusterer:
154
  Returns:
155
  Cluster labels of shape [N]
156
  """
157
- import warnings
158
-
159
  if len(embeddings.shape) != 2:
160
  raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
161
 
 
154
  Returns:
155
  Cluster labels of shape [N]
156
  """
 
 
157
  if len(embeddings.shape) != 2:
158
  raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
159
 
projectors.py CHANGED
@@ -58,13 +58,7 @@ class MLPAudioProjector(nn.Module):
58
  Returns:
59
  Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
60
  """
61
- batch, seq, dim = x.shape
62
- # Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
63
- # This drops trailing frames that don't fill a complete k-frame window
64
- out_len = (seq - self.k) // self.k + 1
65
- x = x[:, : out_len * self.k, :] # Truncate to exact multiple
66
- x = x.reshape(batch, out_len, dim * self.k)
67
-
68
  x = self.linear_1(x)
69
  x = self.norm(x)
70
  x = self.act(x)
@@ -76,6 +70,17 @@ class MLPAudioProjector(nn.Module):
76
  # =============================================================================
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
79
  class SimpleAdapter(nn.Module):
80
  """Simple 2-layer GELU adapter (from MOSA paper)."""
81
 
@@ -89,34 +94,6 @@ class SimpleAdapter(nn.Module):
89
  return self.fc2(self.act(self.fc1(x)))
90
 
91
 
92
- class SwiGLU(nn.Module):
93
- """SwiGLU activation with gated linear units (used in LLaMA, Mistral, etc.)."""
94
-
95
- def __init__(self, dim: int, hidden_dim: int, bias: bool = False):
96
- super().__init__()
97
- self.w1 = nn.Linear(dim, hidden_dim, bias=bias) # Gate
98
- self.w2 = nn.Linear(dim, hidden_dim, bias=bias) # Value
99
- self.w3 = nn.Linear(hidden_dim, dim, bias=bias) # Output
100
-
101
- def forward(self, x: torch.Tensor) -> torch.Tensor:
102
- return self.w3(F.silu(self.w1(x)) * self.w2(x))
103
-
104
-
105
- class AsymmetricSwiGLU(nn.Module):
106
- """SwiGLU that handles different input and output dimensions."""
107
-
108
- def __init__(
109
- self, in_features: int, hidden_features: int, out_features: int, bias: bool = False
110
- ):
111
- super().__init__()
112
- self.w1 = nn.Linear(in_features, hidden_features, bias=bias) # Gate
113
- self.w2 = nn.Linear(in_features, hidden_features, bias=bias) # Value
114
- self.w3 = nn.Linear(hidden_features, out_features, bias=bias) # Output
115
-
116
- def forward(self, x: torch.Tensor) -> torch.Tensor:
117
- return self.w3(F.silu(self.w1(x)) * self.w2(x))
118
-
119
-
120
  class MOSAProjector(nn.Module):
121
  """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
122
 
@@ -281,13 +258,10 @@ class MoEAudioProjector(nn.Module):
281
  Returns:
282
  Projected features of shape [batch, out_len, llm_dim]
283
  """
284
- # 1. Frame Stacking
285
- batch, seq, dim = x.shape
286
- out_len = (seq - self.k) // self.k + 1
287
- x = x[:, : out_len * self.k, :]
288
- x = x.reshape(batch, out_len, dim * self.k)
289
 
290
- # 2. Normalize stacked input (like main branch SharedMoEBlock)
291
  x = self.norm(x)
292
  flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
293
 
 
58
  Returns:
59
  Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
60
  """
61
+ x = _frame_stack(x, self.k)
 
 
 
 
 
 
62
  x = self.linear_1(x)
63
  x = self.norm(x)
64
  x = self.act(x)
 
70
  # =============================================================================
71
 
72
 
73
+ def _frame_stack(x: torch.Tensor, k: int) -> torch.Tensor:
74
+ """Stack k adjacent frames along the feature dim.
75
+
76
+ Truncates trailing frames that don't fill a complete k-frame window,
77
+ matching GLM-ASR's `(seq_len - k) // k + 1` formula.
78
+ """
79
+ batch, seq, dim = x.shape
80
+ out_len = (seq - k) // k + 1
81
+ return x[:, : out_len * k, :].reshape(batch, out_len, dim * k)
82
+
83
+
84
  class SimpleAdapter(nn.Module):
85
  """Simple 2-layer GELU adapter (from MOSA paper)."""
86
 
 
94
  return self.fc2(self.act(self.fc1(x)))
95
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  class MOSAProjector(nn.Module):
98
  """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
99
 
 
258
  Returns:
259
  Projected features of shape [batch, out_len, llm_dim]
260
  """
261
+ x = _frame_stack(x, self.k)
262
+ batch, out_len, _ = x.shape
 
 
 
263
 
264
+ # Normalize stacked input (like main branch SharedMoEBlock)
265
  x = self.norm(x)
266
  flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
267