Training in progress - step 1000
Browse files- alignment.py +6 -3
- asr_config.py +16 -24
- asr_modeling.py +54 -116
- asr_pipeline.py +19 -11
- asr_processing.py +3 -5
- diarization.py +0 -2
- projectors.py +15 -41
alignment.py
CHANGED
|
@@ -120,16 +120,19 @@ class ForcedAligner:
|
|
| 120 |
|
| 121 |
if move_score >= stay_score:
|
| 122 |
# Token j-1 was emitted at frame t-1
|
| 123 |
-
token_frames[j - 1].
|
| 124 |
j -= 1
|
| 125 |
-
# Always decrement time (monotonic)
|
| 126 |
t -= 1
|
| 127 |
|
| 128 |
# Handle any remaining tokens at the start (edge case)
|
| 129 |
while j > 0:
|
| 130 |
-
token_frames[j - 1].
|
| 131 |
j -= 1
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
# Convert to spans
|
| 134 |
token_spans: list[tuple[int, float, float]] = []
|
| 135 |
for token_idx, frames in enumerate(token_frames):
|
|
|
|
| 120 |
|
| 121 |
if move_score >= stay_score:
|
| 122 |
# Token j-1 was emitted at frame t-1
|
| 123 |
+
token_frames[j - 1].append(t - 1)
|
| 124 |
j -= 1
|
|
|
|
| 125 |
t -= 1
|
| 126 |
|
| 127 |
# Handle any remaining tokens at the start (edge case)
|
| 128 |
while j > 0:
|
| 129 |
+
token_frames[j - 1].append(0)
|
| 130 |
j -= 1
|
| 131 |
|
| 132 |
+
# We appended in reverse-time order; restore monotonic order
|
| 133 |
+
for frames in token_frames:
|
| 134 |
+
frames.reverse()
|
| 135 |
+
|
| 136 |
# Convert to spans
|
| 137 |
token_spans: list[tuple[int, float, float]] = []
|
| 138 |
for token_idx, frames in enumerate(token_frames):
|
asr_config.py
CHANGED
|
@@ -2,6 +2,9 @@ from typing import Optional
|
|
| 2 |
|
| 3 |
import transformers
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
class ASRConfig(transformers.PretrainedConfig):
|
| 7 |
"""Configuration class for the ASR model.
|
|
@@ -107,8 +110,7 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 107 |
self.system_prompt = system_prompt
|
| 108 |
self.encoder_dim = encoder_dim
|
| 109 |
self.llm_dim = llm_dim
|
| 110 |
-
|
| 111 |
-
self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
|
| 112 |
self.audio_sample_rate = audio_sample_rate
|
| 113 |
self.projector_init_std = projector_init_std
|
| 114 |
self.projector_pool_stride = projector_pool_stride
|
|
@@ -151,28 +153,18 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 151 |
]
|
| 152 |
self.freeze_projector = freeze_projector
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
else
|
| 166 |
-
)
|
| 167 |
-
self.length_penalty = (
|
| 168 |
-
length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
|
| 169 |
-
)
|
| 170 |
-
self.no_repeat_ngram_size = (
|
| 171 |
-
no_repeat_ngram_size
|
| 172 |
-
if no_repeat_ngram_size is not None
|
| 173 |
-
else generation_defaults["no_repeat_ngram_size"]
|
| 174 |
-
)
|
| 175 |
-
self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
|
| 176 |
self.do_sample = do_sample
|
| 177 |
self.temperature = temperature
|
| 178 |
self.top_p = top_p
|
|
|
|
| 2 |
|
| 3 |
import transformers
|
| 4 |
|
| 5 |
+
# Default conv layers for Whisper/GLM-ASR audio encoders: [(pad, kernel, stride), ...]
|
| 6 |
+
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
|
| 7 |
+
|
| 8 |
|
| 9 |
class ASRConfig(transformers.PretrainedConfig):
|
| 10 |
"""Configuration class for the ASR model.
|
|
|
|
| 110 |
self.system_prompt = system_prompt
|
| 111 |
self.encoder_dim = encoder_dim
|
| 112 |
self.llm_dim = llm_dim
|
| 113 |
+
self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
|
|
|
|
| 114 |
self.audio_sample_rate = audio_sample_rate
|
| 115 |
self.projector_init_std = projector_init_std
|
| 116 |
self.projector_pool_stride = projector_pool_stride
|
|
|
|
| 153 |
]
|
| 154 |
self.freeze_projector = freeze_projector
|
| 155 |
|
| 156 |
+
explicit_generation_args = {
|
| 157 |
+
"num_beams": num_beams,
|
| 158 |
+
"max_new_tokens": max_new_tokens,
|
| 159 |
+
"min_new_tokens": min_new_tokens,
|
| 160 |
+
"repetition_penalty": repetition_penalty,
|
| 161 |
+
"length_penalty": length_penalty,
|
| 162 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
| 163 |
+
"use_cache": use_cache,
|
| 164 |
+
}
|
| 165 |
+
for key, default in generation_defaults.items():
|
| 166 |
+
value = explicit_generation_args[key]
|
| 167 |
+
setattr(self, key, value if value is not None else default)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
self.do_sample = do_sample
|
| 169 |
self.temperature = temperature
|
| 170 |
self.top_p = top_p
|
asr_modeling.py
CHANGED
|
@@ -5,8 +5,8 @@ from typing import Iterator, Optional, Union
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
|
|
|
| 8 |
from transformers import (
|
| 9 |
-
AutoConfig,
|
| 10 |
AutoModel,
|
| 11 |
AutoModelForCausalLM,
|
| 12 |
AutoTokenizer,
|
|
@@ -27,6 +27,26 @@ except ImportError:
|
|
| 27 |
from torchaudio.transforms import SpecAugment
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
class ASRModel(PreTrainedModel, GenerationMixin):
|
| 31 |
"""Audio-to-text model combining an audio encoder, projector, and language model."""
|
| 32 |
|
|
@@ -402,61 +422,25 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 402 |
def _encode_audio(
|
| 403 |
self,
|
| 404 |
audio_features: torch.Tensor,
|
| 405 |
-
|
| 406 |
-
expected_token_counts: torch.Tensor | None = None,
|
| 407 |
) -> torch.Tensor:
|
| 408 |
-
"""Encode audio and
|
| 409 |
|
| 410 |
Args:
|
| 411 |
audio_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 412 |
-
|
| 413 |
-
expected_token_counts: Expected number of audio tokens per sample from input_ids.
|
| 414 |
-
If provided, output will match these counts exactly (padding/truncating as needed).
|
| 415 |
|
| 416 |
Returns:
|
| 417 |
-
Flattened audio embeddings of shape (
|
| 418 |
"""
|
| 419 |
with torch.no_grad():
|
| 420 |
encoder_out = self.audio_tower(input_features=audio_features)
|
| 421 |
hidden_states = encoder_out.last_hidden_state
|
| 422 |
|
| 423 |
-
# Project to LLM space
|
| 424 |
audio_embeds = self.projector(hidden_states)
|
| 425 |
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
token_counts = expected_token_counts
|
| 429 |
-
else:
|
| 430 |
-
# Compute per-sample encoder output lengths using conv formulas
|
| 431 |
-
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 432 |
-
token_counts = torch.tensor(
|
| 433 |
-
[
|
| 434 |
-
self.projector.get_output_length(int(length.item()))
|
| 435 |
-
for length in encoder_lengths
|
| 436 |
-
],
|
| 437 |
-
device=audio_embeds.device,
|
| 438 |
-
)
|
| 439 |
-
|
| 440 |
-
# Extract embeddings matching expected token counts per sample
|
| 441 |
-
batch_size = audio_embeds.shape[0]
|
| 442 |
-
hidden_dim = audio_embeds.shape[2]
|
| 443 |
-
|
| 444 |
-
result_embeds = []
|
| 445 |
-
for i in range(batch_size):
|
| 446 |
-
count = int(token_counts[i].item())
|
| 447 |
-
sample_embeds = audio_embeds[i, :count, :] # Take first 'count' embeddings
|
| 448 |
-
# Pad with zeros if we don't have enough embeddings
|
| 449 |
-
if sample_embeds.shape[0] < count:
|
| 450 |
-
padding = torch.zeros(
|
| 451 |
-
count - sample_embeds.shape[0],
|
| 452 |
-
hidden_dim,
|
| 453 |
-
device=audio_embeds.device,
|
| 454 |
-
dtype=audio_embeds.dtype,
|
| 455 |
-
)
|
| 456 |
-
sample_embeds = torch.cat([sample_embeds, padding], dim=0)
|
| 457 |
-
result_embeds.append(sample_embeds)
|
| 458 |
-
|
| 459 |
-
return torch.cat(result_embeds, dim=0)
|
| 460 |
|
| 461 |
def forward(
|
| 462 |
self,
|
|
@@ -470,34 +454,33 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 470 |
labels: Optional[torch.Tensor] = None,
|
| 471 |
use_cache: Optional[bool] = None,
|
| 472 |
cache_position: Optional[torch.Tensor] = None,
|
|
|
|
| 473 |
**kwargs,
|
| 474 |
) -> CausalLMOutputWithPast:
|
| 475 |
"""Forward pass for training and inference."""
|
| 476 |
-
# Get text embeddings if not provided
|
| 477 |
if inputs_embeds is None:
|
| 478 |
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 479 |
|
| 480 |
if input_features is not None and input_ids is not None:
|
| 481 |
-
# Apply SpecAugment during training if enabled
|
| 482 |
if self.training and self.spec_augment is not None:
|
| 483 |
input_features = self.spec_augment(input_features)
|
| 484 |
|
| 485 |
-
|
| 486 |
-
audio_token_counts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
-
|
| 489 |
-
audio_embeds = self._encode_audio(
|
| 490 |
-
input_features, audio_attention_mask, audio_token_counts
|
| 491 |
-
)
|
| 492 |
|
| 493 |
-
|
| 494 |
-
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 495 |
inputs_embeds = inputs_embeds.masked_scatter(
|
| 496 |
audio_token_mask.to(inputs_embeds.device),
|
| 497 |
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 498 |
)
|
| 499 |
|
| 500 |
-
# Run through language model (let it compute loss if labels provided)
|
| 501 |
outputs = self.language_model(
|
| 502 |
attention_mask=attention_mask,
|
| 503 |
position_ids=position_ids,
|
|
@@ -509,7 +492,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 509 |
**kwargs,
|
| 510 |
)
|
| 511 |
|
| 512 |
-
# Add auxiliary loss from MoE projectors if available
|
| 513 |
if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
|
| 514 |
aux_loss = self.projector.get_aux_loss()
|
| 515 |
if aux_loss is not None and aux_loss.numel() > 0:
|
|
@@ -569,7 +551,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 569 |
batch_size = input_features.shape[0]
|
| 570 |
|
| 571 |
# Encode audio -> flattened embeddings
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
# If input_ids not provided, build prompt with correct number of audio tokens
|
| 575 |
if input_ids is None:
|
|
@@ -653,7 +641,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 653 |
batch_size = input_features.shape[0]
|
| 654 |
|
| 655 |
# Encode audio -> flattened embeddings
|
| 656 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
|
| 658 |
# Build prompt with correct number of audio tokens
|
| 659 |
num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
|
|
@@ -747,63 +741,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 747 |
|
| 748 |
thread.join()
|
| 749 |
|
| 750 |
-
@torch.no_grad()
|
| 751 |
-
def generate_text_only(
|
| 752 |
-
self,
|
| 753 |
-
messages: list[dict[str, str]],
|
| 754 |
-
max_new_tokens: int = 256,
|
| 755 |
-
**generate_kwargs,
|
| 756 |
-
) -> str:
|
| 757 |
-
"""Generate text using only the LLM (no audio encoding).
|
| 758 |
-
|
| 759 |
-
Used for SIFT-style response generation from metadata prompts.
|
| 760 |
-
|
| 761 |
-
Args:
|
| 762 |
-
messages: List of chat messages [{"role": "user", "content": "..."}]
|
| 763 |
-
max_new_tokens: Maximum tokens to generate
|
| 764 |
-
**generate_kwargs: Additional generation arguments
|
| 765 |
-
|
| 766 |
-
Returns:
|
| 767 |
-
Generated text response
|
| 768 |
-
"""
|
| 769 |
-
device = next(self.language_model.parameters()).device
|
| 770 |
-
|
| 771 |
-
# Apply chat template
|
| 772 |
-
input_ids = self.tokenizer.apply_chat_template(
|
| 773 |
-
messages,
|
| 774 |
-
tokenize=True,
|
| 775 |
-
add_generation_prompt=True,
|
| 776 |
-
return_tensors="pt",
|
| 777 |
-
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 778 |
-
).to(device)
|
| 779 |
-
|
| 780 |
-
if input_ids.dim() == 1:
|
| 781 |
-
input_ids = input_ids.unsqueeze(0)
|
| 782 |
-
|
| 783 |
-
attention_mask = torch.ones_like(input_ids)
|
| 784 |
-
|
| 785 |
-
# Generate using language model directly
|
| 786 |
-
output = self.language_model.generate(
|
| 787 |
-
input_ids=input_ids,
|
| 788 |
-
attention_mask=attention_mask,
|
| 789 |
-
max_new_tokens=max_new_tokens,
|
| 790 |
-
do_sample=False,
|
| 791 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
| 792 |
-
eos_token_id=self.tokenizer.eos_token_id,
|
| 793 |
-
**generate_kwargs,
|
| 794 |
-
)
|
| 795 |
-
|
| 796 |
-
# Decode only the new tokens
|
| 797 |
-
new_tokens = output[0, input_ids.shape[1] :]
|
| 798 |
-
response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 799 |
-
return response.strip()
|
| 800 |
-
|
| 801 |
def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
|
| 802 |
"""Save model, tokenizer, and processor."""
|
| 803 |
import shutil
|
| 804 |
-
from pathlib import Path as PathlibPath
|
| 805 |
|
| 806 |
-
save_dir =
|
| 807 |
save_dir.mkdir(parents=True, exist_ok=True)
|
| 808 |
|
| 809 |
# Update config with actual vocab size
|
|
@@ -874,7 +816,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 874 |
json.dump(processor_config, f, indent=2)
|
| 875 |
|
| 876 |
# Copy source files for auto-loading
|
| 877 |
-
src_dir =
|
| 878 |
for asr_file in src_dir.glob("asr_*.py"):
|
| 879 |
shutil.copy(asr_file, save_dir / asr_file.name)
|
| 880 |
# Copy projectors module
|
|
@@ -896,11 +838,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 896 |
# Call parent's push_to_hub
|
| 897 |
return super().push_to_hub(repo_id, **kwargs)
|
| 898 |
|
| 899 |
-
def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
|
| 900 |
-
"""No-op for model card creation - we use MODEL_CARD.md in repo instead."""
|
| 901 |
-
pass
|
| 902 |
-
|
| 903 |
|
| 904 |
# Register with transformers Auto classes
|
| 905 |
-
AutoConfig.register
|
| 906 |
AutoModel.register(ASRConfig, ASRModel)
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F # noqa: N812
|
| 9 |
from transformers import (
|
|
|
|
| 10 |
AutoModel,
|
| 11 |
AutoModelForCausalLM,
|
| 12 |
AutoTokenizer,
|
|
|
|
| 27 |
from torchaudio.transforms import SpecAugment
|
| 28 |
|
| 29 |
|
| 30 |
+
def _gather_audio_embeds(audio_embeds: torch.Tensor, token_counts: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
"""Flatten per-sample audio embeddings into a packed tensor.
|
| 32 |
+
|
| 33 |
+
For each row i, takes the first ``token_counts[i]`` rows of
|
| 34 |
+
``audio_embeds[i]`` and concatenates them. If any token count exceeds
|
| 35 |
+
``audio_embeds.shape[1]``, the deficit is zero-padded.
|
| 36 |
+
|
| 37 |
+
Equivalent to a per-sample slice/cat loop but with O(1) host-device
|
| 38 |
+
syncs per call (one ``max().item()``) instead of one per sample.
|
| 39 |
+
"""
|
| 40 |
+
_, max_len, _ = audio_embeds.shape
|
| 41 |
+
needed = int(token_counts.max().item())
|
| 42 |
+
if needed > max_len:
|
| 43 |
+
audio_embeds = F.pad(audio_embeds, (0, 0, 0, needed - max_len))
|
| 44 |
+
max_len = needed
|
| 45 |
+
indices = torch.arange(max_len, device=audio_embeds.device).unsqueeze(0)
|
| 46 |
+
mask = indices < token_counts.unsqueeze(1)
|
| 47 |
+
return audio_embeds[mask]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
class ASRModel(PreTrainedModel, GenerationMixin):
|
| 51 |
"""Audio-to-text model combining an audio encoder, projector, and language model."""
|
| 52 |
|
|
|
|
| 422 |
def _encode_audio(
|
| 423 |
self,
|
| 424 |
audio_features: torch.Tensor,
|
| 425 |
+
expected_token_counts: torch.Tensor,
|
|
|
|
| 426 |
) -> torch.Tensor:
|
| 427 |
+
"""Encode audio features and return flattened embeddings matching expected_token_counts.
|
| 428 |
|
| 429 |
Args:
|
| 430 |
audio_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 431 |
+
expected_token_counts: Per-sample audio token counts as int64 tensor (batch,).
|
|
|
|
|
|
|
| 432 |
|
| 433 |
Returns:
|
| 434 |
+
Flattened audio embeddings of shape (sum(expected_token_counts), hidden_dim).
|
| 435 |
"""
|
| 436 |
with torch.no_grad():
|
| 437 |
encoder_out = self.audio_tower(input_features=audio_features)
|
| 438 |
hidden_states = encoder_out.last_hidden_state
|
| 439 |
|
|
|
|
| 440 |
audio_embeds = self.projector(hidden_states)
|
| 441 |
|
| 442 |
+
token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
|
| 443 |
+
return _gather_audio_embeds(audio_embeds, token_counts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
def forward(
|
| 446 |
self,
|
|
|
|
| 454 |
labels: Optional[torch.Tensor] = None,
|
| 455 |
use_cache: Optional[bool] = None,
|
| 456 |
cache_position: Optional[torch.Tensor] = None,
|
| 457 |
+
audio_token_counts: Optional[torch.Tensor] = None,
|
| 458 |
**kwargs,
|
| 459 |
) -> CausalLMOutputWithPast:
|
| 460 |
"""Forward pass for training and inference."""
|
|
|
|
| 461 |
if inputs_embeds is None:
|
| 462 |
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 463 |
|
| 464 |
if input_features is not None and input_ids is not None:
|
|
|
|
| 465 |
if self.training and self.spec_augment is not None:
|
| 466 |
input_features = self.spec_augment(input_features)
|
| 467 |
|
| 468 |
+
is_audio_token = input_ids == self.audio_token_id
|
| 469 |
+
if audio_token_counts is None:
|
| 470 |
+
audio_token_counts = is_audio_token.sum(dim=-1)
|
| 471 |
+
else:
|
| 472 |
+
audio_token_counts = audio_token_counts.to(
|
| 473 |
+
device=input_ids.device, dtype=torch.long
|
| 474 |
+
)
|
| 475 |
|
| 476 |
+
audio_embeds = self._encode_audio(input_features, audio_token_counts)
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
+
audio_token_mask = is_audio_token.unsqueeze(-1)
|
|
|
|
| 479 |
inputs_embeds = inputs_embeds.masked_scatter(
|
| 480 |
audio_token_mask.to(inputs_embeds.device),
|
| 481 |
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 482 |
)
|
| 483 |
|
|
|
|
| 484 |
outputs = self.language_model(
|
| 485 |
attention_mask=attention_mask,
|
| 486 |
position_ids=position_ids,
|
|
|
|
| 492 |
**kwargs,
|
| 493 |
)
|
| 494 |
|
|
|
|
| 495 |
if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
|
| 496 |
aux_loss = self.projector.get_aux_loss()
|
| 497 |
if aux_loss is not None and aux_loss.numel() > 0:
|
|
|
|
| 551 |
batch_size = input_features.shape[0]
|
| 552 |
|
| 553 |
# Encode audio -> flattened embeddings
|
| 554 |
+
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 555 |
+
token_counts = torch.tensor(
|
| 556 |
+
[self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
|
| 557 |
+
device=input_features.device,
|
| 558 |
+
dtype=torch.long,
|
| 559 |
+
)
|
| 560 |
+
audio_embeds = self._encode_audio(input_features, token_counts)
|
| 561 |
|
| 562 |
# If input_ids not provided, build prompt with correct number of audio tokens
|
| 563 |
if input_ids is None:
|
|
|
|
| 641 |
batch_size = input_features.shape[0]
|
| 642 |
|
| 643 |
# Encode audio -> flattened embeddings
|
| 644 |
+
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 645 |
+
token_counts = torch.tensor(
|
| 646 |
+
[self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
|
| 647 |
+
device=input_features.device,
|
| 648 |
+
dtype=torch.long,
|
| 649 |
+
)
|
| 650 |
+
audio_embeds = self._encode_audio(input_features, token_counts)
|
| 651 |
|
| 652 |
# Build prompt with correct number of audio tokens
|
| 653 |
num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
|
|
|
|
| 741 |
|
| 742 |
thread.join()
|
| 743 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
|
| 745 |
"""Save model, tokenizer, and processor."""
|
| 746 |
import shutil
|
|
|
|
| 747 |
|
| 748 |
+
save_dir = Path(save_directory)
|
| 749 |
save_dir.mkdir(parents=True, exist_ok=True)
|
| 750 |
|
| 751 |
# Update config with actual vocab size
|
|
|
|
| 816 |
json.dump(processor_config, f, indent=2)
|
| 817 |
|
| 818 |
# Copy source files for auto-loading
|
| 819 |
+
src_dir = Path(__file__).parent
|
| 820 |
for asr_file in src_dir.glob("asr_*.py"):
|
| 821 |
shutil.copy(asr_file, save_dir / asr_file.name)
|
| 822 |
# Copy projectors module
|
|
|
|
| 838 |
# Call parent's push_to_hub
|
| 839 |
return super().push_to_hub(repo_id, **kwargs)
|
| 840 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
|
| 842 |
# Register with transformers Auto classes
|
| 843 |
+
# (AutoConfig.register is handled in asr_config.py at module load.)
|
| 844 |
AutoModel.register(ASRConfig, ASRModel)
|
asr_pipeline.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import Any
|
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
import transformers
|
|
|
|
| 10 |
|
| 11 |
try:
|
| 12 |
from .alignment import ForcedAligner
|
|
@@ -20,6 +21,13 @@ except ImportError:
|
|
| 20 |
# Re-export for backwards compatibility
|
| 21 |
__all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
| 25 |
"""ASR Pipeline for audio-to-text transcription."""
|
|
@@ -152,8 +160,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 152 |
|
| 153 |
def _extract_audio(self, inputs) -> dict | None:
|
| 154 |
"""Extract audio array from various input formats using HF utilities."""
|
| 155 |
-
from transformers.pipelines.audio_utils import ffmpeg_read
|
| 156 |
-
|
| 157 |
if isinstance(inputs, dict):
|
| 158 |
if "array" in inputs:
|
| 159 |
return {
|
|
@@ -257,8 +263,8 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 257 |
|
| 258 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 259 |
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 260 |
-
|
| 261 |
-
|
| 262 |
text = _truncate_repetitions(text)
|
| 263 |
return {"text": text}
|
| 264 |
|
|
@@ -281,14 +287,16 @@ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
|
|
| 281 |
if not text:
|
| 282 |
return text
|
| 283 |
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
-
|
| 289 |
-
word_pattern = re.compile(
|
| 290 |
-
r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
|
| 291 |
-
)
|
| 292 |
while word_pattern.search(text):
|
| 293 |
text = word_pattern.sub(r"\1", text)
|
| 294 |
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
import transformers
|
| 10 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
| 11 |
|
| 12 |
try:
|
| 13 |
from .alignment import ForcedAligner
|
|
|
|
| 21 |
# Re-export for backwards compatibility
|
| 22 |
__all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
|
| 23 |
|
| 24 |
+
_THINK_TAG_RE = re.compile(r"<think>.*?</think>\s*", flags=re.DOTALL)
|
| 25 |
+
_DEFAULT_MIN_REPEATS = 3
|
| 26 |
+
_TRAILING_CHAR_RE = re.compile(r"(.)\1{" + str(_DEFAULT_MIN_REPEATS - 1) + r",}$")
|
| 27 |
+
_TRAILING_WORD_RE = re.compile(
|
| 28 |
+
r"\b(\w+)(?:\s+\1){" + str(_DEFAULT_MIN_REPEATS - 1) + r",}\s*$", re.IGNORECASE
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
|
| 32 |
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
| 33 |
"""ASR Pipeline for audio-to-text transcription."""
|
|
|
|
| 160 |
|
| 161 |
def _extract_audio(self, inputs) -> dict | None:
|
| 162 |
"""Extract audio array from various input formats using HF utilities."""
|
|
|
|
|
|
|
| 163 |
if isinstance(inputs, dict):
|
| 164 |
if "array" in inputs:
|
| 165 |
return {
|
|
|
|
| 263 |
|
| 264 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 265 |
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 266 |
+
if "<think>" in text:
|
| 267 |
+
text = _THINK_TAG_RE.sub("", text).strip()
|
| 268 |
text = _truncate_repetitions(text)
|
| 269 |
return {"text": text}
|
| 270 |
|
|
|
|
| 287 |
if not text:
|
| 288 |
return text
|
| 289 |
|
| 290 |
+
if min_repeats == _DEFAULT_MIN_REPEATS:
|
| 291 |
+
char_pattern = _TRAILING_CHAR_RE
|
| 292 |
+
word_pattern = _TRAILING_WORD_RE
|
| 293 |
+
else:
|
| 294 |
+
char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$")
|
| 295 |
+
word_pattern = re.compile(
|
| 296 |
+
r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
|
| 297 |
+
)
|
| 298 |
|
| 299 |
+
text = char_pattern.sub(r"\1", text)
|
|
|
|
|
|
|
|
|
|
| 300 |
while word_pattern.search(text):
|
| 301 |
text = word_pattern.sub(r"\1", text)
|
| 302 |
|
asr_processing.py
CHANGED
|
@@ -5,9 +5,9 @@ import transformers
|
|
| 5 |
from transformers import ProcessorMixin
|
| 6 |
|
| 7 |
try:
|
| 8 |
-
from .asr_config import ASRConfig
|
| 9 |
except ImportError:
|
| 10 |
-
from asr_config import ASRConfig # type: ignore[no-redef]
|
| 11 |
|
| 12 |
|
| 13 |
class ASRProcessor(ProcessorMixin):
|
|
@@ -18,8 +18,6 @@ class ASRProcessor(ProcessorMixin):
|
|
| 18 |
tokenizer_class = "AutoTokenizer"
|
| 19 |
AUDIO_TOKEN = "<audio>"
|
| 20 |
TRANSCRIBE_PROMPT = ""
|
| 21 |
-
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
|
| 22 |
-
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
|
| 23 |
|
| 24 |
def __init__(
|
| 25 |
self,
|
|
@@ -40,7 +38,7 @@ class ASRProcessor(ProcessorMixin):
|
|
| 40 |
self.tokenizer = tokenizer
|
| 41 |
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
|
| 42 |
self.projector = projector
|
| 43 |
-
self.encoder_conv_layers = encoder_conv_layers or
|
| 44 |
|
| 45 |
def _compute_encoder_output_length(self, mel_length: int) -> int:
|
| 46 |
"""Compute encoder output length using conv layer formulas."""
|
|
|
|
| 5 |
from transformers import ProcessorMixin
|
| 6 |
|
| 7 |
try:
|
| 8 |
+
from .asr_config import DEFAULT_ENCODER_CONV_LAYERS, ASRConfig
|
| 9 |
except ImportError:
|
| 10 |
+
from asr_config import DEFAULT_ENCODER_CONV_LAYERS, ASRConfig # type: ignore[no-redef]
|
| 11 |
|
| 12 |
|
| 13 |
class ASRProcessor(ProcessorMixin):
|
|
|
|
| 18 |
tokenizer_class = "AutoTokenizer"
|
| 19 |
AUDIO_TOKEN = "<audio>"
|
| 20 |
TRANSCRIBE_PROMPT = ""
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def __init__(
|
| 23 |
self,
|
|
|
|
| 38 |
self.tokenizer = tokenizer
|
| 39 |
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
|
| 40 |
self.projector = projector
|
| 41 |
+
self.encoder_conv_layers = encoder_conv_layers or DEFAULT_ENCODER_CONV_LAYERS
|
| 42 |
|
| 43 |
def _compute_encoder_output_length(self, mel_length: int) -> int:
|
| 44 |
"""Compute encoder output length using conv layer formulas."""
|
diarization.py
CHANGED
|
@@ -154,8 +154,6 @@ class SpeakerClusterer:
|
|
| 154 |
Returns:
|
| 155 |
Cluster labels of shape [N]
|
| 156 |
"""
|
| 157 |
-
import warnings
|
| 158 |
-
|
| 159 |
if len(embeddings.shape) != 2:
|
| 160 |
raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
|
| 161 |
|
|
|
|
| 154 |
Returns:
|
| 155 |
Cluster labels of shape [N]
|
| 156 |
"""
|
|
|
|
|
|
|
| 157 |
if len(embeddings.shape) != 2:
|
| 158 |
raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
|
| 159 |
|
projectors.py
CHANGED
|
@@ -58,13 +58,7 @@ class MLPAudioProjector(nn.Module):
|
|
| 58 |
Returns:
|
| 59 |
Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
|
| 60 |
"""
|
| 61 |
-
|
| 62 |
-
# Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
|
| 63 |
-
# This drops trailing frames that don't fill a complete k-frame window
|
| 64 |
-
out_len = (seq - self.k) // self.k + 1
|
| 65 |
-
x = x[:, : out_len * self.k, :] # Truncate to exact multiple
|
| 66 |
-
x = x.reshape(batch, out_len, dim * self.k)
|
| 67 |
-
|
| 68 |
x = self.linear_1(x)
|
| 69 |
x = self.norm(x)
|
| 70 |
x = self.act(x)
|
|
@@ -76,6 +70,17 @@ class MLPAudioProjector(nn.Module):
|
|
| 76 |
# =============================================================================
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
class SimpleAdapter(nn.Module):
|
| 80 |
"""Simple 2-layer GELU adapter (from MOSA paper)."""
|
| 81 |
|
|
@@ -89,34 +94,6 @@ class SimpleAdapter(nn.Module):
|
|
| 89 |
return self.fc2(self.act(self.fc1(x)))
|
| 90 |
|
| 91 |
|
| 92 |
-
class SwiGLU(nn.Module):
|
| 93 |
-
"""SwiGLU activation with gated linear units (used in LLaMA, Mistral, etc.)."""
|
| 94 |
-
|
| 95 |
-
def __init__(self, dim: int, hidden_dim: int, bias: bool = False):
|
| 96 |
-
super().__init__()
|
| 97 |
-
self.w1 = nn.Linear(dim, hidden_dim, bias=bias) # Gate
|
| 98 |
-
self.w2 = nn.Linear(dim, hidden_dim, bias=bias) # Value
|
| 99 |
-
self.w3 = nn.Linear(hidden_dim, dim, bias=bias) # Output
|
| 100 |
-
|
| 101 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 102 |
-
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
class AsymmetricSwiGLU(nn.Module):
|
| 106 |
-
"""SwiGLU that handles different input and output dimensions."""
|
| 107 |
-
|
| 108 |
-
def __init__(
|
| 109 |
-
self, in_features: int, hidden_features: int, out_features: int, bias: bool = False
|
| 110 |
-
):
|
| 111 |
-
super().__init__()
|
| 112 |
-
self.w1 = nn.Linear(in_features, hidden_features, bias=bias) # Gate
|
| 113 |
-
self.w2 = nn.Linear(in_features, hidden_features, bias=bias) # Value
|
| 114 |
-
self.w3 = nn.Linear(hidden_features, out_features, bias=bias) # Output
|
| 115 |
-
|
| 116 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 117 |
-
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 118 |
-
|
| 119 |
-
|
| 120 |
class MOSAProjector(nn.Module):
|
| 121 |
"""MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
|
| 122 |
|
|
@@ -281,13 +258,10 @@ class MoEAudioProjector(nn.Module):
|
|
| 281 |
Returns:
|
| 282 |
Projected features of shape [batch, out_len, llm_dim]
|
| 283 |
"""
|
| 284 |
-
|
| 285 |
-
batch,
|
| 286 |
-
out_len = (seq - self.k) // self.k + 1
|
| 287 |
-
x = x[:, : out_len * self.k, :]
|
| 288 |
-
x = x.reshape(batch, out_len, dim * self.k)
|
| 289 |
|
| 290 |
-
#
|
| 291 |
x = self.norm(x)
|
| 292 |
flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
|
| 293 |
|
|
|
|
| 58 |
Returns:
|
| 59 |
Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
|
| 60 |
"""
|
| 61 |
+
x = _frame_stack(x, self.k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
x = self.linear_1(x)
|
| 63 |
x = self.norm(x)
|
| 64 |
x = self.act(x)
|
|
|
|
| 70 |
# =============================================================================
|
| 71 |
|
| 72 |
|
| 73 |
+
def _frame_stack(x: torch.Tensor, k: int) -> torch.Tensor:
|
| 74 |
+
"""Stack k adjacent frames along the feature dim.
|
| 75 |
+
|
| 76 |
+
Truncates trailing frames that don't fill a complete k-frame window,
|
| 77 |
+
matching GLM-ASR's `(seq_len - k) // k + 1` formula.
|
| 78 |
+
"""
|
| 79 |
+
batch, seq, dim = x.shape
|
| 80 |
+
out_len = (seq - k) // k + 1
|
| 81 |
+
return x[:, : out_len * k, :].reshape(batch, out_len, dim * k)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
class SimpleAdapter(nn.Module):
|
| 85 |
"""Simple 2-layer GELU adapter (from MOSA paper)."""
|
| 86 |
|
|
|
|
| 94 |
return self.fc2(self.act(self.fc1(x)))
|
| 95 |
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
class MOSAProjector(nn.Module):
|
| 98 |
"""MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
|
| 99 |
|
|
|
|
| 258 |
Returns:
|
| 259 |
Projected features of shape [batch, out_len, llm_dim]
|
| 260 |
"""
|
| 261 |
+
x = _frame_stack(x, self.k)
|
| 262 |
+
batch, out_len, _ = x.shape
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
+
# Normalize stacked input (like main branch SharedMoEBlock)
|
| 265 |
x = self.norm(x)
|
| 266 |
flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
|
| 267 |
|