Training in progress - step 364
Browse files- asr_config.py +8 -0
- asr_modeling.py +6 -6
- diarization.py +1 -1
asr_config.py
CHANGED
|
@@ -63,6 +63,10 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 63 |
lora_dropout: float = 0.0,
|
| 64 |
lora_target_modules: Optional[list] = None, # Default: all linear layers
|
| 65 |
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
max_new_tokens: Optional[int] = None,
|
| 67 |
min_new_tokens: Optional[int] = None,
|
| 68 |
repetition_penalty: Optional[float] = None,
|
|
@@ -169,6 +173,10 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 169 |
else generation_defaults["no_repeat_ngram_size"]
|
| 170 |
)
|
| 171 |
self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
if "audio_config" not in kwargs:
|
| 174 |
self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
|
|
|
|
| 63 |
lora_dropout: float = 0.0,
|
| 64 |
lora_target_modules: Optional[list] = None, # Default: all linear layers
|
| 65 |
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
|
| 66 |
+
do_sample: bool = False,
|
| 67 |
+
temperature: Optional[float] = None,
|
| 68 |
+
top_p: Optional[float] = None,
|
| 69 |
+
top_k: Optional[int] = None,
|
| 70 |
max_new_tokens: Optional[int] = None,
|
| 71 |
min_new_tokens: Optional[int] = None,
|
| 72 |
repetition_penalty: Optional[float] = None,
|
|
|
|
| 173 |
else generation_defaults["no_repeat_ngram_size"]
|
| 174 |
)
|
| 175 |
self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
|
| 176 |
+
self.do_sample = do_sample
|
| 177 |
+
self.temperature = temperature
|
| 178 |
+
self.top_p = top_p
|
| 179 |
+
self.top_k = top_k
|
| 180 |
|
| 181 |
if "audio_config" not in kwargs:
|
| 182 |
self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
|
asr_modeling.py
CHANGED
|
@@ -136,11 +136,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 136 |
self.generation_config.max_new_tokens = config.max_new_tokens
|
| 137 |
self.generation_config.min_new_tokens = config.min_new_tokens
|
| 138 |
self.generation_config.num_beams = config.num_beams
|
| 139 |
-
self.generation_config.do_sample =
|
| 140 |
-
#
|
| 141 |
-
self.generation_config.temperature =
|
| 142 |
-
self.generation_config.top_p =
|
| 143 |
-
self.generation_config.top_k =
|
| 144 |
self.generation_config.use_cache = config.use_cache
|
| 145 |
self.generation_config.length_penalty = config.length_penalty
|
| 146 |
self.generation_config.repetition_penalty = config.repetition_penalty
|
|
@@ -730,7 +730,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 730 |
tokenize=True,
|
| 731 |
add_generation_prompt=True,
|
| 732 |
return_tensors="pt",
|
| 733 |
-
enable_thinking=False,
|
| 734 |
).to(device)
|
| 735 |
|
| 736 |
if input_ids.dim() == 1:
|
|
|
|
| 136 |
self.generation_config.max_new_tokens = config.max_new_tokens
|
| 137 |
self.generation_config.min_new_tokens = config.min_new_tokens
|
| 138 |
self.generation_config.num_beams = config.num_beams
|
| 139 |
+
self.generation_config.do_sample = config.do_sample
|
| 140 |
+
# Set sampling params from config (None means use model defaults)
|
| 141 |
+
self.generation_config.temperature = config.temperature
|
| 142 |
+
self.generation_config.top_p = config.top_p
|
| 143 |
+
self.generation_config.top_k = config.top_k
|
| 144 |
self.generation_config.use_cache = config.use_cache
|
| 145 |
self.generation_config.length_penalty = config.length_penalty
|
| 146 |
self.generation_config.repetition_penalty = config.repetition_penalty
|
|
|
|
| 730 |
tokenize=True,
|
| 731 |
add_generation_prompt=True,
|
| 732 |
return_tensors="pt",
|
| 733 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 734 |
).to(device)
|
| 735 |
|
| 736 |
if input_ids.dim() == 1:
|
diarization.py
CHANGED
|
@@ -737,7 +737,7 @@ class SpeakerDiarizer:
|
|
| 737 |
|
| 738 |
cls._pyannote_pipeline = Pipeline.from_pretrained(
|
| 739 |
"pyannote/speaker-diarization-3.1",
|
| 740 |
-
|
| 741 |
)
|
| 742 |
cls._pyannote_pipeline.to(torch.device(_get_device()))
|
| 743 |
|
|
|
|
| 737 |
|
| 738 |
cls._pyannote_pipeline = Pipeline.from_pretrained(
|
| 739 |
"pyannote/speaker-diarization-3.1",
|
| 740 |
+
token=hf_token,
|
| 741 |
)
|
| 742 |
cls._pyannote_pipeline.to(torch.device(_get_device()))
|
| 743 |
|