Update src/merge_lora.py
Browse files- src/merge_lora.py +6 -6
src/merge_lora.py
CHANGED
@@ -9,17 +9,17 @@ from utils.utils import print_arguments, add_arguments
|
|
9 |
|
10 |
parser = argparse.ArgumentParser(description=__doc__)
|
11 |
add_arg = functools.partial(add_arguments, argparser=parser)
|
12 |
-
add_arg("lora_model", type=str, default="output/whisper-tiny/checkpoint-best/", help="
|
13 |
-
add_arg('output_dir', type=str, default='models/', help="
|
14 |
-
add_arg("local_files_only", type=bool, default=False, help="
|
15 |
args = parser.parse_args()
|
16 |
print_arguments(args)
|
17 |
|
18 |
-
assert os.path.exists(args.lora_model), f"
|
19 |
|
20 |
peft_config = PeftConfig.from_pretrained(args.lora_model)
|
21 |
#
|
22 |
-
base_model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, device_map={"": "
|
23 |
local_files_only=args.local_files_only)
|
24 |
|
25 |
model = PeftModel.from_pretrained(base_model, args.lora_model, local_files_only=args.local_files_only)
|
@@ -41,4 +41,4 @@ model.save_pretrained(save_directory)
|
|
41 |
feature_extractor.save_pretrained(save_directory)
|
42 |
tokenizer.save_pretrained(save_directory)
|
43 |
processor.save_pretrained(save_directory)
|
44 |
-
print(f'
|
|
|
9 |
|
10 |
parser = argparse.ArgumentParser(description=__doc__)
|
11 |
add_arg = functools.partial(add_arguments, argparser=parser)
|
12 |
+
add_arg("lora_model", type=str, default="output/whisper-tiny/checkpoint-best/", help="model directory")
|
13 |
+
add_arg('output_dir', type=str, default='models/', help="output directory")
|
14 |
+
add_arg("local_files_only", type=bool, default=False, help="Choose local file , if not , find the model on HF")
|
15 |
args = parser.parse_args()
|
16 |
print_arguments(args)
|
17 |
|
18 |
+
assert os.path.exists(args.lora_model), f"model file{args.lora_model}does not exist"
|
19 |
|
20 |
peft_config = PeftConfig.from_pretrained(args.lora_model)
|
21 |
#
|
22 |
+
base_model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, device_map={"": "cuda"},
|
23 |
local_files_only=args.local_files_only)
|
24 |
|
25 |
model = PeftModel.from_pretrained(base_model, args.lora_model, local_files_only=args.local_files_only)
|
|
|
41 |
feature_extractor.save_pretrained(save_directory)
|
42 |
tokenizer.save_pretrained(save_directory)
|
43 |
processor.save_pretrained(save_directory)
|
44 |
+
print(f'Merged model save at :{save_directory}')
|