lightita commited on
Commit
9a5ab5c
·
verified ·
1 Parent(s): a53b482

Update train_seallm_khm_sum.py

Browse files
Files changed (1) hide show
  1. train_seallm_khm_sum.py +6 -3
train_seallm_khm_sum.py CHANGED
@@ -86,11 +86,13 @@ def load_model_and_tokenizer():
86
  trust_remote_code=True,
87
  )
88
 
89
- model.gradient_checkpointing_enable()
 
90
 
91
  return model, tokenizer
92
 
93
 
 
94
  def main():
95
  train_ds, eval_ds = load_khm_dataset()
96
  model, tokenizer = load_model_and_tokenizer()
@@ -149,10 +151,11 @@ def main():
149
  save_total_limit=2,
150
  lr_scheduler_type="cosine",
151
  warmup_ratio=0.03,
152
- fp16=True, # safer for old transformers
153
- report_to="none", # remove if this crashes
154
  )
155
 
 
156
  trainer = Trainer(
157
  model=model,
158
  args=training_args,
 
86
  trust_remote_code=True,
87
  )
88
 
89
+ # Disable gradient checkpointing; old transformers breaks autograd here
90
+ # model.gradient_checkpointing_enable()
91
 
92
  return model, tokenizer
93
 
94
 
95
+
96
  def main():
97
  train_ds, eval_ds = load_khm_dataset()
98
  model, tokenizer = load_model_and_tokenizer()
 
151
  save_total_limit=2,
152
  lr_scheduler_type="cosine",
153
  warmup_ratio=0.03,
154
+ fp16=False, # turn off mixed precision for CPU
155
+ report_to="none",
156
  )
157
 
158
+
159
  trainer = Trainer(
160
  model=model,
161
  args=training_args,