# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field import logging import pathlib import typing from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from peft import LoraConfig, get_peft_model import transformers from transformers import Trainer from fastchat.train.train import ( DataArguments, ModelArguments, TrainingArguments, make_supervised_data_module, ) from fastchat.train.llama_flash_attn_monkey_patch import ( replace_llama_attn_with_flash_attn, ) replace_llama_attn_with_flash_attn() @dataclass class LoraArguments: lora_r: int = 8 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_target_modules: typing.List[str] = field( default_factory=lambda: ["q_proj", "v_proj"] ) lora_weight_path: str = "" bias: str = "none" def maybe_zero_3(param): if hasattr(param, "ds_id"): assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE with zero.GatheredParameters([param]): param = param.data.cpu().clone().detach() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(state_dict, bias): if bias == "none": to_return = { k: state_dict[k].cpu().clone().detach() for k in state_dict if "lora_" in k } elif bias == "all": to_return = { k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k } elif bias == "lora_only": to_return = {} for k in state_dict: if "lora_" in k: to_return[k] = state_dict[k] bias_name = k.split("lora_")[0] + "bias" if bias_name in state_dict: to_return[bias_name] = state_dict[bias_name] else: raise NotImplementedError to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} return to_return def train(): parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments, LoraArguments) ) ( model_args, data_args, training_args, lora_args, ) = parser.parse_args_into_dataclasses() model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, ) lora_config = LoraConfig( r=lora_args.lora_r, lora_alpha=lora_args.lora_alpha, target_modules=lora_args.lora_target_modules, lora_dropout=lora_args.lora_dropout, bias=lora_args.bias, task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) if training_args.deepspeed is not None and training_args.local_rank == 0: model.print_trainable_parameters() if training_args.gradient_checkpointing: logging.warning( "gradient checkpointing with lora makes requires_grad " "incorrect and needs a monkey patch in Trainer or the " "wrapped model's forward. ref: " "https://github.com/lm-sys/FastChat/pull/138#issuecomment-1509172198" ) tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = Trainer( model=model, tokenizer=tokenizer, args=training_args, **data_module ) model.config.use_cache = False if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() # Save states. Weights might be a placeholder in zero3 and need a gather state_dict = get_peft_state_maybe_zero_3(model.state_dict(), lora_args.bias) if training_args.local_rank == 0: model.save_pretrained(training_args.output_dir, state_dict=state_dict) if __name__ == "__main__": train()