Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| import torch | |
| from peft import PeftConfig, PeftModel | |
| from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser | |
| class ScriptArguments: | |
| """ | |
| The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the | |
| merged model. | |
| """ | |
| adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"}) | |
| base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"}) | |
| output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"}) | |
| parser = HfArgumentParser(ScriptArguments) | |
| script_args = parser.parse_args_into_dataclasses()[0] | |
| assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge" | |
| assert script_args.base_model_name is not None, "please provide the name of the Base model" | |
| assert script_args.output_name is not None, "please provide the output name of the merged model" | |
| peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name) | |
| if peft_config.task_type == "SEQ_CLS": | |
| # The sequence classification task is used for the reward model in PPO | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16 | |
| ) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name) | |
| # Load the PEFT model | |
| model = PeftModel.from_pretrained(model, script_args.adapter_model_name) | |
| model.eval() | |
| model = model.merge_and_unload() | |
| model.save_pretrained(f"{script_args.output_name}") | |
| tokenizer.save_pretrained(f"{script_args.output_name}") | |
| model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False) | |