mazesmazes commited on
Commit
a24fa14
·
verified ·
1 Parent(s): 2acae26

Training in progress - step 364

Browse files
Files changed (3) hide show
  1. asr_config.py +8 -0
  2. asr_modeling.py +6 -6
  3. diarization.py +1 -1
asr_config.py CHANGED
@@ -63,6 +63,10 @@ class ASRConfig(transformers.PretrainedConfig):
63
  lora_dropout: float = 0.0,
64
  lora_target_modules: Optional[list] = None, # Default: all linear layers
65
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
 
 
 
 
66
  max_new_tokens: Optional[int] = None,
67
  min_new_tokens: Optional[int] = None,
68
  repetition_penalty: Optional[float] = None,
@@ -169,6 +173,10 @@ class ASRConfig(transformers.PretrainedConfig):
169
  else generation_defaults["no_repeat_ngram_size"]
170
  )
171
  self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
 
 
 
 
172
 
173
  if "audio_config" not in kwargs:
174
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
 
63
  lora_dropout: float = 0.0,
64
  lora_target_modules: Optional[list] = None, # Default: all linear layers
65
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
66
+ do_sample: bool = False,
67
+ temperature: Optional[float] = None,
68
+ top_p: Optional[float] = None,
69
+ top_k: Optional[int] = None,
70
  max_new_tokens: Optional[int] = None,
71
  min_new_tokens: Optional[int] = None,
72
  repetition_penalty: Optional[float] = None,
 
173
  else generation_defaults["no_repeat_ngram_size"]
174
  )
175
  self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
176
+ self.do_sample = do_sample
177
+ self.temperature = temperature
178
+ self.top_p = top_p
179
+ self.top_k = top_k
180
 
181
  if "audio_config" not in kwargs:
182
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
asr_modeling.py CHANGED
@@ -136,11 +136,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
136
  self.generation_config.max_new_tokens = config.max_new_tokens
137
  self.generation_config.min_new_tokens = config.min_new_tokens
138
  self.generation_config.num_beams = config.num_beams
139
- self.generation_config.do_sample = False
140
- # Clear sampling params (inherited from LLM) since we use greedy decoding
141
- self.generation_config.temperature = None
142
- self.generation_config.top_p = None
143
- self.generation_config.top_k = None
144
  self.generation_config.use_cache = config.use_cache
145
  self.generation_config.length_penalty = config.length_penalty
146
  self.generation_config.repetition_penalty = config.repetition_penalty
@@ -730,7 +730,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
730
  tokenize=True,
731
  add_generation_prompt=True,
732
  return_tensors="pt",
733
- enable_thinking=False,
734
  ).to(device)
735
 
736
  if input_ids.dim() == 1:
 
136
  self.generation_config.max_new_tokens = config.max_new_tokens
137
  self.generation_config.min_new_tokens = config.min_new_tokens
138
  self.generation_config.num_beams = config.num_beams
139
+ self.generation_config.do_sample = config.do_sample
140
+ # Set sampling params from config (None means use model defaults)
141
+ self.generation_config.temperature = config.temperature
142
+ self.generation_config.top_p = config.top_p
143
+ self.generation_config.top_k = config.top_k
144
  self.generation_config.use_cache = config.use_cache
145
  self.generation_config.length_penalty = config.length_penalty
146
  self.generation_config.repetition_penalty = config.repetition_penalty
 
730
  tokenize=True,
731
  add_generation_prompt=True,
732
  return_tensors="pt",
733
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
734
  ).to(device)
735
 
736
  if input_ids.dim() == 1:
diarization.py CHANGED
@@ -737,7 +737,7 @@ class SpeakerDiarizer:
737
 
738
  cls._pyannote_pipeline = Pipeline.from_pretrained(
739
  "pyannote/speaker-diarization-3.1",
740
- use_auth_token=hf_token,
741
  )
742
  cls._pyannote_pipeline.to(torch.device(_get_device()))
743
 
 
737
 
738
  cls._pyannote_pipeline = Pipeline.from_pretrained(
739
  "pyannote/speaker-diarization-3.1",
740
+ token=hf_token,
741
  )
742
  cls._pyannote_pipeline.to(torch.device(_get_device()))
743