lisa-on-cuda / train_ds.py
x-lai
Update new download links for refCOCO series data
1e37883
raw
history blame
16.6 kB
import argparse
import os
import shutil
import sys
import time
from functools import partial
import deepspeed
import numpy as np
import torch
import tqdm
import transformers
from torch.utils.tensorboard import SummaryWriter
from model.LISA import LISA
from utils.dataset import HybridDataset, ValDataset, collate_fn
from utils.utils import (
AverageMeter,
ProgressMeter,
Summary,
dict_to_cuda,
intersectionAndUnionGPU,
)
def parse_args(args):
parser = argparse.ArgumentParser(description="LISA Model Training")
parser.add_argument("--local_rank", default=0, type=int, help="node rank")
parser.add_argument(
"--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
)
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
parser.add_argument(
"--precision",
default="bf16",
type=str,
choices=["fp32", "bf16", "fp16"],
help="precision for inference",
)
parser.add_argument("--image_size", default=1024, type=int, help="image size")
parser.add_argument("--model_max_length", default=512, type=int)
parser.add_argument("--lora_r", default=8, type=int)
parser.add_argument(
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
)
parser.add_argument("--load_in_8bit", action="store_true", default=False)
parser.add_argument("--load_in_4bit", action="store_true", default=False)
parser.add_argument(
"--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str
)
parser.add_argument("--sample_rates", default="9,3,3,1", type=str)
parser.add_argument(
"--sem_seg_data",
default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary",
type=str,
)
parser.add_argument(
"--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
)
parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
parser.add_argument("--dataset_dir", default="./dataset", type=str)
parser.add_argument("--log_base_dir", default="./runs", type=str)
parser.add_argument("--exp_name", default="lisa", type=str)
parser.add_argument("--epochs", default=10, type=int)
parser.add_argument("--steps_per_epoch", default=500, type=int)
parser.add_argument(
"--batch_size", default=2, type=int, help="batch size per device per step"
)
parser.add_argument(
"--grad_accumulation_steps",
default=10,
type=int,
)
parser.add_argument("--val_batch_size", default=1, type=int)
parser.add_argument("--workers", default=4, type=int)
parser.add_argument("--lr", default=0.0003, type=float)
parser.add_argument("--ce_loss_weight", default=1.0, type=float)
parser.add_argument("--dice_loss_weight", default=0.5, type=float)
parser.add_argument("--bce_loss_weight", default=2.0, type=float)
parser.add_argument("--lora_alpha", default=16, type=int)
parser.add_argument("--lora_dropout", default=0.05, type=float)
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
parser.add_argument("--explanatory", default=0.1, type=float)
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.95, type=float)
parser.add_argument("--num_classes_per_sample", default=3, type=int)
parser.add_argument("--exclude_val", action="store_true", default=False)
parser.add_argument("--no_eval", action="store_true", default=False)
parser.add_argument("--eval_only", action="store_true", default=False)
parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
parser.add_argument("--weight", default="", type=str)
parser.add_argument("--print_freq", default=1, type=int)
parser.add_argument("--start_epoch", default=0, type=int)
return parser.parse_args(args)
def main(args):
args = parse_args(args)
args.log_dir = os.path.join(args.log_base_dir, args.exp_name)
if args.local_rank == 0:
os.makedirs(args.log_dir, exist_ok=True)
writer = SummaryWriter(args.log_dir)
else:
writer = None
# Create model
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.version,
cache_dir=None,
model_max_length=args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
num_added_tokens = tokenizer.add_tokens("[SEG]")
ret_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids
args.seg_token_idx = ret_token_idx[0]
model = LISA(
args.local_rank,
args.seg_token_idx,
tokenizer,
args.version,
args.lora_r,
args.precision,
vision_tower=args.vision_tower,
load_in_8bit=args.load_in_8bit,
load_in_4bit=args.load_in_4bit,
ce_loss_weight=args.ce_loss_weight,
dice_loss_weight=args.dice_loss_weight,
bce_loss_weight=args.bce_loss_weight,
vision_pretrained=args.vision_pretrained,
)
if args.weight:
state_dict = torch.load(args.weight, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
world_size = torch.cuda.device_count()
args.distributed = world_size > 1
train_dataset = HybridDataset(
args.dataset_dir,
tokenizer,
args.vision_tower,
samples_per_epoch=args.batch_size
* args.grad_accumulation_steps
* args.steps_per_epoch
* world_size,
precision=args.precision,
image_size=args.image_size,
num_classes_per_sample=args.num_classes_per_sample,
exclude_val=args.exclude_val,
dataset=args.dataset,
sample_rate=[float(x) for x in args.sample_rates.split(",")],
sem_seg_data=args.sem_seg_data,
refer_seg_data=args.refer_seg_data,
vqa_data=args.vqa_data,
reason_seg_data=args.reason_seg_data,
explanatory=args.explanatory,
)
if args.no_eval == False:
val_dataset = ValDataset(
args.dataset_dir,
tokenizer,
args.vision_tower,
args.val_dataset,
args.image_size,
)
print(
f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples."
)
else:
val_dataset = None
print(f"Training with {len(train_dataset)} examples.")
ds_config = {
"train_micro_batch_size_per_gpu": args.batch_size,
"gradient_accumulation_steps": args.grad_accumulation_steps,
"optimizer": {
"type": "AdamW",
"params": {
"lr": args.lr,
"weight_decay": 0.0,
"betas": (args.beta1, args.beta2),
},
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"total_num_steps": args.epochs * args.steps_per_epoch,
"warmup_min_lr": 0,
"warmup_max_lr": args.lr,
"warmup_num_steps": 100,
"warmup_type": "linear",
},
},
"fp16": {
"enabled": args.precision == "fp16",
},
"bf16": {
"enabled": args.precision == "bf16",
},
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8,
},
}
model_engine, optimizer, train_loader, scheduler = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
training_data=train_dataset,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
config=ds_config,
)
if val_dataset is not None:
assert args.val_batch_size == 1
val_sampler = torch.utils.data.distributed.DistributedSampler(
val_dataset, shuffle=False, drop_last=False
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.val_batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=False,
sampler=val_sampler,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
)
train_iter = iter(train_loader)
best_score, cur_ciou = 0.0, 0.0
if args.eval_only:
giou, ciou = validate(val_loader, model_engine, 0, writer, args)
exit()
for epoch in range(args.start_epoch, args.epochs):
# train for one epoch
train_iter = train(
train_loader,
model_engine,
epoch,
scheduler,
writer,
train_iter,
args,
)
if args.no_eval == False:
giou, ciou = validate(val_loader, model_engine, epoch, writer, args)
is_best = giou > best_score
best_score = max(giou, best_score)
cur_ciou = ciou if is_best else cur_ciou
if args.no_eval or is_best:
save_dir = os.path.join(args.log_dir, "ckpt_model")
if args.local_rank == 0:
torch.save(
{"epoch": epoch},
os.path.join(
args.log_dir,
"meta_log_giou{:.3f}_ciou{:.3f}.pth".format(
best_score, cur_ciou
),
),
)
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
torch.distributed.barrier()
model_engine.save_checkpoint(save_dir)
def train(
train_loader,
model,
epoch,
scheduler,
writer,
train_iter,
args,
):
"""Main training loop."""
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":.4f")
ce_losses = AverageMeter("CeLoss", ":.4f")
mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f")
mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f")
mask_losses = AverageMeter("MaskLoss", ":.4f")
progress = ProgressMeter(
args.steps_per_epoch,
[
batch_time,
losses,
ce_losses,
mask_losses,
mask_bce_losses,
mask_dice_losses,
],
prefix="Epoch: [{}]".format(epoch),
)
# switch to train mode
model.train()
end = time.time()
for global_step in range(args.steps_per_epoch):
for i in range(args.grad_accumulation_steps):
try:
input_dict = next(train_iter)
except:
train_iter = iter(train_loader)
input_dict = next(train_iter)
data_time.update(time.time() - end)
input_dict = dict_to_cuda(input_dict)
if args.precision == "fp16":
input_dict["images"] = input_dict["images"].half()
input_dict["images_clip"] = input_dict["images_clip"].half()
elif args.precision == "bf16":
input_dict["images"] = input_dict["images"].bfloat16()
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
else:
input_dict["images"] = input_dict["images"].float()
input_dict["images_clip"] = input_dict["images_clip"].float()
output_dict = model(**input_dict)
loss = output_dict["loss"]
ce_loss = output_dict["ce_loss"]
mask_bce_loss = output_dict["mask_bce_loss"]
mask_dice_loss = output_dict["mask_dice_loss"]
mask_loss = output_dict["mask_loss"]
losses.update(loss.item(), input_dict["images"].size(0))
ce_losses.update(ce_loss.item(), input_dict["images"].size(0))
mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0))
mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0))
mask_losses.update(mask_loss.item(), input_dict["images"].size(0))
model.backward(loss)
model.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if global_step % args.print_freq == 0:
if args.distributed:
batch_time.all_reduce()
data_time.all_reduce()
losses.all_reduce()
ce_losses.all_reduce()
mask_bce_losses.all_reduce()
mask_dice_losses.all_reduce()
mask_losses.all_reduce()
if args.local_rank == 0:
progress.display(global_step + 1)
writer.add_scalar("train/loss", losses.avg, global_step)
writer.add_scalar("train/ce_loss", ce_losses.avg, global_step)
writer.add_scalar(
"train/mask_bce_loss", mask_bce_losses.avg, global_step
)
writer.add_scalar(
"train/mask_dice_loss", mask_dice_losses.avg, global_step
)
writer.add_scalar("train/mask_loss", mask_losses.avg, global_step)
writer.add_scalar(
"metrics/total_secs_per_batch", batch_time.avg, global_step
)
writer.add_scalar(
"metrics/data_secs_per_batch", data_time.avg, global_step
)
batch_time.reset()
data_time.reset()
losses.reset()
ce_losses.reset()
mask_bce_losses.reset()
mask_dice_losses.reset()
mask_losses.reset()
if global_step != 0:
curr_lr = scheduler.get_last_lr()
if args.local_rank == 0:
writer.add_scalar("train/lr", curr_lr[0], global_step)
return train_iter
def validate(val_loader, model_engine, epoch, writer, args):
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
model_engine.eval()
for input_dict in tqdm.tqdm(val_loader):
input_dict = dict_to_cuda(input_dict)
if args.precision == "fp16":
input_dict["images"] = input_dict["images"].half()
input_dict["images_clip"] = input_dict["images_clip"].half()
elif args.precision == "bf16":
input_dict["images"] = input_dict["images"].bfloat16()
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
else:
input_dict["images"] = input_dict["images"].float()
input_dict["images_clip"] = input_dict["images_clip"].float()
output_dict = model_engine(**input_dict)
pred_masks = output_dict["pred_masks"]
masks_list = output_dict["gt_masks"][0].int()
output_list = (pred_masks[0] > 0).int()
assert len(pred_masks) == 1
intersection, union, acc_iou = 0.0, 0.0, 0.0
for mask_i, output_i in zip(masks_list, output_list):
intersection_i, union_i, _ = intersectionAndUnionGPU(
output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255
)
intersection += intersection_i
union += union_i
acc_iou += intersection_i / (union_i + 1e-5)
acc_iou[union_i == 0] += 1.0 # no-object target
intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0]
intersection_meter.update(intersection), union_meter.update(
union
), acc_iou_meter.update(acc_iou, n=masks_list.shape[0])
intersection_meter.all_reduce()
union_meter.all_reduce()
acc_iou_meter.all_reduce()
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
ciou = iou_class[1]
giou = acc_iou_meter.avg[1]
if args.local_rank == 0:
writer.add_scalar("val/giou", giou, epoch)
writer.add_scalar("val/giou", ciou, epoch)
print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou))
return giou, ciou
if __name__ == "__main__":
main(sys.argv[1:])