vstar / VisualSearch /merge_lora_weights_and_save_hf_model.py
Penghao Wu
init
3672502
import argparse
import os
import sys
import torch
import transformers
from peft import LoraConfig, get_peft_model
from VisualSearch.model.VSM import VSMForCausalLM
from VisualSearch.utils.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN
def parse_args(args):
parser = argparse.ArgumentParser(
description="merge lora weights and save model with hf format"
)
parser.add_argument(
"--version", default="LLaVA-7B-v1.1"
)
parser.add_argument(
"--precision",
default="bf16",
type=str,
choices=["fp32", "bf16", "fp16"],
help="precision for inference",
)
parser.add_argument("--out_dim", default=512, type=int)
parser.add_argument("--image_size", default=1024, type=int, help="image size")
parser.add_argument("--model_max_length", default=512, type=int)
parser.add_argument(
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
)
parser.add_argument("--lora_r", default=8, type=int)
parser.add_argument("--lora_alpha", default=16, type=int)
parser.add_argument("--lora_dropout", default=0.05, type=float)
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
parser.add_argument("--train_mask_decoder", action="store_true", default=True)
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
parser.add_argument(
"--conv_type",
default="llava_v1",
type=str,
choices=["llava_v1", "llava_llama_2"],
)
parser.add_argument("--weight", default="./runs/vsm/pytorch_model.bin", type=str)
parser.add_argument("--save_path", default="./seal_vsm_7b", type=str)
return parser.parse_args(args)
def main(args):
args = parse_args(args)
# Create model
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.version,
cache_dir=None,
model_max_length=args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
num_added_tokens = tokenizer.add_tokens("[LOC]")
args.loc_token_idx = tokenizer("[LOC]", add_special_tokens=False).input_ids[0]
if args.use_mm_start_end:
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
model_args = {
"train_mask_decoder": args.train_mask_decoder,
"out_dim": args.out_dim,
"loc_token_idx": args.loc_token_idx,
"vision_tower": args.vision_tower,
}
torch_dtype = torch.float32
if args.precision == "bf16":
torch_dtype = torch.bfloat16
elif args.precision == "fp16":
torch_dtype = torch.half
model = VSMForCausalLM.from_pretrained(
args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch_dtype)
model.get_model().initialize_lisa_modules(model.get_model().config)
lora_r = args.lora_r
if lora_r > 0:
def find_linear_layers(model, lora_target_modules):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
and all(
[
x not in name
for x in [
"owlvit",
"visual_projection",
"prompt_encoder",
"mask_decoder",
"vision_tower",
"mm_projector",
"text_hidden_fcs_seg",
"text_hidden_fcs_det",
]
]
)
and any([x in name for x in lora_target_modules])
):
lora_module_names.add(name)
return sorted(list(lora_module_names))
lora_alpha = args.lora_alpha
lora_dropout = args.lora_dropout
lora_target_modules = find_linear_layers(
model, args.lora_target_modules.split(",")
)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model.resize_token_embeddings(len(tokenizer))
state_dict = torch.load(args.weight, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model = model.merge_and_unload()
state_dict = {}
for k, v in model.state_dict().items():
if "vision_tower" not in k:
state_dict[k] = v
model.save_pretrained(args.save_path, state_dict=state_dict)
tokenizer.save_pretrained(args.save_path)
if __name__ == "__main__":
main(sys.argv[1:])