vstar / VisualSearch /train.py
Penghao Wu
init
3672502
import argparse
import os
import shutil
import sys
import time
from functools import partial
import deepspeed
import torch
import tqdm
import transformers
from peft import LoraConfig, get_peft_model
from torch.utils.tensorboard import SummaryWriter
from VisualSearch.model.VSM import VSMForCausalLM
from VisualSearch.model.llava import conversation as conversation_lib
from VisualSearch.utils.dataset import HybridDataset, ValDataset, collate_fn
from VisualSearch.utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
AverageMeter, ProgressMeter, Summary, dict_to_cuda,
intersectionAndUnionGPU)
def parse_args(args):
parser = argparse.ArgumentParser(description="VisualSearch Model Training")
parser.add_argument("--local_rank", default=0, type=int, help="node rank")
parser.add_argument(
"--version", default="LLaVA-7B-v1.1"
)
parser.add_argument(
"--precision",
default="bf16",
type=str,
choices=["fp32", "bf16", "fp16"],
help="precision for training",
)
parser.add_argument("--model_max_length", default=512, type=int)
parser.add_argument("--lora_r", default=8, type=int)
parser.add_argument(
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
)
parser.add_argument("--load_in_8bit", action="store_true", default=False)
parser.add_argument("--load_in_4bit", action="store_true", default=False)
parser.add_argument(
"--dataset", default="general_segdet||refer_seg||mixed_grounding||vqa", type=str
)
parser.add_argument("--sample_rates", default="15,4,4,15", type=str)
parser.add_argument(
"--general_segdet_data",
default="objects365||cocostuff||paco_lvis",
type=str,
)
parser.add_argument("--general_segdet_sample_rates", default="2,1,1", type=str)
parser.add_argument(
"--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
)
parser.add_argument("--vqa_data", default="possible_locations_conv_86k||llava_instruct_80k", type=str)
parser.add_argument("--vqa_sample_rates", default="2,1", type=str)
parser.add_argument("--val_dataset", default="refcoco|unc|val", type=str)
parser.add_argument("--dataset_dir", default="data", type=str)
parser.add_argument("--log_base_dir", default="./runs", type=str)
parser.add_argument("--exp_name", default="vsm", type=str)
parser.add_argument("--epochs", default=40, type=int)
parser.add_argument("--steps_per_epoch", default=2500, type=int)
parser.add_argument(
"--batch_size", default=4, type=int, help="batch size per device per step"
)
parser.add_argument(
"--grad_accumulation_steps",
default=2,
type=int,
)
parser.add_argument("--val_batch_size", default=1, type=int)
parser.add_argument("--workers", default=2, type=int)
parser.add_argument("--lr", default=0.0001, type=float)
parser.add_argument("--ce_loss_weight", default=1.0, type=float)
parser.add_argument("--dice_loss_weight", default=0.5, type=float)
parser.add_argument("--bce_loss_weight", default=2.0, type=float)
parser.add_argument("--det_loss_weight", default=0.1, type=float)
parser.add_argument("--lora_alpha", default=16, type=int)
parser.add_argument("--lora_dropout", default=0.05, type=float)
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
parser.add_argument("--explanatory", default=0.1, type=float)
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.95, type=float)
parser.add_argument("--num_classes_per_sample", default=3, type=int)
parser.add_argument("--exclude_val", action="store_true", default=False)
parser.add_argument("--no_eval", action="store_true", default=False)
parser.add_argument("--out_dim", default=512, type=int)
parser.add_argument("--weight", type=str)
parser.add_argument("--resume", default="", type=str)
parser.add_argument("--print_freq", default=1, type=int)
parser.add_argument("--start_epoch", default=0, type=int)
parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
parser.add_argument("--train_mask_decoder", action="store_true", default=True)
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
parser.add_argument("--auto_resume", action="store_true", default=False)
parser.add_argument(
"--conv_type",
default="llava_v1",
type=str,
choices=["llava_v1", "llava_llama_2"],
)
return parser.parse_args(args)
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def iou(bbox1, bbox2):
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])
w1 = bbox1[2] - bbox1[0]
h1 = bbox1[3] - bbox1[1]
w2 = bbox2[2] - bbox2[0]
h2 = bbox2[3] - bbox2[1]
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
return inter_area/(w1*h1+w2*h2-inter_area)
def main(args):
args = parse_args(args)
args.log_dir = os.path.join(args.log_base_dir, args.exp_name)
if args.local_rank == 0:
os.makedirs(args.log_dir, exist_ok=True)
writer = SummaryWriter(args.log_dir)
else:
writer = None
# Create model
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.version,
cache_dir=None,
model_max_length=args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
num_added_tokens = tokenizer.add_tokens("[LOC]")
args.loc_token_idx = tokenizer("[LOC]", add_special_tokens=False).input_ids[0]
if args.use_mm_start_end:
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
model_args = {
"train_mask_decoder": args.train_mask_decoder,
"out_dim": args.out_dim,
"ce_loss_weight": args.ce_loss_weight,
"dice_loss_weight": args.dice_loss_weight,
"bce_loss_weight": args.bce_loss_weight,
"det_loss_weight" : args.det_loss_weight,
"loc_token_idx": args.loc_token_idx,
"vision_tower": args.vision_tower,
"use_mm_start_end": args.use_mm_start_end,
}
torch_dtype = torch.float32
if args.precision == "bf16":
torch_dtype = torch.bfloat16
elif args.precision == "fp16":
torch_dtype = torch.half
model = VSMForCausalLM.from_pretrained(
args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch_dtype, device=args.local_rank)
model.get_model().initialize_lisa_modules(model.get_model().config)
for p in vision_tower.parameters():
p.requires_grad = False
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
conversation_lib.default_conversation = conversation_lib.conv_templates[
args.conv_type
]
lora_r = args.lora_r
if lora_r > 0:
def find_linear_layers(model, lora_target_modules):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
and all(
[
x not in name
for x in [
"owlvit",
"visual_projection",
"prompt_encoder",
"mask_decoder",
"vision_tower",
"mm_projector",
"text_hidden_fcs_seg",
"text_hidden_fcs_det",
]
]
)
and any([x in name for x in lora_target_modules])
):
lora_module_names.add(name)
return sorted(list(lora_module_names))
lora_alpha = args.lora_alpha
lora_dropout = args.lora_dropout
lora_target_modules = find_linear_layers(
model, args.lora_target_modules.split(",")
)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model.resize_token_embeddings(len(tokenizer))
# make text_hidden_fcs, mask_decoder, lm_head, embed_tokens trainable
for n, p in model.named_parameters():
if any(
[
x in n
for x in ["lm_head", "embed_tokens", "visual_projection", "prompt_encoder", "mask_decoder", "text_hidden_fcs_seg", "text_hidden_fcs_det", "owlvit.class_head", "owlvit.layer_norm"]
]
):
# print("n: ", n, "p.shape: ", p.shape)
p.requires_grad = True
world_size = torch.cuda.device_count()
print('world_size', world_size)
args.distributed = world_size > 1
train_dataset = HybridDataset(
args.dataset_dir,
tokenizer,
args.vision_tower,
samples_per_epoch=args.batch_size
* args.grad_accumulation_steps
* args.steps_per_epoch
* world_size,
precision=args.precision,
num_classes_per_sample=args.num_classes_per_sample,
exclude_val=args.exclude_val,
dataset=args.dataset,
sample_rate=[float(x) for x in args.sample_rates.split(",")],
general_segdet_data=args.general_segdet_data,
general_segdet_sample_rate=[float(x) for x in args.general_segdet_sample_rates.split(",")],
refer_seg_data=args.refer_seg_data,
vqa_data=args.vqa_data,
vqa_sample_rate=[float(x) for x in args.vqa_sample_rates.split(",")],
)
if args.no_eval == False:
val_dataset = ValDataset(
args.dataset_dir,
tokenizer,
args.vision_tower,
args.val_dataset,
)
print(
f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples."
)
ds_config = {
"train_micro_batch_size_per_gpu": args.batch_size,
"gradient_accumulation_steps": args.grad_accumulation_steps,
"optimizer": {
"type": "AdamW",
"params": {
"lr": args.lr,
"weight_decay": 0.0,
"betas": (args.beta1, args.beta2),
},
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"total_num_steps": args.epochs * args.steps_per_epoch,
"warmup_min_lr": 0,
"warmup_max_lr": args.lr,
"warmup_num_steps": 100,
"warmup_type": "linear",
},
},
"fp16": {
"enabled": args.precision == "fp16",
},
"bf16": {
"enabled": args.precision == "bf16",
},
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8,
},
}
model_engine, optimizer, train_loader, scheduler = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
training_data=train_dataset,
collate_fn=partial(
collate_fn,
tokenizer=tokenizer,
conv_type=args.conv_type,
use_mm_start_end=args.use_mm_start_end,
local_rank=args.local_rank,
),
config=ds_config,
)
# resume deepspeed checkpoint
if args.auto_resume and len(args.resume) == 0:
resume = os.path.join(args.log_dir, "ckpt_model")
if os.path.exists(resume):
args.resume = resume
if args.resume:
load_path, client_state = model_engine.load_checkpoint(args.resume)
with open(os.path.join(args.resume, "latest"), "r") as f:
ckpt_dir = f.readlines()[0].strip()
args.start_epoch = (
int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch
)
print(
"resume training from {}, start from epoch {}".format(
args.resume, args.start_epoch
)
)
# validation dataset
if val_dataset is not None:
assert args.val_batch_size == 1
val_sampler = torch.utils.data.distributed.DistributedSampler(
val_dataset, shuffle=False, drop_last=False
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.val_batch_size,
shuffle=False,
pin_memory=False,
sampler=val_sampler,
collate_fn=partial(
collate_fn,
tokenizer=tokenizer,
conv_type=args.conv_type,
use_mm_start_end=args.use_mm_start_end,
local_rank=args.local_rank,
),
)
train_iter = iter(train_loader)
best_score, cur_ciou, cur_giou = 0.0, 0.0, 0.0
for epoch in range(args.start_epoch, args.epochs):
# train for one epoch
train_iter = train(
train_loader,
model_engine,
epoch,
scheduler,
writer,
train_iter,
args,
)
if args.no_eval == False:
giou, ciou, det_acc = validate(val_loader, model_engine, epoch, writer, args)
is_best = det_acc > best_score
best_score = max(det_acc, best_score)
cur_giou = giou if is_best else cur_giou
cur_ciou = ciou if is_best else cur_ciou
if args.no_eval or is_best:
save_dir = os.path.join(args.log_dir, "ckpt_model")
if args.local_rank == 0:
torch.save(
{"epoch": epoch},
os.path.join(
args.log_dir,
"meta_log_detacc{:.3f}_giou{:.3f}_ciou{:.3f}.pth".format(
best_score, cur_giou, cur_ciou
),
),
)
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
torch.distributed.barrier()
model_engine.save_checkpoint(save_dir)
def train(
train_loader,
model,
epoch,
scheduler,
writer,
train_iter,
args,
):
"""Main training loop."""
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":.4f")
ce_losses = AverageMeter("CeLoss", ":.4f")
mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f")
mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f")
mask_losses = AverageMeter("MaskLoss", ":.4f")
detection_losses = AverageMeter("DetectionLoss", ":.4f")
detection_ce_losses = AverageMeter("DetectionCELoss", ":.4f")
detection_bbox_losses = AverageMeter("DetectionBBOXLoss", ":.4f")
detection_giou_losses = AverageMeter("DetectionGIOULoss", ":.4f")
progress = ProgressMeter(
args.steps_per_epoch,
[
batch_time,
losses,
ce_losses,
mask_losses,
mask_bce_losses,
mask_dice_losses,
detection_losses,
detection_ce_losses,
detection_bbox_losses,
detection_giou_losses
],
prefix="Epoch: [{}]".format(epoch),
)
# switch to train mode
model.train()
end = time.time()
for global_step in range(args.steps_per_epoch):
for i in range(args.grad_accumulation_steps):
try:
input_dict = next(train_iter)
except:
train_iter = iter(train_loader)
input_dict = next(train_iter)
data_time.update(time.time() - end)
input_dict = dict_to_cuda(input_dict)
if args.precision == "fp16":
input_dict["images"] = input_dict["images"].half()
input_dict["images_clip"] = input_dict["images_clip"].half()
elif args.precision == "bf16":
input_dict["images"] = input_dict["images"].bfloat16()
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
else:
input_dict["images"] = input_dict["images"].float()
input_dict["images_clip"] = input_dict["images_clip"].float()
output_dict = model(**input_dict)
loss = output_dict["loss"]
ce_loss = output_dict["ce_loss"]
mask_bce_loss = output_dict["mask_bce_loss"]
mask_dice_loss = output_dict["mask_dice_loss"]
mask_loss = output_dict["mask_loss"]
detection_loss = output_dict['detection_loss']
detection_loss_ce = output_dict['detection_loss_ce']
detection_loss_bbox = output_dict['detection_loss_bbox']
detection_loss_giou = output_dict['detection_loss_giou']
losses.update(loss.item(), 1)
ce_losses.update(ce_loss.item(), 1)
mask_bce_losses.update(mask_bce_loss.item(), 1)
mask_dice_losses.update(mask_dice_loss.item(), 1)
mask_losses.update(mask_loss.item(), 1)
detection_losses.update(detection_loss.item(), 1)
detection_ce_losses.update(detection_loss_ce.item(), 1)
detection_bbox_losses.update(detection_loss_bbox.item(), 1)
detection_giou_losses.update(detection_loss_giou.item(), 1)
model.backward(loss)
model.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if global_step % args.print_freq == 0:
if args.distributed:
batch_time.all_reduce()
data_time.all_reduce()
losses.all_reduce()
ce_losses.all_reduce()
mask_bce_losses.all_reduce()
mask_dice_losses.all_reduce()
mask_losses.all_reduce()
detection_losses.all_reduce()
detection_ce_losses.all_reduce()
detection_bbox_losses.all_reduce()
detection_giou_losses.all_reduce()
if args.local_rank == 0:
progress.display(global_step + 1)
writer.add_scalar("train/loss", losses.avg, global_step+args.steps_per_epoch*epoch)
writer.add_scalar("train/ce_loss", ce_losses.avg, global_step+args.steps_per_epoch*epoch)
writer.add_scalar(
"train/mask_bce_loss", mask_bce_losses.avg, global_step+args.steps_per_epoch*epoch
)
writer.add_scalar(
"train/mask_dice_loss", mask_dice_losses.avg, global_step+args.steps_per_epoch*epoch
)
writer.add_scalar("train/mask_loss", mask_losses.avg, global_step+args.steps_per_epoch*epoch)
writer.add_scalar(
"train/detection_loss", detection_losses.avg, global_step+args.steps_per_epoch*epoch
)
writer.add_scalar(
"train/detection_ce_loss", detection_ce_losses.avg, global_step+args.steps_per_epoch*epoch
)
writer.add_scalar(
"train/detection_bbox_loss", detection_bbox_losses.avg, global_step+args.steps_per_epoch*epoch
)
writer.add_scalar(
"train/detection_giou_loss", detection_giou_losses.avg, global_step+args.steps_per_epoch*epoch
)
writer.add_scalar(
"metrics/total_secs_per_batch", batch_time.avg, global_step+args.steps_per_epoch*epoch
)
writer.add_scalar(
"metrics/data_secs_per_batch", data_time.avg, global_step+args.steps_per_epoch*epoch
)
batch_time.reset()
data_time.reset()
losses.reset()
ce_losses.reset()
mask_bce_losses.reset()
mask_dice_losses.reset()
mask_losses.reset()
detection_losses.reset()
detection_ce_losses.reset()
detection_bbox_losses.reset()
detection_giou_losses.reset()
if global_step != 0:
curr_lr = scheduler.get_last_lr()
if args.local_rank == 0:
writer.add_scalar("train/lr", curr_lr[0], global_step+args.steps_per_epoch*epoch)
return train_iter
def validate(val_loader, model_engine, epoch, writer, args):
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
det_acc_meter = AverageMeter("DetAcc", ":6.3f", Summary.SUM)
model_engine.eval()
for input_dict in tqdm.tqdm(val_loader):
torch.cuda.empty_cache()
input_dict = dict_to_cuda(input_dict)
if args.precision == "fp16":
input_dict["images"] = input_dict["images"].half()
input_dict["images_clip"] = input_dict["images_clip"].half()
elif args.precision == "bf16":
input_dict["images"] = input_dict["images"].bfloat16()
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
else:
input_dict["images"] = input_dict["images"].float()
input_dict["images_clip"] = input_dict["images_clip"].float()
with torch.no_grad():
output_dict = model_engine(**input_dict)
pred_masks = output_dict["pred_masks"]
masks_list = output_dict["gt_masks"][0].int()
output_list = (pred_masks[0] > 0).int()
assert len(pred_masks) == 1
pred_logits = output_dict['pred_logits']
pred_boxes = output_dict['pred_boxes']
gt_bboxes = output_dict['gt_bboxes']
for pred_logits_i, pred_boxes_i, gt_bboxes_i in zip(pred_logits, pred_boxes, gt_bboxes):
top_index = pred_logits_i.view(-1).argmax()
pred_bbox = pred_boxes_i[top_index].view(1, 4)
gt_bbox = gt_bboxes_i.view(1,4)
iou_i = iou(box_cxcywh_to_xyxy(pred_bbox).view(4), box_cxcywh_to_xyxy(gt_bbox).view(4))
det_acc = 1.0 if iou_i > 0.5 else 0.0
det_acc_meter.update(det_acc, 1)
intersection, union, acc_iou = 0.0, 0.0, 0.0
for mask_i, output_i in zip(masks_list, output_list):
intersection_i, union_i, _ = intersectionAndUnionGPU(
output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255
)
intersection += intersection_i
union += union_i
acc_iou += intersection_i / (union_i + 1e-5)
acc_iou[union_i == 0] += 1.0 # no-object target
intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0]
intersection_meter.update(intersection), union_meter.update(
union
), acc_iou_meter.update(acc_iou, n=masks_list.shape[0])
intersection_meter.all_reduce()
union_meter.all_reduce()
acc_iou_meter.all_reduce()
det_acc_meter.all_reduce()
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
ciou = iou_class[1]
giou = acc_iou_meter.avg[1]
det_acc = det_acc_meter.avg
if args.local_rank == 0:
writer.add_scalar("val/giou", giou, epoch)
writer.add_scalar("val/ciou", ciou, epoch)
writer.add_scalar("val/det_acc", det_acc, epoch)
print("giou: {:.4f}, ciou: {:.4f}, det_acc: {:.4f}".format(giou, ciou, det_acc))
return giou, ciou, det_acc
if __name__ == "__main__":
main(sys.argv[1:])