File size: 4,958 Bytes
6ef31de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG>
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
import logging
import pathlib
import typing
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from peft import LoraConfig, get_peft_model
import transformers
from transformers import Trainer
from fastchat.train.train import (
DataArguments,
ModelArguments,
TrainingArguments,
make_supervised_data_module,
)
from fastchat.train.llama_flash_attn_monkey_patch import (
replace_llama_attn_with_flash_attn,
)
replace_llama_attn_with_flash_attn()
@dataclass
class LoraArguments:
lora_r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_target_modules: typing.List[str] = field(
default_factory=lambda: ["q_proj", "v_proj"]
)
lora_weight_path: str = ""
bias: str = "none"
def maybe_zero_3(param):
if hasattr(param, "ds_id"):
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with zero.GatheredParameters([param]):
param = param.data.cpu().clone().detach()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(state_dict, bias):
if bias == "none":
to_return = {
k: state_dict[k].cpu().clone().detach() for k in state_dict if "lora_" in k
}
elif bias == "all":
to_return = {
k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k
}
elif bias == "lora_only":
to_return = {}
for k in state_dict:
if "lora_" in k:
to_return[k] = state_dict[k]
bias_name = k.split("lora_")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
return to_return
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
)
(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.bias,
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
if training_args.deepspeed is not None and training_args.local_rank == 0:
model.print_trainable_parameters()
if training_args.gradient_checkpointing:
logging.warning(
"gradient checkpointing with lora makes requires_grad "
"incorrect and needs a monkey patch in Trainer or the "
"wrapped model's forward. ref: "
"https://github.com/lm-sys/FastChat/pull/138#issuecomment-1509172198"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)
model.config.use_cache = False
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
# Save states. Weights might be a placeholder in zero3 and need a gather
state_dict = get_peft_state_maybe_zero_3(model.state_dict(), lora_args.bias)
if training_args.local_rank == 0:
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
if __name__ == "__main__":
train()
|