Update custom model files, README, and requirements
Browse files- alignment.py +8 -7
- asr_config.py +17 -24
- asr_modeling.py +3 -3
- diarization.py +35 -8
- handler.py +0 -8
alignment.py
CHANGED
|
@@ -73,7 +73,7 @@ class ForcedAligner:
|
|
| 73 |
# Force alignment to use all tokens by preventing staying in blank
|
| 74 |
# at the end when there are still tokens to emit
|
| 75 |
if num_tokens > 1:
|
| 76 |
-
trellis[-num_tokens + 1:, 0] = float("inf")
|
| 77 |
|
| 78 |
for t in range(num_frames):
|
| 79 |
for j in range(num_tokens + 1):
|
|
@@ -113,7 +113,12 @@ class ForcedAligner:
|
|
| 113 |
# Alignment failed - fall back to uniform distribution
|
| 114 |
frames_per_token = num_frames / num_tokens
|
| 115 |
return [
|
| 116 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
for i in range(num_tokens)
|
| 118 |
]
|
| 119 |
|
|
@@ -280,11 +285,7 @@ class ForcedAligner:
|
|
| 280 |
last_char_peak = peak_frame
|
| 281 |
|
| 282 |
# Don't forget the last word
|
| 283 |
-
if (
|
| 284 |
-
first_char_peak is not None
|
| 285 |
-
and last_char_peak is not None
|
| 286 |
-
and word_idx < len(words)
|
| 287 |
-
):
|
| 288 |
start_time = max(0.0, first_char_peak * frame_duration - start_offset)
|
| 289 |
end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
|
| 290 |
word_timestamps.append(
|
|
|
|
| 73 |
# Force alignment to use all tokens by preventing staying in blank
|
| 74 |
# at the end when there are still tokens to emit
|
| 75 |
if num_tokens > 1:
|
| 76 |
+
trellis[-num_tokens + 1 :, 0] = float("inf")
|
| 77 |
|
| 78 |
for t in range(num_frames):
|
| 79 |
for j in range(num_tokens + 1):
|
|
|
|
| 113 |
# Alignment failed - fall back to uniform distribution
|
| 114 |
frames_per_token = num_frames / num_tokens
|
| 115 |
return [
|
| 116 |
+
(
|
| 117 |
+
tokens[i],
|
| 118 |
+
i * frames_per_token,
|
| 119 |
+
(i + 1) * frames_per_token,
|
| 120 |
+
(i + 0.5) * frames_per_token,
|
| 121 |
+
)
|
| 122 |
for i in range(num_tokens)
|
| 123 |
]
|
| 124 |
|
|
|
|
| 285 |
last_char_peak = peak_frame
|
| 286 |
|
| 287 |
# Don't forget the last word
|
| 288 |
+
if first_char_peak is not None and last_char_peak is not None and word_idx < len(words):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
start_time = max(0.0, first_char_peak * frame_duration - start_offset)
|
| 290 |
end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
|
| 291 |
word_timestamps.append(
|
asr_config.py
CHANGED
|
@@ -21,7 +21,7 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 21 |
self,
|
| 22 |
audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
|
| 23 |
text_model_id: str = "Qwen/Qwen3-0.6B",
|
| 24 |
-
attn_implementation: str = "
|
| 25 |
model_dtype: str = "bfloat16",
|
| 26 |
num_beams: Optional[int] = None,
|
| 27 |
system_prompt: str = "You are a helpful assistant.",
|
|
@@ -64,6 +64,7 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 64 |
lora_target_modules: Optional[list] = None, # Default: all linear layers
|
| 65 |
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
|
| 66 |
do_sample: bool = False,
|
|
|
|
| 67 |
temperature: Optional[float] = None,
|
| 68 |
top_p: Optional[float] = None,
|
| 69 |
top_k: Optional[int] = None,
|
|
@@ -80,7 +81,7 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 80 |
Args:
|
| 81 |
audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
|
| 82 |
text_model_id: HuggingFace model ID for text decoder (Qwen)
|
| 83 |
-
attn_implementation: Attention implementation ("
|
| 84 |
model_dtype: Model dtype ("bfloat16", "float16", "float32")
|
| 85 |
projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
|
| 86 |
use_lora: Enable LoRA adapters for Stage 2 fine-tuning
|
|
@@ -151,29 +152,21 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 151 |
]
|
| 152 |
self.freeze_projector = freeze_projector
|
| 153 |
|
| 154 |
-
# Generation parameters
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
)
|
| 162 |
-
self.
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
)
|
| 167 |
-
self.length_penalty = (
|
| 168 |
-
length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
|
| 169 |
-
)
|
| 170 |
-
self.no_repeat_ngram_size = (
|
| 171 |
-
no_repeat_ngram_size
|
| 172 |
-
if no_repeat_ngram_size is not None
|
| 173 |
-
else generation_defaults["no_repeat_ngram_size"]
|
| 174 |
-
)
|
| 175 |
-
self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
|
| 176 |
self.do_sample = do_sample
|
|
|
|
| 177 |
self.temperature = temperature
|
| 178 |
self.top_p = top_p
|
| 179 |
self.top_k = top_k
|
|
|
|
| 21 |
self,
|
| 22 |
audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
|
| 23 |
text_model_id: str = "Qwen/Qwen3-0.6B",
|
| 24 |
+
attn_implementation: str = "sdpa",
|
| 25 |
model_dtype: str = "bfloat16",
|
| 26 |
num_beams: Optional[int] = None,
|
| 27 |
system_prompt: str = "You are a helpful assistant.",
|
|
|
|
| 64 |
lora_target_modules: Optional[list] = None, # Default: all linear layers
|
| 65 |
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
|
| 66 |
do_sample: bool = False,
|
| 67 |
+
enable_thinking: bool = False, # Enable Qwen3 thinking mode for omni models
|
| 68 |
temperature: Optional[float] = None,
|
| 69 |
top_p: Optional[float] = None,
|
| 70 |
top_k: Optional[int] = None,
|
|
|
|
| 81 |
Args:
|
| 82 |
audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
|
| 83 |
text_model_id: HuggingFace model ID for text decoder (Qwen)
|
| 84 |
+
attn_implementation: Attention implementation ("sdpa", "flash_attention_2", "eager")
|
| 85 |
model_dtype: Model dtype ("bfloat16", "float16", "float32")
|
| 86 |
projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
|
| 87 |
use_lora: Enable LoRA adapters for Stage 2 fine-tuning
|
|
|
|
| 152 |
]
|
| 153 |
self.freeze_projector = freeze_projector
|
| 154 |
|
| 155 |
+
# Generation parameters: check named param first, then kwargs (from config.json), then default
|
| 156 |
+
def get_gen_param(name, named_value):
|
| 157 |
+
if named_value is not None:
|
| 158 |
+
return named_value
|
| 159 |
+
return kwargs.get(name, generation_defaults[name])
|
| 160 |
+
|
| 161 |
+
self.num_beams = get_gen_param("num_beams", num_beams)
|
| 162 |
+
self.max_new_tokens = get_gen_param("max_new_tokens", max_new_tokens)
|
| 163 |
+
self.min_new_tokens = get_gen_param("min_new_tokens", min_new_tokens)
|
| 164 |
+
self.repetition_penalty = get_gen_param("repetition_penalty", repetition_penalty)
|
| 165 |
+
self.length_penalty = get_gen_param("length_penalty", length_penalty)
|
| 166 |
+
self.no_repeat_ngram_size = get_gen_param("no_repeat_ngram_size", no_repeat_ngram_size)
|
| 167 |
+
self.use_cache = get_gen_param("use_cache", use_cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
self.do_sample = do_sample
|
| 169 |
+
self.enable_thinking = enable_thinking
|
| 170 |
self.temperature = temperature
|
| 171 |
self.top_p = top_p
|
| 172 |
self.top_k = top_k
|
asr_modeling.py
CHANGED
|
@@ -582,7 +582,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 582 |
tokenize=True,
|
| 583 |
add_generation_prompt=True,
|
| 584 |
return_tensors="pt",
|
| 585 |
-
enable_thinking=
|
| 586 |
)
|
| 587 |
input_ids = chat_result.input_ids.to(device)
|
| 588 |
|
|
@@ -665,7 +665,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 665 |
tokenize=True,
|
| 666 |
add_generation_prompt=True,
|
| 667 |
return_tensors="pt",
|
| 668 |
-
enable_thinking=
|
| 669 |
)
|
| 670 |
input_ids = chat_result.input_ids.to(device)
|
| 671 |
|
|
@@ -764,7 +764,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 764 |
tokenize=True,
|
| 765 |
add_generation_prompt=True,
|
| 766 |
return_tensors="pt",
|
| 767 |
-
enable_thinking=
|
| 768 |
).to(device)
|
| 769 |
|
| 770 |
if input_ids.dim() == 1:
|
|
|
|
| 582 |
tokenize=True,
|
| 583 |
add_generation_prompt=True,
|
| 584 |
return_tensors="pt",
|
| 585 |
+
enable_thinking=getattr(self.config, "enable_thinking", False),
|
| 586 |
)
|
| 587 |
input_ids = chat_result.input_ids.to(device)
|
| 588 |
|
|
|
|
| 665 |
tokenize=True,
|
| 666 |
add_generation_prompt=True,
|
| 667 |
return_tensors="pt",
|
| 668 |
+
enable_thinking=getattr(self.config, "enable_thinking", False),
|
| 669 |
)
|
| 670 |
input_ids = chat_result.input_ids.to(device)
|
| 671 |
|
|
|
|
| 764 |
tokenize=True,
|
| 765 |
add_generation_prompt=True,
|
| 766 |
return_tensors="pt",
|
| 767 |
+
enable_thinking=getattr(self.config, "enable_thinking", False),
|
| 768 |
).to(device)
|
| 769 |
|
| 770 |
if input_ids.dim() == 1:
|
diarization.py
CHANGED
|
@@ -91,20 +91,47 @@ class SpectralCluster:
|
|
| 91 |
def get_spec_embs(
|
| 92 |
self, laplacian: np.ndarray, k_oracle: int | None = None
|
| 93 |
) -> tuple[np.ndarray, int]:
|
| 94 |
-
"""Extract spectral embeddings from Laplacian.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
|
| 96 |
|
| 97 |
-
if k_oracle is not None
|
| 98 |
-
num_of_spk = k_oracle
|
| 99 |
-
else:
|
| 100 |
-
lambda_gap_list = self.get_eigen_gaps(
|
| 101 |
-
lambdas[self.min_num_spks - 1 : self.max_num_spks + 1]
|
| 102 |
-
)
|
| 103 |
-
num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
|
| 104 |
|
| 105 |
emb = eig_vecs[:, :num_of_spk]
|
| 106 |
return emb, num_of_spk
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
|
| 109 |
"""Cluster spectral embeddings using k-means."""
|
| 110 |
_, labels, _ = k_means(emb, k, n_init=10)
|
|
|
|
| 91 |
def get_spec_embs(
|
| 92 |
self, laplacian: np.ndarray, k_oracle: int | None = None
|
| 93 |
) -> tuple[np.ndarray, int]:
|
| 94 |
+
"""Extract spectral embeddings from Laplacian.
|
| 95 |
+
|
| 96 |
+
Uses the eigengap heuristic to estimate the number of clusters:
|
| 97 |
+
The number of clusters k is chosen where the gap between consecutive
|
| 98 |
+
eigenvalues is largest, indicating a transition from "cluster" eigenvalues
|
| 99 |
+
(near 0) to "noise" eigenvalues.
|
| 100 |
+
"""
|
| 101 |
lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
|
| 102 |
|
| 103 |
+
num_of_spk = k_oracle if k_oracle is not None else self._estimate_num_speakers(lambdas)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
emb = eig_vecs[:, :num_of_spk]
|
| 106 |
return emb, num_of_spk
|
| 107 |
|
| 108 |
+
def _estimate_num_speakers(self, lambdas: np.ndarray) -> int:
|
| 109 |
+
"""Estimate number of speakers using refined eigengap heuristic.
|
| 110 |
+
|
| 111 |
+
For spectral clustering, we look for the largest gap in eigenvalues.
|
| 112 |
+
The eigenvalues corresponding to clusters are close to 0, and there
|
| 113 |
+
should be a significant jump to the remaining eigenvalues.
|
| 114 |
+
"""
|
| 115 |
+
# Consider eigenvalues from index 1 to max_num_spks (skip first, it's always ~0)
|
| 116 |
+
# We need gaps between positions, so look at indices 1 to max_num_spks+1
|
| 117 |
+
max_idx = min(self.max_num_spks + 1, len(lambdas))
|
| 118 |
+
relevant_lambdas = lambdas[1:max_idx] # Skip first eigenvalue
|
| 119 |
+
|
| 120 |
+
if len(relevant_lambdas) < 2:
|
| 121 |
+
return self.min_num_spks
|
| 122 |
+
|
| 123 |
+
# Compute absolute gaps (not ratios - ratios are unstable near 0)
|
| 124 |
+
gaps = np.diff(relevant_lambdas)
|
| 125 |
+
|
| 126 |
+
# Find the largest gap - the index gives us (k-1) since we skipped first
|
| 127 |
+
# Add 1 to convert from gap index to number of speakers
|
| 128 |
+
# Add 1 again because we skipped the first eigenvalue
|
| 129 |
+
max_gap_idx = int(np.argmax(gaps))
|
| 130 |
+
num_of_spk = max_gap_idx + 2 # +1 for gap->count, +1 for skipped eigenvalue
|
| 131 |
+
|
| 132 |
+
# Clamp between min and max
|
| 133 |
+
return max(self.min_num_spks, min(num_of_spk, self.max_num_spks))
|
| 134 |
+
|
| 135 |
def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
|
| 136 |
"""Cluster spectral embeddings using k-means."""
|
| 137 |
_, labels, _ = k_means(emb, k, n_init=10)
|
handler.py
CHANGED
|
@@ -39,8 +39,6 @@ class EndpointHandler:
|
|
| 39 |
"torch_dtype": "auto",
|
| 40 |
"low_cpu_mem_usage": True,
|
| 41 |
}
|
| 42 |
-
if self._is_flash_attn_available():
|
| 43 |
-
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 44 |
|
| 45 |
# Load model (this loads the model, tokenizer, and feature extractor)
|
| 46 |
self.model = ASRModel.from_pretrained(path, **model_kwargs)
|
|
@@ -56,12 +54,6 @@ class EndpointHandler:
|
|
| 56 |
device=self.device,
|
| 57 |
)
|
| 58 |
|
| 59 |
-
def _is_flash_attn_available(self):
|
| 60 |
-
"""Check if flash attention is available."""
|
| 61 |
-
import importlib.util
|
| 62 |
-
|
| 63 |
-
return importlib.util.find_spec("flash_attn") is not None
|
| 64 |
-
|
| 65 |
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
| 66 |
"""Process an inference request.
|
| 67 |
|
|
|
|
| 39 |
"torch_dtype": "auto",
|
| 40 |
"low_cpu_mem_usage": True,
|
| 41 |
}
|
|
|
|
|
|
|
| 42 |
|
| 43 |
# Load model (this loads the model, tokenizer, and feature extractor)
|
| 44 |
self.model = ASRModel.from_pretrained(path, **model_kwargs)
|
|
|
|
| 54 |
device=self.device,
|
| 55 |
)
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
| 58 |
"""Process an inference request.
|
| 59 |
|