mazesmazes commited on
Commit
f47bbc6
·
verified ·
1 Parent(s): bee1fc8

Update custom model files, README, and requirements

Browse files
Files changed (5) hide show
  1. alignment.py +8 -7
  2. asr_config.py +17 -24
  3. asr_modeling.py +3 -3
  4. diarization.py +35 -8
  5. handler.py +0 -8
alignment.py CHANGED
@@ -73,7 +73,7 @@ class ForcedAligner:
73
  # Force alignment to use all tokens by preventing staying in blank
74
  # at the end when there are still tokens to emit
75
  if num_tokens > 1:
76
- trellis[-num_tokens + 1:, 0] = float("inf")
77
 
78
  for t in range(num_frames):
79
  for j in range(num_tokens + 1):
@@ -113,7 +113,12 @@ class ForcedAligner:
113
  # Alignment failed - fall back to uniform distribution
114
  frames_per_token = num_frames / num_tokens
115
  return [
116
- (tokens[i], i * frames_per_token, (i + 1) * frames_per_token, (i + 0.5) * frames_per_token)
 
 
 
 
 
117
  for i in range(num_tokens)
118
  ]
119
 
@@ -280,11 +285,7 @@ class ForcedAligner:
280
  last_char_peak = peak_frame
281
 
282
  # Don't forget the last word
283
- if (
284
- first_char_peak is not None
285
- and last_char_peak is not None
286
- and word_idx < len(words)
287
- ):
288
  start_time = max(0.0, first_char_peak * frame_duration - start_offset)
289
  end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
290
  word_timestamps.append(
 
73
  # Force alignment to use all tokens by preventing staying in blank
74
  # at the end when there are still tokens to emit
75
  if num_tokens > 1:
76
+ trellis[-num_tokens + 1 :, 0] = float("inf")
77
 
78
  for t in range(num_frames):
79
  for j in range(num_tokens + 1):
 
113
  # Alignment failed - fall back to uniform distribution
114
  frames_per_token = num_frames / num_tokens
115
  return [
116
+ (
117
+ tokens[i],
118
+ i * frames_per_token,
119
+ (i + 1) * frames_per_token,
120
+ (i + 0.5) * frames_per_token,
121
+ )
122
  for i in range(num_tokens)
123
  ]
124
 
 
285
  last_char_peak = peak_frame
286
 
287
  # Don't forget the last word
288
+ if first_char_peak is not None and last_char_peak is not None and word_idx < len(words):
 
 
 
 
289
  start_time = max(0.0, first_char_peak * frame_duration - start_offset)
290
  end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset)
291
  word_timestamps.append(
asr_config.py CHANGED
@@ -21,7 +21,7 @@ class ASRConfig(transformers.PretrainedConfig):
21
  self,
22
  audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
23
  text_model_id: str = "Qwen/Qwen3-0.6B",
24
- attn_implementation: str = "flash_attention_2",
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
27
  system_prompt: str = "You are a helpful assistant.",
@@ -64,6 +64,7 @@ class ASRConfig(transformers.PretrainedConfig):
64
  lora_target_modules: Optional[list] = None, # Default: all linear layers
65
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
66
  do_sample: bool = False,
 
67
  temperature: Optional[float] = None,
68
  top_p: Optional[float] = None,
69
  top_k: Optional[int] = None,
@@ -80,7 +81,7 @@ class ASRConfig(transformers.PretrainedConfig):
80
  Args:
81
  audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
82
  text_model_id: HuggingFace model ID for text decoder (Qwen)
83
- attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
84
  model_dtype: Model dtype ("bfloat16", "float16", "float32")
85
  projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
86
  use_lora: Enable LoRA adapters for Stage 2 fine-tuning
@@ -151,29 +152,21 @@ class ASRConfig(transformers.PretrainedConfig):
151
  ]
152
  self.freeze_projector = freeze_projector
153
 
154
- # Generation parameters (use explicit value if provided, else use default)
155
- self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
156
- self.max_new_tokens = (
157
- max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
158
- )
159
- self.min_new_tokens = (
160
- min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
161
- )
162
- self.repetition_penalty = (
163
- repetition_penalty
164
- if repetition_penalty is not None
165
- else generation_defaults["repetition_penalty"]
166
- )
167
- self.length_penalty = (
168
- length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
169
- )
170
- self.no_repeat_ngram_size = (
171
- no_repeat_ngram_size
172
- if no_repeat_ngram_size is not None
173
- else generation_defaults["no_repeat_ngram_size"]
174
- )
175
- self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
176
  self.do_sample = do_sample
 
