mazesmazes commited on
Commit
fbf00b0
·
verified ·
1 Parent(s): 8d8b876

Update custom model files, README, and requirements

Browse files
Files changed (4) hide show
  1. .gitattributes +0 -1
  2. asr_modeling.py +11 -3
  3. asr_processing.py +5 -4
  4. projectors.py +11 -9
.gitattributes CHANGED
@@ -1,4 +1,3 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
- tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
asr_modeling.py CHANGED
@@ -38,7 +38,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
38
  _is_loading_from_pretrained: bool = False
39
  _pretrained_model_path: Optional[str] = None
40
 
41
- TRANSCRIBE_PROMPT = "Please transcribe this audio into text: "
42
 
43
  @classmethod
44
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
@@ -543,7 +543,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
543
  messages: list[dict[str, str]] = []
544
  if system_prompt:
545
  messages.append({"role": "system", "content": system_prompt})
546
- messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
 
 
 
547
 
548
  chat_result = self.tokenizer.apply_chat_template(
549
  messages,
@@ -618,7 +621,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
618
  messages: list[dict[str, str]] = []
619
  if system_prompt:
620
  messages.append({"role": "system", "content": system_prompt})
621
- messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
 
 
 
622
 
623
  chat_result = self.tokenizer.apply_chat_template(
624
  messages,
@@ -778,6 +784,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
778
  shutil.copy(asr_file, save_dir / asr_file.name)
779
  # Copy projectors module
780
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
 
 
781
 
782
  def push_to_hub(self, repo_id: str, **kwargs) -> str:
783
  """Push model to HuggingFace Hub, ensuring adapter_config points to repo.
 
38
  _is_loading_from_pretrained: bool = False
39
  _pretrained_model_path: Optional[str] = None
40
 
41
+ TRANSCRIBE_PROMPT = "Transcribe speech to text" # Audio tokens come BEFORE this
42
 
43
  @classmethod
44
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
 
543
  messages: list[dict[str, str]] = []
544
  if system_prompt:
545
  messages.append({"role": "system", "content": system_prompt})
546
+ # Audio BEFORE prompt for proper causal attention
547
+ messages.append(
548
+ {"role": "user", "content": audio_placeholder + " " + self.TRANSCRIBE_PROMPT}
549
+ )
550
 
551
  chat_result = self.tokenizer.apply_chat_template(
552
  messages,
 
621
  messages: list[dict[str, str]] = []
622
  if system_prompt:
623
  messages.append({"role": "system", "content": system_prompt})
624
+ # Audio BEFORE prompt for proper causal attention
625
+ messages.append(
626
+ {"role": "user", "content": audio_placeholder + " " + self.TRANSCRIBE_PROMPT}
627
+ )
628
 
629
  chat_result = self.tokenizer.apply_chat_template(
630
  messages,
 
784
  shutil.copy(asr_file, save_dir / asr_file.name)
785
  # Copy projectors module
786
  shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
787
+ # Copy diarization module
788
+ shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py")
789
 
790
  def push_to_hub(self, repo_id: str, **kwargs) -> str:
791
  """Push model to HuggingFace Hub, ensuring adapter_config points to repo.
asr_processing.py CHANGED
@@ -17,7 +17,7 @@ class ASRProcessor(ProcessorMixin):
17
  feature_extractor_class = "AutoFeatureExtractor"
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
- TRANSCRIBE_PROMPT = "Transcribe: "
21
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
  DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
 
@@ -89,10 +89,11 @@ class ASRProcessor(ProcessorMixin):
89
  else:
90
  num_audio_tokens = 0
91
 
92
- # Build prompt with audio token placeholders
93
- user_content = self.TRANSCRIBE_PROMPT
94
  if num_audio_tokens > 0:
95
- user_content += self.AUDIO_TOKEN * num_audio_tokens
 
 
96
 
97
  messages = []
98
  if system_prompt:
 
17
  feature_extractor_class = "AutoFeatureExtractor"
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
+ TRANSCRIBE_PROMPT = "Transcribe speech to text"
21
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
  DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
 
 
89
  else:
90
  num_audio_tokens = 0
91
 
92
+ # Build prompt with audio token placeholders (audio BEFORE prompt)
 
93
  if num_audio_tokens > 0:
94
+ user_content = self.AUDIO_TOKEN * num_audio_tokens + " " + self.TRANSCRIBE_PROMPT
95
+ else:
96
+ user_content = self.TRANSCRIBE_PROMPT
97
 
98
  messages = []
99
  if system_prompt:
projectors.py CHANGED
@@ -33,11 +33,12 @@ class MLPAudioProjector(nn.Module):
33
 
34
  encoder_dim = getattr(config, "encoder_dim", 768)
35
  llm_dim = getattr(config, "llm_dim", 2048)
36
- self.k = getattr(config, "projector_pool_stride", 2)
37
 
38
  # Frame stacking: concat k adjacent frames then project
 
39
  in_dim = encoder_dim * self.k
40
- hidden_dim = llm_dim
41
  self.linear_1 = nn.Linear(in_dim, hidden_dim)
42
  self.act = nn.GELU()
43
  self.linear_2 = nn.Linear(hidden_dim, llm_dim)
@@ -85,6 +86,7 @@ class SimpleAdapter(nn.Module):
85
  def forward(self, x: torch.Tensor) -> torch.Tensor:
86
  return self.fc2(self.act(self.fc1(x)))
87
 
 
88
  class MOSAProjector(nn.Module):
89
  """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
90
 
@@ -126,7 +128,10 @@ class MOSAProjector(nn.Module):
126
  # --- 3. Experts (Simple 2-layer GELU adapters) ---
127
  # Each expert: llm_dim -> hidden -> llm_dim (much smaller than frame-stacking)
128
  self.experts = nn.ModuleList(
129
- [SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim) for _ in range(self.num_experts)]
 
 
 
130
  )
131
 
132
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -149,18 +154,15 @@ class MOSAProjector(nn.Module):
149
  routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
150
 
151
  # --- 3. Expert Mixture (Dense Execution) ---
152
- expert_outputs = torch.stack(
153
- [expert(x) for expert in self.experts]
154
- ) # (E, B, out_len, D)
155
  return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
156
 
157
  def get_output_length(self, input_length: int) -> int:
158
  """Calculate output sequence length after Conv1d downsampling (4x reduction)."""
159
  # Conv1d with stride 2, kernel 3, padding 1: out = (in + 2*1 - 3) // 2 + 1 = (in - 1) // 2 + 1
160
  # Applied twice for 4x total reduction
161
- length = (input_length + 2 * 1 - 3) // 2 + 1 # First conv
162
- length = (length + 2 * 1 - 3) // 2 + 1 # Second conv
163
- return length
164
 
165
 
166
  # =============================================================================
 
33
 
34
  encoder_dim = getattr(config, "encoder_dim", 768)
35
  llm_dim = getattr(config, "llm_dim", 2048)
36
+ self.k = getattr(config, "projector_pool_stride", 4)
37
 
38
  # Frame stacking: concat k adjacent frames then project
39
+ # Hidden dim uses 2x expansion like GLM-ASR's GlmAsrMultiModalProjector
40
  in_dim = encoder_dim * self.k
41
+ hidden_dim = llm_dim * 2
42
  self.linear_1 = nn.Linear(in_dim, hidden_dim)
43
  self.act = nn.GELU()
44
  self.linear_2 = nn.Linear(hidden_dim, llm_dim)
 
86
  def forward(self, x: torch.Tensor) -> torch.Tensor:
87
  return self.fc2(self.act(self.fc1(x)))
88
 
89
+
90
  class MOSAProjector(nn.Module):
91
  """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
92
 
 
128
  # --- 3. Experts (Simple 2-layer GELU adapters) ---
129
  # Each expert: llm_dim -> hidden -> llm_dim (much smaller than frame-stacking)
130
  self.experts = nn.ModuleList(
131
+ [
132
+ SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
133
+ for _ in range(self.num_experts)
134
+ ]
135
  )
136
 
137
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
154
  routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
155
 
156
  # --- 3. Expert Mixture (Dense Execution) ---
157
+ expert_outputs = torch.stack([expert(x) for expert in self.experts]) # (E, B, out_len, D)
 
 
158
  return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
159
 
160
  def get_output_length(self, input_length: int) -> int:
161
  """Calculate output sequence length after Conv1d downsampling (4x reduction)."""
162
  # Conv1d with stride 2, kernel 3, padding 1: out = (in + 2*1 - 3) // 2 + 1 = (in - 1) // 2 + 1
163
  # Applied twice for 4x total reduction
164
+ after_conv1 = (input_length + 2 * 1 - 3) // 2 + 1
165
+ return (after_conv1 + 2 * 1 - 3) // 2 + 1
 
166
 
167
 
168
  # =============================================================================