Training in progress - step 1000
Browse files- asr_config.py +2 -0
- asr_modeling.py +11 -5
- asr_processing.py +1 -1
- config.json +1 -0
- model.safetensors +2 -2
asr_config.py
CHANGED
|
@@ -66,6 +66,7 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 66 |
lora_dropout: float = 0.0,
|
| 67 |
lora_target_modules: Optional[list] = None, # Default: all linear layers
|
| 68 |
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
|
|
|
|
| 69 |
do_sample: bool = False,
|
| 70 |
temperature: Optional[float] = None,
|
| 71 |
top_p: Optional[float] = None,
|
|
@@ -152,6 +153,7 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 152 |
"down_proj",
|
| 153 |
]
|
| 154 |
self.freeze_projector = freeze_projector
|
|
|
|
| 155 |
|
| 156 |
explicit_generation_args = {
|
| 157 |
"num_beams": num_beams,
|
|
|
|
| 66 |
lora_dropout: float = 0.0,
|
| 67 |
lora_target_modules: Optional[list] = None, # Default: all linear layers
|
| 68 |
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
|
| 69 |
+
freeze_language_model: bool = True, # False = full decoder fine-tuning
|
| 70 |
do_sample: bool = False,
|
| 71 |
temperature: Optional[float] = None,
|
| 72 |
top_p: Optional[float] = None,
|
|
|
|
| 153 |
"down_proj",
|
| 154 |
]
|
| 155 |
self.freeze_projector = freeze_projector
|
| 156 |
+
self.freeze_language_model = freeze_language_model
|
| 157 |
|
| 158 |
explicit_generation_args = {
|
| 159 |
"num_beams": num_beams,
|
asr_modeling.py
CHANGED
|
@@ -58,7 +58,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 58 |
_is_loading_from_pretrained: bool = False
|
| 59 |
_pretrained_model_path: Optional[str] = None
|
| 60 |
|
| 61 |
-
TRANSCRIBE_PROMPT = ""
|
| 62 |
|
| 63 |
@classmethod
|
| 64 |
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
|
|
@@ -265,8 +265,9 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 265 |
|
| 266 |
decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
|
| 267 |
decoder.config.use_cache = getattr(config, "use_cache", True)
|
| 268 |
-
|
| 269 |
-
|
|
|
|
| 270 |
return decoder
|
| 271 |
|
| 272 |
def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
|
|
@@ -395,8 +396,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 395 |
)
|
| 396 |
|
| 397 |
def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
|
| 398 |
-
"""
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
def _compute_encoder_output_lengths(
|
| 402 |
self,
|
|
|
|
| 58 |
_is_loading_from_pretrained: bool = False
|
| 59 |
_pretrained_model_path: Optional[str] = None
|
| 60 |
|
| 61 |
+
TRANSCRIBE_PROMPT = "Transcribe the speech to text"
|
| 62 |
|
| 63 |
@classmethod
|
| 64 |
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
|
|
|
|
| 265 |
|
| 266 |
decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
|
| 267 |
decoder.config.use_cache = getattr(config, "use_cache", True)
|
| 268 |
+
if getattr(config, "freeze_language_model", True):
|
| 269 |
+
decoder.requires_grad_(False)
|
| 270 |
+
decoder.train(False)
|
| 271 |
return decoder
|
| 272 |
|
| 273 |
def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
|
|
|
|
| 396 |
)
|
| 397 |
|
| 398 |
def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
|
| 399 |
+
"""Save trainable weights: projector, plus the language model when fine-tuned."""
|
| 400 |
+
sd = {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
|
| 401 |
+
if not getattr(self.config, "freeze_language_model", True):
|
| 402 |
+
sd.update(
|
| 403 |
+
{f"language_model.{k}": v for k, v in self.language_model.state_dict().items()}
|
| 404 |
+
)
|
| 405 |
+
return sd
|
| 406 |
|
| 407 |
def _compute_encoder_output_lengths(
|
| 408 |
self,
|
asr_processing.py
CHANGED
|
@@ -17,7 +17,7 @@ class ASRProcessor(ProcessorMixin):
|
|
| 17 |
feature_extractor_class = "AutoFeatureExtractor"
|
| 18 |
tokenizer_class = "AutoTokenizer"
|
| 19 |
AUDIO_TOKEN = "<audio>"
|
| 20 |
-
TRANSCRIBE_PROMPT = ""
|
| 21 |
|
| 22 |
def __init__(
|
| 23 |
self,
|
|
|
|
| 17 |
feature_extractor_class = "AutoFeatureExtractor"
|
| 18 |
tokenizer_class = "AutoTokenizer"
|
| 19 |
AUDIO_TOKEN = "<audio>"
|
| 20 |
+
TRANSCRIBE_PROMPT = "Transcribe the speech to text"
|
| 21 |
|
| 22 |
def __init__(
|
| 23 |
self,
|
config.json
CHANGED
|
@@ -235,6 +235,7 @@
|
|
| 235 |
],
|
| 236 |
"encoder_dim": 1280,
|
| 237 |
"eos_token_id": 151645,
|
|
|
|
| 238 |
"freeze_projector": false,
|
| 239 |
"freq_mask_length": 27,
|
| 240 |
"inference_warmup_tokens": 10,
|
|
|
|
| 235 |
],
|
| 236 |
"encoder_dim": 1280,
|
| 237 |
"eos_token_id": 151645,
|
| 238 |
+
"freeze_language_model": false,
|
| 239 |
"freeze_projector": false,
|
| 240 |
"freq_mask_length": 27,
|
| 241 |
"inference_warmup_tokens": 10,
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4121706cbf6dcdefc550c39232b722d9f8b6b180ebd04e75056da125de9ee705
|
| 3 |
+
size 1216765200
|