177
  self.temperature = temperature
178
  self.top_p = top_p
179
  self.top_k = top_k
 
21
  self,
22
  audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
23
  text_model_id: str = "Qwen/Qwen3-0.6B",
24
+ attn_implementation: str = "sdpa",
25
  model_dtype: str = "bfloat16",
26
  num_beams: Optional[int] = None,
27
  system_prompt: str = "You are a helpful assistant.",
 
64
  lora_target_modules: Optional[list] = None, # Default: all linear layers
65
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
66
  do_sample: bool = False,
67
+ enable_thinking: bool = False, # Enable Qwen3 thinking mode for omni models
68
  temperature: Optional[float] = None,
69
  top_p: Optional[float] = None,
70
  top_k: Optional[int] = None,
 
81
  Args:
82
  audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
83
  text_model_id: HuggingFace model ID for text decoder (Qwen)
84
+ attn_implementation: Attention implementation ("sdpa", "flash_attention_2", "eager")
85
  model_dtype: Model dtype ("bfloat16", "float16", "float32")
86
  projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
87
  use_lora: Enable LoRA adapters for Stage 2 fine-tuning
 
152
  ]
153
  self.freeze_projector = freeze_projector
154
 
155
+ # Generation parameters: check named param first, then kwargs (from config.json), then default
156
+ def get_gen_param(name, named_value):
157
+ if named_value is not None:
158
+ return named_value
159
+ return kwargs.get(name, generation_defaults[name])
160
+
161
+ self.num_beams = get_gen_param("num_beams", num_beams)
162
+ self.max_new_tokens = get_gen_param("max_new_tokens", max_new_tokens)
163
+ self.min_new_tokens = get_gen_param("min_new_tokens", min_new_tokens)
164
+ self.repetition_penalty = get_gen_param("repetition_penalty", repetition_penalty)
165
+ self.length_penalty = get_gen_param("length_penalty", length_penalty)
166
+ self.no_repeat_ngram_size = get_gen_param("no_repeat_ngram_size", no_repeat_ngram_size)
167
+ self.use_cache = get_gen_param("use_cache", use_cache)
 
 
 
 
 
 
 
 
 
168
  self.do_sample = do_sample
169
+ self.enable_thinking = enable_thinking
170
  self.temperature = temperature
171
  self.top_p = top_p
172
  self.top_k = top_k
asr_modeling.py CHANGED
@@ -582,7 +582,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
582
  tokenize=True,
583
  add_generation_prompt=True,
584
  return_tensors="pt",
585
- enable_thinking=False, # Disable Qwen3 thinking mode for ASR
586
  )
587
  input_ids = chat_result.input_ids.to(device)
588
 
@@ -665,7 +665,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
665
  tokenize=True,
666
  add_generation_prompt=True,
667
  return_tensors="pt",
668
- enable_thinking=False, # Disable Qwen3 thinking mode for ASR
669
  )
670
  input_ids = chat_result.input_ids.to(device)
671
 
@@ -764,7 +764,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
764
  tokenize=True,
765
  add_generation_prompt=True,
766
  return_tensors="pt",
767
- enable_thinking=False, # Disable Qwen3 thinking mode for ASR
768
  ).to(device)
769
 
770
  if input_ids.dim() == 1:
 
582
  tokenize=True,
583
  add_generation_prompt=True,
584
  return_tensors="pt",
585
+ enable_thinking=getattr(self.config, "enable_thinking", False),
586
  )
587
  input_ids = chat_result.input_ids.to(device)
588
 
 
665
  tokenize=True,
666
  add_generation_prompt=True,
667
  return_tensors="pt",
668
+ enable_thinking=getattr(self.config, "enable_thinking", False),
669
  )
670
  input_ids = chat_result.input_ids.to(device)
671
 
 
764
  tokenize=True,
765
  add_generation_prompt=True,
766
  return_tensors="pt",
767
+ enable_thinking=getattr(self.config, "enable_thinking", False),
768
  ).to(device)
769
 
770
  if input_ids.dim() == 1:
diarization.py CHANGED
@@ -91,20 +91,47 @@ class SpectralCluster:
91
  def get_spec_embs(
92
  self, laplacian: np.ndarray, k_oracle: int | None = None
93
  ) -> tuple[np.ndarray, int]:
