winglian commited on
Commit
902dd0a
1 Parent(s): 80b2ed2

fix issue with completed model being empty

Browse files

see https://github.com/huggingface/peft/issues/286#issuecomment-1501617281

Files changed (1) hide show
  1. scripts/finetune.py +3 -7
scripts/finetune.py CHANGED
@@ -369,22 +369,18 @@ def train(
369
  )
370
  model.config.use_cache = False
371
 
372
- old_state_dict = model.state_dict
373
- model.state_dict = (
374
- lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
375
- ).__get__(model, type(model))
376
-
377
  if torch.__version__ >= "2" and sys.platform != "win32":
378
  model = torch.compile(model)
379
 
 
 
 
380
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
381
  signal.signal(
382
  signal.SIGINT,
383
  lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
384
  )
385
 
386
- # go ahead and presave the adapter config
387
- lora_config.save_pretrained(cfg.output_dir)
388
  trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
389
 
390
  model.save_pretrained(cfg.output_dir)
 
369
  )
370
  model.config.use_cache = False
371
 
 
 
 
 
 
372
  if torch.__version__ >= "2" and sys.platform != "win32":
373
  model = torch.compile(model)
374
 
375
+ # go ahead and presave, so we have the adapter config available to inspect
376
+ lora_config.save_pretrained(cfg.output_dir)
377
+
378
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
379
  signal.signal(
380
  signal.SIGINT,
381
  lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
382
  )
383
 
 
 
384
  trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
385
 
386
  model.save_pretrained(cfg.output_dir)