Training in progress - step 1000
Browse files- asr_config.py +18 -29
- asr_modeling.py +53 -58
- asr_pipeline.py +23 -29
- asr_processing.py +7 -6
- projectors.py +41 -33
asr_config.py
CHANGED
|
@@ -6,6 +6,19 @@ import transformers
|
|
| 6 |
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
class ASRConfig(transformers.PretrainedConfig):
|
| 10 |
"""Configuration class for the ASR model.
|
| 11 |
|
|
@@ -14,7 +27,7 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 14 |
- Text decoder (Qwen)
|
| 15 |
- Projector (MLP, MOSA, MoE, QFormer)
|
| 16 |
- Generation parameters
|
| 17 |
-
- Training options (
|
| 18 |
"""
|
| 19 |
|
| 20 |
model_type = "asr_model"
|
|
@@ -38,9 +51,6 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 38 |
downsample_rate: int = 5, # Granite default
|
| 39 |
projector_hidden_dim: Optional[int] = None,
|
| 40 |
projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
|
| 41 |
-
projector_num_layers: int = 2, # Number of layers in MLP projector
|
| 42 |
-
projector_init_std: float = 0.02, # Weight initialization std
|
| 43 |
-
projector_dropout: float = 0.0, # Dropout rate for projector layers
|
| 44 |
# MoE-specific configuration
|
| 45 |
num_experts: int = 4, # Number of experts in MoE projectors
|
| 46 |
num_experts_per_tok: int = 2, # Top-k experts per token
|
|
@@ -51,14 +61,6 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 51 |
qformer_num_layers: int = 2, # Number of QFormer transformer layers
|
| 52 |
qformer_num_heads: int = 16, # Number of attention heads in QFormer
|
| 53 |
qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
|
| 54 |
-
label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
|
| 55 |
-
inference_warmup_tokens: int = 10,
|
| 56 |
-
# SpecAugment settings
|
| 57 |
-
use_specaugment: bool = False,
|
| 58 |
-
num_time_masks: int = 2,
|
| 59 |
-
time_mask_length: int = 10,
|
| 60 |
-
num_freq_masks: int = 0,
|
| 61 |
-
freq_mask_length: int = 10,
|
| 62 |
# LoRA configuration (for Stage 2 fine-tuning)
|
| 63 |
use_lora: bool = False,
|
| 64 |
lora_rank: int = 8, # SALMONN default
|
|
@@ -88,22 +90,20 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 88 |
model_dtype: Model dtype ("bfloat16", "float16", "float32")
|
| 89 |
projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
|
| 90 |
use_lora: Enable LoRA adapters for Stage 2 fine-tuning
|
| 91 |
-
use_specaugment: Enable SpecAugment data augmentation
|
| 92 |
"""
|
| 93 |
-
# Set default generation parameters (greedy decoding only)
|
|
|
|
|
|
|
| 94 |
generation_defaults = {
|
| 95 |
"num_beams": 1,
|
| 96 |
"max_new_tokens": 128,
|
| 97 |
"min_new_tokens": 0,
|
| 98 |
"repetition_penalty": 1.0,
|
| 99 |
"length_penalty": 1.0,
|
| 100 |
-
"no_repeat_ngram_size": 0,
|
| 101 |
"use_cache": True,
|
| 102 |
}
|
| 103 |
|
| 104 |
-
# Apply defaults (config.json values take precedence)
|
| 105 |
-
kwargs = {**generation_defaults, **kwargs}
|
| 106 |
-
|
| 107 |
self.audio_model_id = audio_model_id
|
| 108 |
self.text_model_id = text_model_id
|
| 109 |
self.attn_implementation = attn_implementation
|
|
@@ -113,13 +113,10 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 113 |
self.llm_dim = llm_dim
|
| 114 |
self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
|
| 115 |
self.audio_sample_rate = audio_sample_rate
|
| 116 |
-
self.projector_init_std = projector_init_std
|
| 117 |
self.projector_pool_stride = projector_pool_stride
|
| 118 |
self.downsample_rate = downsample_rate
|
| 119 |
self.projector_hidden_dim = projector_hidden_dim
|
| 120 |
self.projector_type = projector_type
|
| 121 |
-
self.projector_num_layers = projector_num_layers
|
| 122 |
-
self.projector_dropout = projector_dropout
|
| 123 |
# MoE-specific configuration
|
| 124 |
self.num_experts = num_experts
|
| 125 |
self.num_experts_per_tok = num_experts_per_tok
|
|
@@ -130,14 +127,6 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 130 |
self.qformer_num_layers = qformer_num_layers
|
| 131 |
self.qformer_num_heads = qformer_num_heads
|
| 132 |
self.qformer_intermediate_size = qformer_intermediate_size
|
| 133 |
-
self.label_smoothing = label_smoothing
|
| 134 |
-
self.inference_warmup_tokens = inference_warmup_tokens
|
| 135 |
-
# SpecAugment configuration
|
| 136 |
-
self.use_specaugment = use_specaugment
|
| 137 |
-
self.num_time_masks = num_time_masks
|
| 138 |
-
self.time_mask_length = time_mask_length
|
| 139 |
-
self.num_freq_masks = num_freq_masks
|
| 140 |
-
self.freq_mask_length = freq_mask_length
|
| 141 |
# LoRA configuration
|
| 142 |
self.use_lora = use_lora
|
| 143 |
self.lora_rank = lora_rank
|
|
|
|
| 6 |
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
|
| 7 |
|
| 8 |
|
| 9 |
+
def compute_encoder_output_length(mel_length, conv_layers=None):
|
| 10 |
+
"""Apply encoder conv layer formulas to compute output length.
|
| 11 |
+
|
| 12 |
+
Works with both Python ints and torch tensors of mel lengths; the formula
|
| 13 |
+
`(L + 2*p - (k-1) - 1) // s + 1` per layer is identical for both.
|
| 14 |
+
"""
|
| 15 |
+
layers = conv_layers if conv_layers is not None else DEFAULT_ENCODER_CONV_LAYERS
|
| 16 |
+
length = mel_length
|
| 17 |
+
for padding, kernel_size, stride in layers:
|
| 18 |
+
length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
| 19 |
+
return length
|
| 20 |
+
|
| 21 |
+
|
| 22 |
class ASRConfig(transformers.PretrainedConfig):
|
| 23 |
"""Configuration class for the ASR model.
|
| 24 |
|
|
|
|
| 27 |
- Text decoder (Qwen)
|
| 28 |
- Projector (MLP, MOSA, MoE, QFormer)
|
| 29 |
- Generation parameters
|
| 30 |
+
- Training options (LoRA)
|
| 31 |
"""
|
| 32 |
|
| 33 |
model_type = "asr_model"
|
|
|
|
| 51 |
downsample_rate: int = 5, # Granite default
|
| 52 |
projector_hidden_dim: Optional[int] = None,
|
| 53 |
projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
|
|
|
|
|
|
|
|
|
|
| 54 |
# MoE-specific configuration
|
| 55 |
num_experts: int = 4, # Number of experts in MoE projectors
|
| 56 |
num_experts_per_tok: int = 2, # Top-k experts per token
|
|
|
|
| 61 |
qformer_num_layers: int = 2, # Number of QFormer transformer layers
|
| 62 |
qformer_num_heads: int = 16, # Number of attention heads in QFormer
|
| 63 |
qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
# LoRA configuration (for Stage 2 fine-tuning)
|
| 65 |
use_lora: bool = False,
|
| 66 |
lora_rank: int = 8, # SALMONN default
|
|
|
|
| 90 |
model_dtype: Model dtype ("bfloat16", "float16", "float32")
|
| 91 |
projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
|
| 92 |
use_lora: Enable LoRA adapters for Stage 2 fine-tuning
|
|
|
|
| 93 |
"""
|
| 94 |
+
# Set default generation parameters (greedy decoding only).
|
| 95 |
+
# Applied via setattr below — keeping these out of kwargs so they
|
| 96 |
+
# don't get re-overwritten by super().__init__(**kwargs) at the end.
|
| 97 |
generation_defaults = {
|
| 98 |
"num_beams": 1,
|
| 99 |
"max_new_tokens": 128,
|
| 100 |
"min_new_tokens": 0,
|
| 101 |
"repetition_penalty": 1.0,
|
| 102 |
"length_penalty": 1.0,
|
| 103 |
+
"no_repeat_ngram_size": 0,
|
| 104 |
"use_cache": True,
|
| 105 |
}
|
| 106 |
|
|
|
|
|
|
|
|
|
|
| 107 |
self.audio_model_id = audio_model_id
|
| 108 |
self.text_model_id = text_model_id
|
| 109 |
self.attn_implementation = attn_implementation
|
|
|
|
| 113 |
self.llm_dim = llm_dim
|
| 114 |
self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
|
| 115 |
self.audio_sample_rate = audio_sample_rate
|
|
|
|
| 116 |
self.projector_pool_stride = projector_pool_stride
|
| 117 |
self.downsample_rate = downsample_rate
|
| 118 |
self.projector_hidden_dim = projector_hidden_dim
|
| 119 |
self.projector_type = projector_type
|
|
|
|
|
|
|
| 120 |
# MoE-specific configuration
|
| 121 |
self.num_experts = num_experts
|
| 122 |
self.num_experts_per_tok = num_experts_per_tok
|
|
|
|
| 127 |
self.qformer_num_layers = qformer_num_layers
|
| 128 |
self.qformer_num_heads = qformer_num_heads
|
| 129 |
self.qformer_intermediate_size = qformer_intermediate_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
# LoRA configuration
|
| 131 |
self.use_lora = use_lora
|
| 132 |
self.lora_rank = lora_rank
|
asr_modeling.py
CHANGED
|
@@ -17,16 +17,13 @@ from transformers.generation import GenerationMixin
|
|
| 17 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 18 |
|
| 19 |
try:
|
| 20 |
-
from .asr_config import ASRConfig
|
| 21 |
from .projectors import PROJECTOR_CLASSES
|
| 22 |
except ImportError:
|
| 23 |
-
from asr_config import ASRConfig # type: ignore[no-redef]
|
| 24 |
from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
|
| 25 |
|
| 26 |
|
| 27 |
-
from torchaudio.transforms import SpecAugment
|
| 28 |
-
|
| 29 |
-
|
| 30 |
def _gather_audio_embeds(audio_embeds: torch.Tensor, token_counts: torch.Tensor) -> torch.Tensor:
|
| 31 |
"""Flatten per-sample audio embeddings into a packed tensor.
|
| 32 |
|
|
@@ -56,7 +53,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 56 |
_supports_flash_attn_2 = True
|
| 57 |
supports_gradient_checkpointing = True
|
| 58 |
_is_loading_from_pretrained: bool = False
|
| 59 |
-
_pretrained_model_path: Optional[str] = None
|
| 60 |
|
| 61 |
TRANSCRIBE_PROMPT = "Transcribe the speech to text"
|
| 62 |
|
|
@@ -72,7 +68,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 72 |
|
| 73 |
# Set flag to avoid device_map="auto" in sub-model loaders
|
| 74 |
cls._is_loading_from_pretrained = True
|
| 75 |
-
cls._pretrained_model_path = pretrained_model_name_or_path
|
| 76 |
|
| 77 |
try:
|
| 78 |
model = cls(config, **kwargs)
|
|
@@ -134,7 +129,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 134 |
return model
|
| 135 |
finally:
|
| 136 |
cls._is_loading_from_pretrained = False
|
| 137 |
-
cls._pretrained_model_path = None
|
| 138 |
|
| 139 |
def __init__(self, config: ASRConfig, **kwargs) -> None:
|
| 140 |
super().__init__(config)
|
|
@@ -190,17 +184,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 190 |
if getattr(config, "freeze_projector", False):
|
| 191 |
self.projector.requires_grad_(False)
|
| 192 |
|
| 193 |
-
# SpecAugment for data augmentation during training
|
| 194 |
-
if getattr(config, "use_specaugment", False):
|
| 195 |
-
self.spec_augment = SpecAugment(
|
| 196 |
-
n_time_masks=config.num_time_masks,
|
| 197 |
-
time_mask_param=config.time_mask_length,
|
| 198 |
-
n_freq_masks=config.num_freq_masks,
|
| 199 |
-
freq_mask_param=config.freq_mask_length,
|
| 200 |
-
)
|
| 201 |
-
else:
|
| 202 |
-
self.spec_augment = None
|
| 203 |
-
|
| 204 |
# For model parallelism
|
| 205 |
self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
|
| 206 |
|
|
@@ -340,7 +323,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 340 |
self.tokenizer.add_special_tokens(
|
| 341 |
{"additional_special_tokens": existing_special + ["<audio>"]}
|
| 342 |
)
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
|
| 346 |
self.tokenizer.padding_side = "right"
|
|
@@ -352,9 +341,20 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 352 |
cfg.eos_token_id = self.tokenizer.eos_token_id
|
| 353 |
cfg.bos_token_id = self.tokenizer.bos_token_id
|
| 354 |
|
| 355 |
-
def
|
| 356 |
-
"""
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
|
| 360 |
"""Enable/disable gradient checkpointing for the language model."""
|
|
@@ -396,34 +396,40 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 396 |
)
|
| 397 |
|
| 398 |
def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
|
| 399 |
-
"""Save trainable weights: projector, plus the language model when fine-tuned.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
sd = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
|
| 401 |
if not getattr(self.config, "freeze_language_model", True):
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
return sd
|
| 406 |
|
| 407 |
def _compute_encoder_output_lengths(
|
| 408 |
self,
|
| 409 |
audio_attention_mask: torch.Tensor,
|
| 410 |
) -> torch.Tensor:
|
| 411 |
-
"""Compute per-sample encoder output lengths using conv layer formulas.
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
Returns:
|
| 417 |
-
Tensor of encoder output lengths per sample (batch,)
|
| 418 |
-
"""
|
| 419 |
-
# Get mel frame lengths from attention mask
|
| 420 |
-
lengths = audio_attention_mask.sum(dim=-1)
|
| 421 |
-
|
| 422 |
-
# Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
|
| 423 |
-
for padding, kernel_size, stride in self.config.encoder_conv_layers:
|
| 424 |
-
lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
| 425 |
-
|
| 426 |
-
return lengths
|
| 427 |
|
| 428 |
def _encode_audio(
|
| 429 |
self,
|
|
@@ -468,9 +474,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 468 |
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 469 |
|
| 470 |
if input_features is not None and input_ids is not None:
|
| 471 |
-
if self.training and self.spec_augment is not None:
|
| 472 |
-
input_features = self.spec_augment(input_features)
|
| 473 |
-
|
| 474 |
is_audio_token = input_ids == self.audio_token_id
|
| 475 |
if audio_token_counts is None:
|
| 476 |
audio_token_counts = is_audio_token.sum(dim=-1)
|
|
@@ -556,13 +559,9 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 556 |
device = input_features.device
|
| 557 |
batch_size = input_features.shape[0]
|
| 558 |
|
| 559 |
-
# Encode audio -> flattened embeddings
|
| 560 |
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 561 |
-
token_counts =
|
| 562 |
-
[self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
|
| 563 |
-
device=input_features.device,
|
| 564 |
-
dtype=torch.long,
|
| 565 |
-
)
|
| 566 |
audio_embeds = self._encode_audio(input_features, token_counts)
|
| 567 |
|
| 568 |
# If input_ids not provided, build prompt with correct number of audio tokens
|
|
@@ -646,13 +645,9 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 646 |
device = input_features.device
|
| 647 |
batch_size = input_features.shape[0]
|
| 648 |
|
| 649 |
-
# Encode audio -> flattened embeddings
|
| 650 |
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 651 |
-
token_counts =
|
| 652 |
-
[self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
|
| 653 |
-
device=input_features.device,
|
| 654 |
-
dtype=torch.long,
|
| 655 |
-
)
|
| 656 |
audio_embeds = self._encode_audio(input_features, token_counts)
|
| 657 |
|
| 658 |
# Build prompt with correct number of audio tokens
|
|
|
|
| 17 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 18 |
|
| 19 |
try:
|
| 20 |
+
from .asr_config import ASRConfig, compute_encoder_output_length
|
| 21 |
from .projectors import PROJECTOR_CLASSES
|
| 22 |
except ImportError:
|
| 23 |
+
from asr_config import ASRConfig, compute_encoder_output_length # type: ignore[no-redef]
|
| 24 |
from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
| 27 |
def _gather_audio_embeds(audio_embeds: torch.Tensor, token_counts: torch.Tensor) -> torch.Tensor:
|
| 28 |
"""Flatten per-sample audio embeddings into a packed tensor.
|
| 29 |
|
|
|
|
| 53 |
_supports_flash_attn_2 = True
|
| 54 |
supports_gradient_checkpointing = True
|
| 55 |
_is_loading_from_pretrained: bool = False
|
|
|
|
| 56 |
|
| 57 |
TRANSCRIBE_PROMPT = "Transcribe the speech to text"
|
| 58 |
|
|
|
|
| 68 |
|
| 69 |
# Set flag to avoid device_map="auto" in sub-model loaders
|
| 70 |
cls._is_loading_from_pretrained = True
|
|
|
|
| 71 |
|
| 72 |
try:
|
| 73 |
model = cls(config, **kwargs)
|
|
|
|
| 129 |
return model
|
| 130 |
finally:
|
| 131 |
cls._is_loading_from_pretrained = False
|
|
|
|
| 132 |
|
| 133 |
def __init__(self, config: ASRConfig, **kwargs) -> None:
|
| 134 |
super().__init__(config)
|
|
|
|
| 184 |
if getattr(config, "freeze_projector", False):
|
| 185 |
self.projector.requires_grad_(False)
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
# For model parallelism
|
| 188 |
self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
|
| 189 |
|
|
|
|
| 323 |
self.tokenizer.add_special_tokens(
|
| 324 |
{"additional_special_tokens": existing_special + ["<audio>"]}
|
| 325 |
)
|
| 326 |
+
# mean_resizing=True initializes the new <audio> row at the mean of
|
| 327 |
+
# existing rows so its scale matches the pretrained distribution. The
|
| 328 |
+
# input-side <audio> embedding is overwritten via masked_scatter and
|
| 329 |
+
# never seen by the LM, but with tied embeddings (Qwen3-0.6B) this
|
| 330 |
+
# same row is the lm_head column for predicting <audio>; a Gaussian
|
| 331 |
+
# draw at config.initializer_range was visible in early-step logits.
|
| 332 |
+
self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=True)
|
| 333 |
|
| 334 |
self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
|
| 335 |
self.tokenizer.padding_side = "right"
|
|
|
|
| 341 |
cfg.eos_token_id = self.tokenizer.eos_token_id
|
| 342 |
cfg.bos_token_id = self.tokenizer.bos_token_id
|
| 343 |
|
| 344 |
+
def train(self, mode: bool = True):
|
| 345 |
+
"""Set train/eval mode, but keep frozen submodules out of train mode.
|
| 346 |
+
|
| 347 |
+
HF Trainer calls `model.train()` at the top of every training step, which
|
| 348 |
+
recursively switches every submodule into train mode — re-enabling dropout
|
| 349 |
+
on modules with `requires_grad_(False)`. The frozen encoder (and the LM
|
| 350 |
+
when `freeze_language_model=True`) should always run deterministically;
|
| 351 |
+
train-mode dropout only adds noise that can't improve a frozen network.
|
| 352 |
+
"""
|
| 353 |
+
super().train(mode)
|
| 354 |
+
self.audio_tower.train(False)
|
| 355 |
+
if getattr(self.config, "freeze_language_model", True):
|
| 356 |
+
self.language_model.train(False)
|
| 357 |
+
return self
|
| 358 |
|
| 359 |
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
|
| 360 |
"""Enable/disable gradient checkpointing for the language model."""
|
|
|
|
| 396 |
)
|
| 397 |
|
| 398 |
def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
|
| 399 |
+
"""Save trainable weights: projector, plus the language model when fine-tuned.
|
| 400 |
+
|
| 401 |
+
With LoRA attached, the language_model entries are flattened to plain
|
| 402 |
+
(non-PEFT) HF naming so model.safetensors round-trips through
|
| 403 |
+
ASRModel.from_pretrained — which builds a vanilla base LM, overlays
|
| 404 |
+
these weights, and only then re-attaches PEFT. lora_*/adapter weights
|
| 405 |
+
are skipped here; PEFT serializes them separately as
|
| 406 |
+
adapter_model.safetensors via the save_pretrained path below.
|
| 407 |
+
"""
|
| 408 |
sd = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
|
| 409 |
if not getattr(self.config, "freeze_language_model", True):
|
| 410 |
+
lm = self.language_model
|
| 411 |
+
if hasattr(lm, "peft_config"):
|
| 412 |
+
for k, v in lm.state_dict().items():
|
| 413 |
+
if "lora_" in k:
|
| 414 |
+
continue
|
| 415 |
+
if k.startswith("base_model.model."):
|
| 416 |
+
k = k[len("base_model.model.") :]
|
| 417 |
+
# LoRA layers wrap the original Linear as `<name>.base_layer.<weight|bias>`.
|
| 418 |
+
k = k.replace(".base_layer.", ".")
|
| 419 |
+
sd[f"language_model.{k}"] = v
|
| 420 |
+
else:
|
| 421 |
+
sd.update({f"language_model.{k}": v for k, v in lm.state_dict().items()})
|
| 422 |
return sd
|
| 423 |
|
| 424 |
def _compute_encoder_output_lengths(
|
| 425 |
self,
|
| 426 |
audio_attention_mask: torch.Tensor,
|
| 427 |
) -> torch.Tensor:
|
| 428 |
+
"""Compute per-sample encoder output lengths using conv layer formulas."""
|
| 429 |
+
return compute_encoder_output_length(
|
| 430 |
+
audio_attention_mask.sum(dim=-1),
|
| 431 |
+
self.config.encoder_conv_layers,
|
| 432 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
|
| 434 |
def _encode_audio(
|
| 435 |
self,
|
|
|
|
| 474 |
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 475 |
|
| 476 |
if input_features is not None and input_ids is not None:
|
|
|
|
|
|
|
|
|
|
| 477 |
is_audio_token = input_ids == self.audio_token_id
|
| 478 |
if audio_token_counts is None:
|
| 479 |
audio_token_counts = is_audio_token.sum(dim=-1)
|
|
|
|
| 559 |
device = input_features.device
|
| 560 |
batch_size = input_features.shape[0]
|
| 561 |
|
| 562 |
+
# Encode audio -> flattened embeddings (no per-sample host sync)
|
| 563 |
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 564 |
+
token_counts = self.projector.get_output_length(encoder_lengths).to(torch.long)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
audio_embeds = self._encode_audio(input_features, token_counts)
|
| 566 |
|
| 567 |
# If input_ids not provided, build prompt with correct number of audio tokens
|
|
|
|
| 645 |
device = input_features.device
|
| 646 |
batch_size = input_features.shape[0]
|
| 647 |
|
| 648 |
+
# Encode audio -> flattened embeddings (no per-sample host sync)
|
| 649 |
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 650 |
+
token_counts = self.projector.get_output_length(encoder_lengths).to(torch.long)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
audio_embeds = self._encode_audio(input_features, token_counts)
|
| 652 |
|
| 653 |
# Build prompt with correct number of audio tokens
|
asr_pipeline.py
CHANGED
|
@@ -23,9 +23,9 @@ __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
|
|
| 23 |
|
| 24 |
_THINK_TAG_RE = re.compile(r"<think>.*?</think>\s*", flags=re.DOTALL)
|
| 25 |
_DEFAULT_MIN_REPEATS = 3
|
| 26 |
-
_TRAILING_CHAR_RE = re.compile(
|
| 27 |
_TRAILING_WORD_RE = re.compile(
|
| 28 |
-
|
| 29 |
)
|
| 30 |
|
| 31 |
|
|
@@ -291,10 +291,8 @@ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
|
|
| 291 |
char_pattern = _TRAILING_CHAR_RE
|
| 292 |
word_pattern = _TRAILING_WORD_RE
|
| 293 |
else:
|
| 294 |
-
char_pattern = re.compile(
|
| 295 |
-
word_pattern = re.compile(
|
| 296 |
-
r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
|
| 297 |
-
)
|
| 298 |
|
| 299 |
text = char_pattern.sub(r"\1", text)
|
| 300 |
while word_pattern.search(text):
|
|
@@ -303,28 +301,24 @@ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
|
|
| 303 |
# 3. Truncate repeated phrases (2-20 words) at end
|
| 304 |
# e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
|
| 305 |
words = text.split()
|
| 306 |
-
if len(words)
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
# Keep prefix + one instance of the phrase
|
| 326 |
-
text = (match.group(1) + match.group(2)).strip()
|
| 327 |
-
words = text.split()
|
| 328 |
-
break
|
| 329 |
|
| 330 |
return text
|
|
|
|
| 23 |
|
| 24 |
_THINK_TAG_RE = re.compile(r"<think>.*?</think>\s*", flags=re.DOTALL)
|
| 25 |
_DEFAULT_MIN_REPEATS = 3
|
| 26 |
+
_TRAILING_CHAR_RE = re.compile(rf"(.)\1{{{_DEFAULT_MIN_REPEATS - 1},}}$")
|
| 27 |
_TRAILING_WORD_RE = re.compile(
|
| 28 |
+
rf"\b(\w+)(?:\s+\1){{{_DEFAULT_MIN_REPEATS - 1},}}\s*$", re.IGNORECASE
|
| 29 |
)
|
| 30 |
|
| 31 |
|
|
|
|
| 291 |
char_pattern = _TRAILING_CHAR_RE
|
| 292 |
word_pattern = _TRAILING_WORD_RE
|
| 293 |
else:
|
| 294 |
+
char_pattern = re.compile(rf"(.)\1{{{min_repeats - 1},}}$")
|
| 295 |
+
word_pattern = re.compile(rf"\b(\w+)(?:\s+\1){{{min_repeats - 1},}}\s*$", re.IGNORECASE)
|
|
|
|
|
|
|
| 296 |
|
| 297 |
text = char_pattern.sub(r"\1", text)
|
| 298 |
while word_pattern.search(text):
|
|
|
|
| 301 |
# 3. Truncate repeated phrases (2-20 words) at end
|
| 302 |
# e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
|
| 303 |
words = text.split()
|
| 304 |
+
if len(words) < min_repeats * 2:
|
| 305 |
+
return text
|
| 306 |
+
|
| 307 |
+
# Cheap pre-check: trailing window must contain duplicates for any phrase repeat
|
| 308 |
+
# to be possible. set(window) == window means all unique → no repetition.
|
| 309 |
+
window = words[-min_repeats * 2 :]
|
| 310 |
+
if len(set(window)) == len(window):
|
| 311 |
+
return text
|
| 312 |
+
|
| 313 |
+
for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
|
| 314 |
+
phrase_escaped = re.escape(" ".join(words[-phrase_len:]))
|
| 315 |
+
phrase_pattern = re.compile(
|
| 316 |
+
rf"(^|.*?\s)({phrase_escaped})(?:\s+{phrase_escaped}){{{min_repeats - 1},}}\s*$",
|
| 317 |
+
re.IGNORECASE,
|
| 318 |
+
)
|
| 319 |
+
match = phrase_pattern.match(text)
|
| 320 |
+
if match:
|
| 321 |
+
text = (match.group(1) + match.group(2)).strip()
|
| 322 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
return text
|
asr_processing.py
CHANGED
|
@@ -5,9 +5,13 @@ import transformers
|
|
| 5 |
from transformers import ProcessorMixin
|
| 6 |
|
| 7 |
try:
|
| 8 |
-
from .asr_config import DEFAULT_ENCODER_CONV_LAYERS, ASRConfig
|
| 9 |
except ImportError:
|
| 10 |
-
from asr_config import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class ASRProcessor(ProcessorMixin):
|
|
@@ -42,10 +46,7 @@ class ASRProcessor(ProcessorMixin):
|
|
| 42 |
|
| 43 |
def _compute_encoder_output_length(self, mel_length: int) -> int:
|
| 44 |
"""Compute encoder output length using conv layer formulas."""
|
| 45 |
-
|
| 46 |
-
for padding, kernel_size, stride in self.encoder_conv_layers:
|
| 47 |
-
length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
| 48 |
-
return length
|
| 49 |
|
| 50 |
def __call__(
|
| 51 |
self,
|
|
|
|
| 5 |
from transformers import ProcessorMixin
|
| 6 |
|
| 7 |
try:
|
| 8 |
+
from .asr_config import DEFAULT_ENCODER_CONV_LAYERS, ASRConfig, compute_encoder_output_length
|
| 9 |
except ImportError:
|
| 10 |
+
from asr_config import ( # type: ignore[no-redef]
|
| 11 |
+
DEFAULT_ENCODER_CONV_LAYERS,
|
| 12 |
+
ASRConfig,
|
| 13 |
+
compute_encoder_output_length,
|
| 14 |
+
)
|
| 15 |
|
| 16 |
|
| 17 |
class ASRProcessor(ProcessorMixin):
|
|
|
|
| 46 |
|
| 47 |
def _compute_encoder_output_length(self, mel_length: int) -> int:
|
| 48 |
"""Compute encoder output length using conv layer formulas."""
|
| 49 |
+
return compute_encoder_output_length(mel_length, self.encoder_conv_layers)
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def __call__(
|
| 52 |
self,
|
projectors.py
CHANGED
|
@@ -43,6 +43,11 @@ class MLPAudioProjector(nn.Module):
|
|
| 43 |
self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
|
| 44 |
self.act = nn.GELU()
|
| 45 |
self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
def get_output_length(self, input_length: int) -> int:
|
| 48 |
"""Calculate output sequence length given input length (matches GLM-ASR)."""
|
|
@@ -62,7 +67,8 @@ class MLPAudioProjector(nn.Module):
|
|
| 62 |
x = self.linear_1(x)
|
| 63 |
x = self.norm(x)
|
| 64 |
x = self.act(x)
|
| 65 |
-
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
# =============================================================================
|
|
@@ -102,6 +108,12 @@ class MOSAProjector(nn.Module):
|
|
| 102 |
Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total).
|
| 103 |
"""
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
def __init__(self, config):
|
| 106 |
"""Initialize MOSA projector.
|
| 107 |
|
|
@@ -112,31 +124,28 @@ class MOSAProjector(nn.Module):
|
|
| 112 |
self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
|
| 113 |
self.llm_dim = getattr(config, "llm_dim", None) or 2048
|
| 114 |
self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
|
| 115 |
-
adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
|
| 116 |
-
router_hidden = getattr(config, "router_hidden_dim", None) or 512
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
| 120 |
self.downsampler = nn.Sequential(
|
| 121 |
-
nn.Conv1d(self.encoder_dim, self.encoder_dim,
|
| 122 |
nn.GELU(),
|
| 123 |
-
nn.Conv1d(self.encoder_dim, self.llm_dim,
|
| 124 |
nn.GELU(),
|
| 125 |
)
|
| 126 |
|
| 127 |
-
# --- 2. Simple Router (MOSA-Base: 2 layers with ReLU) ---
|
| 128 |
-
# Takes downsampled features (llm_dim) -> 512 -> num_experts
|
| 129 |
self.router = nn.Sequential(
|
| 130 |
-
nn.Linear(self.llm_dim,
|
| 131 |
nn.ReLU(),
|
| 132 |
-
nn.Linear(
|
| 133 |
)
|
| 134 |
|
| 135 |
-
# --- 3. Experts (Simple 2-layer GELU adapters) ---
|
| 136 |
-
# Each expert: llm_dim -> hidden -> llm_dim (much smaller than frame-stacking)
|
| 137 |
self.experts = nn.ModuleList(
|
| 138 |
[
|
| 139 |
-
SimpleAdapter(self.llm_dim,
|
| 140 |
for _ in range(self.num_experts)
|
| 141 |
]
|
| 142 |
)
|
|
@@ -150,26 +159,22 @@ class MOSAProjector(nn.Module):
|
|
| 150 |
Returns:
|
| 151 |
Projected features of shape [batch, out_len, llm_dim]
|
| 152 |
"""
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
x = x.transpose(1, 2)
|
| 156 |
-
x = self.downsampler(x)
|
| 157 |
-
# Permute back: [B, D, S] -> [B, S, D]
|
| 158 |
-
x = x.transpose(1, 2)
|
| 159 |
-
|
| 160 |
-
# --- 2. Routing ---
|
| 161 |
routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
|
| 162 |
|
| 163 |
-
#
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
| 166 |
|
| 167 |
def get_output_length(self, input_length: int) -> int:
|
| 168 |
"""Calculate output sequence length after Conv1d downsampling (4x reduction)."""
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
return
|
| 173 |
|
| 174 |
|
| 175 |
# =============================================================================
|
|
@@ -414,10 +419,13 @@ class QFormerAudioProjector(nn.Module):
|
|
| 414 |
# Final projection to LLM dimension (Granite uses bias=True)
|
| 415 |
self.linear = nn.Linear(qformer_hidden, llm_dim)
|
| 416 |
|
| 417 |
-
def get_output_length(self, input_length
|
| 418 |
-
"""Calculate output sequence length given input length.
|
| 419 |
-
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
| 421 |
return nblocks * self.num_queries
|
| 422 |
|
| 423 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 43 |
self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
|
| 44 |
self.act = nn.GELU()
|
| 45 |
self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
|
| 46 |
+
# Output norm aligns the projector's RMS with the LM's embed_tokens
|
| 47 |
+
# distribution. Without it, linear_2's Kaiming-uniform init produces
|
| 48 |
+
# outputs ~30× quieter than embed rows, which saturates softmax at
|
| 49 |
+
# audio positions and starves them of gradient.
|
| 50 |
+
self.norm_2 = LlamaRMSNorm(llm_dim, eps=1e-6)
|
| 51 |
|
| 52 |
def get_output_length(self, input_length: int) -> int:
|
| 53 |
"""Calculate output sequence length given input length (matches GLM-ASR)."""
|
|
|
|
| 67 |
x = self.linear_1(x)
|
| 68 |
x = self.norm(x)
|
| 69 |
x = self.act(x)
|
| 70 |
+
x = self.linear_2(x)
|
| 71 |
+
return self.norm_2(x)
|
| 72 |
|
| 73 |
|
| 74 |
# =============================================================================
|
|
|
|
| 108 |
Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total).
|
| 109 |
"""
|
| 110 |
|
| 111 |
+
ADAPTER_HIDDEN_DIM = 4096
|
| 112 |
+
ROUTER_HIDDEN_DIM = 512
|
| 113 |
+
CONV_KERNEL = 3
|
| 114 |
+
CONV_STRIDE = 2
|
| 115 |
+
CONV_PADDING = 1
|
| 116 |
+
|
| 117 |
def __init__(self, config):
|
| 118 |
"""Initialize MOSA projector.
|
| 119 |
|
|
|
|
| 124 |
self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
|
| 125 |
self.llm_dim = getattr(config, "llm_dim", None) or 2048
|
| 126 |
self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
conv_kwargs = {
|
| 129 |
+
"kernel_size": self.CONV_KERNEL,
|
| 130 |
+
"stride": self.CONV_STRIDE,
|
| 131 |
+
"padding": self.CONV_PADDING,
|
| 132 |
+
}
|
| 133 |
self.downsampler = nn.Sequential(
|
| 134 |
+
nn.Conv1d(self.encoder_dim, self.encoder_dim, **conv_kwargs),
|
| 135 |
nn.GELU(),
|
| 136 |
+
nn.Conv1d(self.encoder_dim, self.llm_dim, **conv_kwargs),
|
| 137 |
nn.GELU(),
|
| 138 |
)
|
| 139 |
|
|
|
|
|
|
|
| 140 |
self.router = nn.Sequential(
|
| 141 |
+
nn.Linear(self.llm_dim, self.ROUTER_HIDDEN_DIM),
|
| 142 |
nn.ReLU(),
|
| 143 |
+
nn.Linear(self.ROUTER_HIDDEN_DIM, self.num_experts),
|
| 144 |
)
|
| 145 |
|
|
|
|
|
|
|
| 146 |
self.experts = nn.ModuleList(
|
| 147 |
[
|
| 148 |
+
SimpleAdapter(self.llm_dim, self.ADAPTER_HIDDEN_DIM, self.llm_dim)
|
| 149 |
for _ in range(self.num_experts)
|
| 150 |
]
|
| 151 |
)
|
|
|
|
| 159 |
Returns:
|
| 160 |
Projected features of shape [batch, out_len, llm_dim]
|
| 161 |
"""
|
| 162 |
+
x = self.downsampler(x.transpose(1, 2)).transpose(1, 2)
|
| 163 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
|
| 165 |
|
| 166 |
+
# Accumulate weighted expert outputs without materializing all experts at once.
|
| 167 |
+
output = self.experts[0](x) * routing_weights[..., 0:1]
|
| 168 |
+
for i, expert in enumerate(self.experts[1:], start=1):
|
| 169 |
+
output = output + expert(x) * routing_weights[..., i : i + 1]
|
| 170 |
+
return output
|
| 171 |
|
| 172 |
def get_output_length(self, input_length: int) -> int:
|
| 173 |
"""Calculate output sequence length after Conv1d downsampling (4x reduction)."""
|
| 174 |
+
length = input_length
|
| 175 |
+
for _ in range(2):
|
| 176 |
+
length = (length + 2 * self.CONV_PADDING - self.CONV_KERNEL) // self.CONV_STRIDE + 1
|
| 177 |
+
return length
|
| 178 |
|
| 179 |
|
| 180 |
# =============================================================================
|
|
|
|
| 419 |
# Final projection to LLM dimension (Granite uses bias=True)
|
| 420 |
self.linear = nn.Linear(qformer_hidden, llm_dim)
|
| 421 |
|
| 422 |
+
def get_output_length(self, input_length):
|
| 423 |
+
"""Calculate output sequence length given input length.
|
| 424 |
+
|
| 425 |
+
Accepts either Python ints or torch tensors; uses ceiling division so
|
| 426 |
+
the formula is identical for both — math.ceil would block tensors.
|
| 427 |
+
"""
|
| 428 |
+
nblocks = (input_length + self.window_size - 1) // self.window_size
|
| 429 |
return nblocks * self.num_queries
|
| 430 |
|
| 431 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|