94
- """Extract spectral embeddings from Laplacian."""
 
 
 
 
 
 
95
  lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
96
 
97
- if k_oracle is not None:
98
- num_of_spk = k_oracle
99
- else:
100
- lambda_gap_list = self.get_eigen_gaps(
101
- lambdas[self.min_num_spks - 1 : self.max_num_spks + 1]
102
- )
103
- num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
104
 
105
  emb = eig_vecs[:, :num_of_spk]
106
  return emb, num_of_spk
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
109
  """Cluster spectral embeddings using k-means."""
110
  _, labels, _ = k_means(emb, k, n_init=10)
 
91
  def get_spec_embs(
92
  self, laplacian: np.ndarray, k_oracle: int | None = None
93
  ) -> tuple[np.ndarray, int]:
94
+ """Extract spectral embeddings from Laplacian.
95
+
96
+ Uses the eigengap heuristic to estimate the number of clusters:
97
+ The number of clusters k is chosen where the gap between consecutive
98
+ eigenvalues is largest, indicating a transition from "cluster" eigenvalues
99
+ (near 0) to "noise" eigenvalues.
100
+ """
101
  lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
102
 
103
+ num_of_spk = k_oracle if k_oracle is not None else self._estimate_num_speakers(lambdas)
 
 
 
 
 
 
104
 
105
  emb = eig_vecs[:, :num_of_spk]
106
  return emb, num_of_spk
107
 
108
+ def _estimate_num_speakers(self, lambdas: np.ndarray) -> int:
109
+ """Estimate number of speakers using refined eigengap heuristic.
110
+
111
+ For spectral clustering, we look for the largest gap in eigenvalues.
112
+ The eigenvalues corresponding to clusters are close to 0, and there
113
+ should be a significant jump to the remaining eigenvalues.
114
+ """
115
+ # Consider eigenvalues from index 1 to max_num_spks (skip first, it's always ~0)
116
+ # We need gaps between positions, so look at indices 1 to max_num_spks+1
117
+ max_idx = min(self.max_num_spks + 1, len(lambdas))
118
+ relevant_lambdas = lambdas[1:max_idx] # Skip first eigenvalue
119
+
120
+ if len(relevant_lambdas) < 2:
121
+ return self.min_num_spks
122
+
123
+ # Compute absolute gaps (not ratios - ratios are unstable near 0)
124
+ gaps = np.diff(relevant_lambdas)
125
+
126
+ # Find the largest gap - the index gives us (k-1) since we skipped first
127
+ # Add 1 to convert from gap index to number of speakers
128
+ # Add 1 again because we skipped the first eigenvalue
129
+ max_gap_idx = int(np.argmax(gaps))
130
+ num_of_spk = max_gap_idx + 2 # +1 for gap->count, +1 for skipped eigenvalue
131
+
132
+ # Clamp between min and max
133
+ return max(self.min_num_spks, min(num_of_spk, self.max_num_spks))
134
+
135
  def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
136
  """Cluster spectral embeddings using k-means."""
137
  _, labels, _ = k_means(emb, k, n_init=10)
handler.py CHANGED
@@ -39,8 +39,6 @@ class EndpointHandler:
39
  "torch_dtype": "auto",
40
  "low_cpu_mem_usage": True,
41
  }
42
- if self._is_flash_attn_available():
43
- model_kwargs["attn_implementation"] = "flash_attention_2"
44
 
45
  # Load model (this loads the model, tokenizer, and feature extractor)
46
  self.model = ASRModel.from_pretrained(path, **model_kwargs)
@@ -56,12 +54,6 @@ class EndpointHandler:
56
  device=self.device,
57
  )
58
 
59
- def _is_flash_attn_available(self):
60
- """Check if flash attention is available."""
61
- import importlib.util
62
-
63
- return importlib.util.find_spec("flash_attn") is not None
64
-
65
  def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
66
  """Process an inference request.
67
 
 
39
  "torch_dtype": "auto",
40
  "low_cpu_mem_usage": True,
41
  }
 
 
42
 
43
  # Load model (this loads the model, tokenizer, and feature extractor)
44
  self.model = ASRModel.from_pretrained(path, **model_kwargs)
 
54
  device=self.device,
55
  )
56
 
 
 
 
 
 
 
57
  def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
58
  """Process an inference request.
59