Spaces:
Paused
Paused
import argparse | |
import glob | |
import os | |
import sys | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import transformers | |
from peft import LoraConfig, get_peft_model | |
from transformers import AutoTokenizer | |
from model.LISA import LISAForCausalLM | |
from utils.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN | |
def parse_args(args): | |
parser = argparse.ArgumentParser( | |
description="merge lora weights and save model with hf format" | |
) | |
parser.add_argument( | |
"--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview" | |
) | |
parser.add_argument("--vis_save_path", default="./vis_output", type=str) | |
parser.add_argument( | |
"--precision", | |
default="bf16", | |
type=str, | |
choices=["fp32", "bf16", "fp16"], | |
help="precision for inference", | |
) | |
parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str) | |
parser.add_argument("--out_dim", default=256, type=int) | |
parser.add_argument("--image_size", default=1024, type=int, help="image size") | |
parser.add_argument("--model_max_length", default=512, type=int) | |
parser.add_argument( | |
"--vision-tower", default="openai/clip-vit-large-patch14", type=str | |
) | |
parser.add_argument("--lora_r", default=8, type=int) | |
parser.add_argument("--lora_alpha", default=16, type=int) | |
parser.add_argument("--lora_dropout", default=0.05, type=float) | |
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str) | |
parser.add_argument("--local-rank", default=0, type=int, help="node rank") | |
parser.add_argument("--train_mask_decoder", action="store_true", default=True) | |
parser.add_argument("--use_mm_start_end", action="store_true", default=True) | |
parser.add_argument( | |
"--conv_type", | |
default="llava_v1", | |
type=str, | |
choices=["llava_v1", "llava_llama_2"], | |
) | |
parser.add_argument("--weight", default="", type=str, required=True) | |
parser.add_argument("--save_path", default="./lisa_model", type=str, required=True) | |
return parser.parse_args(args) | |
def main(args): | |
args = parse_args(args) | |
os.makedirs(args.vis_save_path, exist_ok=True) | |
# Create model | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
args.version, | |
cache_dir=None, | |
model_max_length=args.model_max_length, | |
padding_side="right", | |
use_fast=False, | |
) | |
tokenizer.pad_token = tokenizer.unk_token | |
num_added_tokens = tokenizer.add_tokens("[SEG]") | |
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] | |
if args.use_mm_start_end: | |
tokenizer.add_tokens( | |
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True | |
) | |
model_args = { | |
"train_mask_decoder": args.train_mask_decoder, | |
"out_dim": args.out_dim, | |
"seg_token_idx": args.seg_token_idx, | |
"vision_tower": args.vision_tower, | |
} | |
torch_dtype = torch.float32 | |
if args.precision == "bf16": | |
torch_dtype = torch.bfloat16 | |
elif args.precision == "fp16": | |
torch_dtype = torch.half | |
model = LISAForCausalLM.from_pretrained( | |
args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args | |
) | |
model.config.eos_token_id = tokenizer.eos_token_id | |
model.config.bos_token_id = tokenizer.bos_token_id | |
model.config.pad_token_id = tokenizer.pad_token_id | |
model.get_model().initialize_vision_modules(model.get_model().config) | |
vision_tower = model.get_model().get_vision_tower() | |
vision_tower.to(dtype=torch_dtype) | |
model.get_model().initialize_lisa_modules(model.get_model().config) | |
lora_r = args.lora_r | |
if lora_r > 0: | |
def find_linear_layers(model, lora_target_modules): | |
cls = torch.nn.Linear | |
lora_module_names = set() | |
for name, module in model.named_modules(): | |
if ( | |
isinstance(module, cls) | |
and all( | |
[ | |
x not in name | |
for x in [ | |
"visual_model", | |
"vision_tower", | |
"mm_projector", | |
"text_hidden_fcs", | |
] | |
] | |
) | |
and any([x in name for x in lora_target_modules]) | |
): | |
lora_module_names.add(name) | |
return sorted(list(lora_module_names)) | |
lora_alpha = args.lora_alpha | |
lora_dropout = args.lora_dropout | |
lora_target_modules = find_linear_layers( | |
model, args.lora_target_modules.split(",") | |
) | |
lora_config = LoraConfig( | |
r=lora_r, | |
lora_alpha=lora_alpha, | |
target_modules=lora_target_modules, | |
lora_dropout=lora_dropout, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
model = get_peft_model(model, lora_config) | |
model.print_trainable_parameters() | |
model.resize_token_embeddings(len(tokenizer)) | |
state_dict = torch.load(args.weight, map_location="cpu") | |
model.load_state_dict(state_dict, strict=True) | |
model = model.merge_and_unload() | |
state_dict = {} | |
for k, v in model.state_dict().items(): | |
if "vision_tower" not in k: | |
state_dict[k] = v | |
model.save_pretrained(args.save_path, state_dict=state_dict) | |
tokenizer.save_pretrained(args.save_path) | |
if __name__ == "__main__": | |
main(sys.argv[1:]) | |