mazesmazes commited on
Commit
c03ad04
·
verified ·
1 Parent(s): 67d278a

Training in progress - step 1000

Browse files
Files changed (5) hide show
  1. asr_config.py +2 -0
  2. asr_modeling.py +11 -5
  3. asr_processing.py +1 -1
  4. config.json +1 -0
  5. model.safetensors +2 -2
asr_config.py CHANGED
@@ -66,6 +66,7 @@ class ASRConfig(transformers.PretrainedConfig):
66
  lora_dropout: float = 0.0,
67
  lora_target_modules: Optional[list] = None, # Default: all linear layers
68
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
 
69
  do_sample: bool = False,
70
  temperature: Optional[float] = None,
71
  top_p: Optional[float] = None,
@@ -152,6 +153,7 @@ class ASRConfig(transformers.PretrainedConfig):
152
  "down_proj",
153
  ]
154
  self.freeze_projector = freeze_projector
 
155
 
156
  explicit_generation_args = {
157
  "num_beams": num_beams,
 
66
  lora_dropout: float = 0.0,
67
  lora_target_modules: Optional[list] = None, # Default: all linear layers
68
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
69
+ freeze_language_model: bool = True, # False = full decoder fine-tuning
70
  do_sample: bool = False,
71
  temperature: Optional[float] = None,
72
  top_p: Optional[float] = None,
 
153
  "down_proj",
154
  ]
155
  self.freeze_projector = freeze_projector
156
+ self.freeze_language_model = freeze_language_model
157
 
158
  explicit_generation_args = {
159
  "num_beams": num_beams,
asr_modeling.py CHANGED
@@ -58,7 +58,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
58
  _is_loading_from_pretrained: bool = False
59
  _pretrained_model_path: Optional[str] = None
60
 
61
- TRANSCRIBE_PROMPT = ""
62
 
63
  @classmethod
64
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
@@ -265,8 +265,9 @@ class ASRModel(PreTrainedModel, GenerationMixin):
265
 
266
  decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
267
  decoder.config.use_cache = getattr(config, "use_cache", True)
268
- decoder.requires_grad_(False)
269
- decoder.eval()
 
270
  return decoder
271
 
272
  def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
@@ -395,8 +396,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
395
  )
396
 
397
  def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
398
- """Only save trainable projector weights."""
399
- return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
 
 
 
 
 
400
 
401
  def _compute_encoder_output_lengths(
402
  self,
 
58
  _is_loading_from_pretrained: bool = False
59
  _pretrained_model_path: Optional[str] = None
60
 
61
+ TRANSCRIBE_PROMPT = "Transcribe the speech to text"
62
 
63
  @classmethod
64
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
 
265
 
266
  decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
267
  decoder.config.use_cache = getattr(config, "use_cache", True)
268
+ if getattr(config, "freeze_language_model", True):
269
+ decoder.requires_grad_(False)
270
+ decoder.train(False)
271
  return decoder
272
 
273
  def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
 
396
  )
397
 
398
  def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
399
+ """Save trainable weights: projector, plus the language model when fine-tuned."""
400
+ sd = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
401
+ if not getattr(self.config, "freeze_language_model", True):
402
+ sd.update(
403
+ {f"language_model.{k}": v for k, v in self.language_model.state_dict().items()}
404
+ )
405
+ return sd
406
 
407
  def _compute_encoder_output_lengths(
408
  self,
asr_processing.py CHANGED
@@ -17,7 +17,7 @@ class ASRProcessor(ProcessorMixin):
17
  feature_extractor_class = "AutoFeatureExtractor"
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
- TRANSCRIBE_PROMPT = ""
21
 
22
  def __init__(
23
  self,
 
17
  feature_extractor_class = "AutoFeatureExtractor"
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
+ TRANSCRIBE_PROMPT = "Transcribe the speech to text"
21
 
22
  def __init__(
23
  self,
config.json CHANGED
@@ -235,6 +235,7 @@
235
  ],
236
  "encoder_dim": 1280,
237
  "eos_token_id": 151645,
 
238
  "freeze_projector": false,
239
  "freq_mask_length": 27,
240
  "inference_warmup_tokens": 10,
 
235
  ],
236
  "encoder_dim": 1280,
237
  "eos_token_id": 151645,
238
+ "freeze_language_model": false,
239
  "freeze_projector": false,
240
  "freq_mask_length": 27,
241
  "inference_warmup_tokens": 10,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dfbd8328dc22b130f2558cd3cf634a711688a57837e4f8d18bce72a38398dd4c
3
- size 25170248
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4121706cbf6dcdefc550c39232b722d9f8b6b180ebd04e75056da125de9ee705
3
+ size 1216765200