File size: 10,325 Bytes
744eb4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
import pathlib
from typing import Optional, List
import transformers
from pointllm.train.pointllm_trainer import PointLLMTrainer
from pointllm import conversation as conversation_lib
from pointllm.model import *
from pointllm.data import make_object_point_data_module
# * logger
from pointllm.utils import build_logger
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<unk>"
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="")
version: Optional[str] = field(default="v1")
@dataclass
class DataArguments:
data_path: str = field(default="ScanNet", metadata={"help": "Path to the training data."})
anno_path: str = field(default=None, metadata={"help": "Path to the utterance data. If None, will use referit3d by defautl."})
use_color: bool = field(default=False, metadata={"help": "Whether to use color."})
data_debug_num: int = field(default=0, metadata={"help": "Number of data to use in debug mode. If larger than 0, use debug mode, else use the whole data"})
split_train_val: bool = field(default=False, metadata={"help": "Whether to split train and val."})
split_ratio: float = field(default=0.9, metadata={"help": "Ratio of train and val."})
pointnum: int = field(default=8192, metadata={"help": "Number of points."})
conversation_types: List[str] = field(default_factory=lambda: ["simple_description"], metadata={"help": "Conversation types to use."})
is_multimodal: bool = True
@dataclass
class TrainingArguments(transformers.TrainingArguments):
# * can refer to https://huggingface.co/docs/transformers/v4.28.1/en/main_classes/trainer#transformers.TrainingArgument
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
model_debug: bool = field(default=False, metadata={"help": "Whether to use small model."}) # * whether to load checkpoints at the mo
fix_llm: bool = field(default=True, metadata={"help": "Whether to fix the LLM."})
fix_pointnet: bool = field(default=True, metadata={"help": "Whether to fix the PointNet."})
remove_unused_columns: bool = field(default=False)
force_fsdp: bool = field(default=False)
# * for two stage training
tune_mm_mlp_adapter: bool = field(default=True) # * set True when pre-training, and false when fine-tuning
stage_2: bool = field(default=False) # * set True when fine-tuning
pretrained_mm_mlp_adapter: Optional[str] = field(default=None) # * path to the pre-trained projector & output_embed & input_embed
detatch_point_token: bool = field(default=False) # * deprecated
# * point backbone ckpt path
point_backbone_ckpt: str = field(default=None)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
training_args.log_level = "info" # * default is passive(warning)
# * build logger
logger = build_logger(__name__, training_args.output_dir + '/train.log')
if training_args.model_debug:
# * do not load checkpoint, load from config
config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
model = PointLLMLlamaForCausalLM._from_config(config)
else:
model = PointLLMLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
model.config.use_cache = False
if training_args.fix_llm:
# * This will fix all the parameters
logger.info("LLM is fixed. Fix_llm flag is set to True")
# * fix llama, lm_head, pointnet, projection layer here
model.requires_grad_(False)
model.get_model().fix_llm = True
model.get_model().point_proj.requires_grad_(True)
model.get_model().point_backbone.requires_grad_(True) # * set as True for fsdp, use fix_pointnet flag to control
else:
model.get_model().fix_llm = False
logger.warning("LLM is trainable. Fix_llm flag is set to False")
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
if model_args.version == "v0" or "v0" in model_args.model_name_or_path:
raise ValueError("v0 is deprecated.")
else:
tokenizer.pad_token = tokenizer.unk_token
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"]
if not training_args.fix_pointnet:
# * not fix pointnet
logger.info("Point backbone is trainable. Fix_pointnet flag is set to False, pointnet grad will be recorded.")
model.get_model().fix_pointnet = False
else:
logger.info("Point backbone is fixed. Fix_pointnet flag is set to True, pointnet grad will not be recorded.")
model.get_model().fix_pointnet = True # * use with torch.inference_mode to control, not requires_grad for fsdp for second stage
if not training_args.stage_2:
logger.info("Set requires_grad of point backbone to False")
model.get_model().point_backbone.requires_grad_(False) # * fix pointnet for first stage, need for fsdp in stage2
if training_args.tune_mm_mlp_adapter:
# * not fix the projection layer
# * may need to set the embed_tokens to require_grad = True if added new tokens
# * this is done in initialize_tokenizer_point_backbone_config
logger.info("Point projection layer is trainable.")
else:
model.get_model().point_proj.requires_grad_(False)
logger.info("Point prejcetion layer is fixed.")
if not training_args.stage_2:
# * we assume in stage2, llm, point_backbone, and projection layer can be loaded from the model checkpoint
print(f"Default point_backbone_ckpt is {training_args.point_backbone_ckpt}.")
model.get_model().load_point_backbone_checkpoint(training_args.point_backbone_ckpt)
model.initialize_tokenizer_point_backbone_config(tokenizer=tokenizer, device=training_args.device, fix_llm=training_args.fix_llm)
else:
# * stage2
model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer=tokenizer)
point_backbone_config = model.get_model().point_backbone_config
data_args.point_token_len = point_backbone_config['point_token_len']
data_args.mm_use_point_start_end = point_backbone_config['mm_use_point_start_end']
data_args.point_backbone_config = point_backbone_config
params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
if len(params_no_grad) > 0:
if training_args.fsdp is not None and len(training_args.fsdp) > 0:
if len(params_no_grad) < 10:
print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
else:
print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
def patch_FSDP_use_orig_params(func):
def wrap_func(*args, **kwargs):
use_orig_params = kwargs.pop('use_orig_params', True)
return func(*args, **kwargs, use_orig_params=use_orig_params)
return wrap_func
FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
data_module = make_object_point_data_module(tokenizer=tokenizer,
data_args=data_args)
trainer = PointLLMTrainer(model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
|