mjschock commited on
Commit
b21080c
·
unverified ·
1 Parent(s): 611c848

Refactor SFTTrainer configuration in train.py to remove data_collator from the SFT config, preventing duplication and enhancing clarity in trainer setup.

Browse files
Files changed (1) hide show
  1. train.py +5 -1
train.py CHANGED
@@ -200,6 +200,10 @@ def create_trainer(
200
  **cfg.training.sft.data_collator,
201
  )
202
 
 
 
 
 
203
  trainer = SFTTrainer(
204
  model=model,
205
  tokenizer=tokenizer,
@@ -207,7 +211,7 @@ def create_trainer(
207
  eval_dataset=dataset["validation"],
208
  args=training_args,
209
  data_collator=data_collator,
210
- **cfg.training.sft,
211
  )
212
  logger.info("Trainer created successfully")
213
  return trainer
 
200
  **cfg.training.sft.data_collator,
201
  )
202
 
203
+ # Create SFT config without data_collator to avoid duplication
204
+ sft_config = OmegaConf.to_container(cfg.training.sft, resolve=True)
205
+ sft_config.pop('data_collator', None) # Remove data_collator from config
206
+
207
  trainer = SFTTrainer(
208
  model=model,
209
  tokenizer=tokenizer,
 
211
  eval_dataset=dataset["validation"],
212
  args=training_args,
213
  data_collator=data_collator,
214
+ **sft_config,
215
  )
216
  logger.info("Trainer created successfully")
217
  return trainer