DuyTa commited on
Commit
5819083
1 Parent(s): d601d6f

Update src/merge_lora.py

Browse files
Files changed (1) hide show
  1. src/merge_lora.py +6 -6
src/merge_lora.py CHANGED
@@ -9,17 +9,17 @@ from utils.utils import print_arguments, add_arguments
9
 
10
  parser = argparse.ArgumentParser(description=__doc__)
11
  add_arg = functools.partial(add_arguments, argparser=parser)
12
- add_arg("lora_model", type=str, default="output/whisper-tiny/checkpoint-best/", help="微调保存的模型路径")
13
- add_arg('output_dir', type=str, default='models/', help="合并模型的保存目录")
14
- add_arg("local_files_only", type=bool, default=False, help="是否只在本地加载模型,不尝试下载")
15
  args = parser.parse_args()
16
  print_arguments(args)
17
 
18
- assert os.path.exists(args.lora_model), f"模型文件{args.lora_model}不存在"
19
 
20
  peft_config = PeftConfig.from_pretrained(args.lora_model)
21
  #
22
- base_model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, device_map={"": "cpu"},
23
  local_files_only=args.local_files_only)
24
 
25
  model = PeftModel.from_pretrained(base_model, args.lora_model, local_files_only=args.local_files_only)
@@ -41,4 +41,4 @@ model.save_pretrained(save_directory)
41
  feature_extractor.save_pretrained(save_directory)
42
  tokenizer.save_pretrained(save_directory)
43
  processor.save_pretrained(save_directory)
44
- print(f'合并模型保持在:{save_directory}')
 
9
 
10
  parser = argparse.ArgumentParser(description=__doc__)
11
  add_arg = functools.partial(add_arguments, argparser=parser)
12
+ add_arg("lora_model", type=str, default="output/whisper-tiny/checkpoint-best/", help="model directory")
13
+ add_arg('output_dir', type=str, default='models/', help="output directory")
14
+ add_arg("local_files_only", type=bool, default=False, help="Choose local file , if not , find the model on HF")
15
  args = parser.parse_args()
16
  print_arguments(args)
17
 
18
+ assert os.path.exists(args.lora_model), f"model file{args.lora_model}does not exist"
19
 
20
  peft_config = PeftConfig.from_pretrained(args.lora_model)
21
  #
22
+ base_model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, device_map={"": "cuda"},
23
  local_files_only=args.local_files_only)
24
 
25
  model = PeftModel.from_pretrained(base_model, args.lora_model, local_files_only=args.local_files_only)
 
41
  feature_extractor.save_pretrained(save_directory)
42
  tokenizer.save_pretrained(save_directory)
43
  processor.save_pretrained(save_directory)
44
+ print(f'Merged model save at :{save_directory}')