VRIS_vip / LAVT-RIS /train_refzom_angle.py
dianecy's picture
Upload folder using huggingface_hub
8d82201 verified
raw
history blame
23.5 kB
import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
from functools import reduce
import operator
from bert.modeling_bert import BertModel
import torchvision
from lib import segmentation
import transforms as T
import utils
import numpy as np
import torch.nn.functional as F
from data.dataset_refer_zom import Referzom_Dataset, Refzom_DistributedSampler
from data.dataset_refer_bert_rev import ReferDataset
import gc
from collections import OrderedDict
from torch.utils.tensorboard import SummaryWriter
def get_dataset(image_set, transform, args, eval_mode=False):
if args.dataset == 'ref-zom':
ds = Referzom_Dataset(args,
split=image_set,
image_transforms=transform,
target_transforms=None,
eval_mode=eval_mode)
else :
ds = ReferDataset(args,
split=image_set,
image_transforms=transform,
target_transforms=None,
eval_mode=image_set == 'val'
)
num_classes = 2
return ds, num_classes
# IoU calculation for validation
# def IoU(pred, gt):
# pred = pred.argmax(1)
# intersection = torch.sum(torch.mul(pred, gt))
# union = torch.sum(torch.add(pred, gt)) - intersection
# if intersection == 0 or union == 0:
# iou = 0
# else:
# iou = float(intersection) / float(union)
# return iou, intersection, union
def IoU(pred, gt):
pred = pred.argmax(1)
intersection = torch.sum(torch.mul(pred, gt))
union = torch.sum(torch.add(pred, gt)) - intersection
if intersection == 0 or union == 0:
iou = 0
else:
iou = float(intersection) / float(union)
return iou, intersection, union
def get_transform(args):
transforms = [T.Resize(args.img_size, args.img_size),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
return T.Compose(transforms)
def criterion(input, target):
weight = torch.FloatTensor([0.9, 1.1]).cuda()
return nn.functional.cross_entropy(input, target, weight=weight)
def return_mask(emb_distance, verb_mask=None):
B_, B_ = emb_distance.shape
positive_mask = torch.zeros_like(emb_distance)
positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
if B_ < len(verb_mask):
# If B_ equals to 2*K (double the number of verb phrase)
for i in range(B_ // 2):
positive_mask[2 * i, 2 * i + 1] = 1
positive_mask[2 * i + 1, 2 * i] = 1
else:
# Process the case where we have a mix of sentences with and without verbs
i = 0
while i < B_:
if verb_mask[i] == 1:
positive_mask[i, i + 1] = 1
positive_mask[i + 1, i] = 1
i += 2
else:
i += 1
negative_mask = torch.ones_like(emb_distance) - positive_mask
return positive_mask, negative_mask
def UniAngularContrastLoss(total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
_, C, H, W = total_fq.shape
if verbonly :
B = total_fq[verb_mask].shape[0]
emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C)
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
else :
emb = torch.mean(total_fq, dim=-1)
B_ = emb.shape[0]
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
positive_mask, negative_mask = return_mask(sim_matrix, verb_mask)
if len(positive_mask) > 0 :
sim_matrix_with_margin = sim_matrix.clone()
sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
logits = sim_matrix_with_margin / tau
exp_logits = torch.exp(logits)
pos_exp_logits = exp_logits * positive_mask.long()
pos_exp_logits = pos_exp_logits.sum(dim=-1)
# print("pos_exp_logits: ", pos_exp_logits.shape)
total_exp_logits = exp_logits.sum(dim=-1)
positive_loss = -torch.log(pos_exp_logits / total_exp_logits)
angular_loss = positive_loss.mean()
return angular_loss
else :
return torch.tensor(0.0, device=total_fq.device)
def UniAngularLogitContrastLoss(total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
epsilon = 1e-10 # Stability term for numerical issues
_, C, H, W = total_fq.shape
# Calculate embeddings
if verbonly :
B = total_fq[verb_mask].shape[0]
emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C)
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
else :
emb = torch.mean(total_fq, dim=-1)
B_ = emb.shape[0]
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
margin_in_radians = m / 57.2958 # Convert degrees to radians
theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix)
positive_mask, negative_mask = return_mask(sim_matrix, verb_mask)
theta_with_margin = theta_matrix.clone()
theta_with_margin[positive_mask.bool()] -= margin_in_radians # Subtract margin directly for positives
logits = theta_with_margin / tau # Scale with temperature
# Compute exponential logits for softmax
exp_logits = torch.exp(logits)
# pos_exp_logits = (exp_logits * positive_mask).sum(dim=-1) # Positive term
pos_exp_logits = exp_logits * positive_mask
pos_exp_logits = pos_exp_logits.sum(dim=-1)
# neg_exp_logits = (exp_logits * negative_mask).sum(dim=-1) # Negative term
# total_exp_logits = pos_exp_logits + neg_exp_logits
total_exp_logits = exp_logits.sum(dim=-1)
# pos_exp_logits = pos_exp_logits + epsilon
# total_exp_logits = total_exp_logits + epsilon
# Compute angular loss
loss = -torch.log(pos_exp_logits / total_exp_logits)
angular_loss = loss.mean()
return angular_loss
def evaluate(model, data_loader, bert_model):
#print("current model : ", model)
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
total_its = 0
acc_ious = 0
# evaluation variables
cum_I, cum_U = 0, 0
eval_seg_iou_list = [.5, .6, .7, .8, .9]
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
seg_total = 0
mean_IoU = []
mean_acc = []
with torch.no_grad():
for data in metric_logger.log_every(data_loader, 100, header):
total_its += 1
# Unpack data
image, target, source_type, sentences, attentions = data
image, target, sentences, attentions = (
image.cuda(non_blocking=True),
target.cuda(non_blocking=True),
sentences.cuda(non_blocking=True),
attentions.cuda(non_blocking=True)
)
# Squeeze unnecessary dimensions
sentences = sentences.squeeze(-1)
attentions = attentions.squeeze(-1)
# Model inference
if bert_model is not None:
last_hidden_states = bert_model(sentences, attention_mask=attentions)[0]
embedding = last_hidden_states.permute(0, 2, 1) # [B, N, 768] -> [B, 768, N]
attentions = attentions.unsqueeze(-1) # [B, N] -> [B, N, 1]
output = model(image, embedding, l_mask=attentions)
else:
output = model(image, sentences, l_mask=attentions, is_train=False)
# Zero target case
if source_type[0] == 'zero':
pred = output.argmax(1)
incorrect_num = torch.sum(pred).item() # Count non-zero predictions
acc = 1 if incorrect_num == 0 else 0
mean_acc.append(acc)
else:
# Non-zero target case
this_iou, I, U = IoU(output, target) # Use the provided IoU function
mean_IoU.append(this_iou)
cum_I += I
cum_U += U
for n_eval_iou in range(len(eval_seg_iou_list)):
eval_seg_iou = eval_seg_iou_list[n_eval_iou]
seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
seg_total += 1
mIoU = np.mean(mean_IoU)
mean_acc = np.mean(mean_acc)
print('Final results:')
print('Mean IoU is %.2f\n' % (mIoU * 100.))
results_str = ''
precs = []
for n_eval_iou in range(len(eval_seg_iou_list)):
results_str += ' precision@%s = %.2f\n' % \
(str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
precs.append(seg_correct[n_eval_iou] * 100. / seg_total)
results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
results_str += ' mean IoU = %.2f\n' % (mIoU * 100.)
print(results_str)
if args.dataset == 'ref-zom':
print('Mean accuracy for one-to-zero sample is %.2f\n' % (mean_acc*100))
return mIoU, 100 * cum_I / cum_U, precs
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq,
iterations, bert_model, metric_learning=False, args=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
header = 'Epoch: [{}]'.format(epoch)
train_loss = 0
total_its = 0
mlw = args.metric_loss_weight
metric_mode = args.metric_mode
if not metric_learning:
mlw = 0
for data in metric_logger.log_every(data_loader, print_freq, header):
#print("data : ", data)
total_its += 1
# Ref-Zom Repro
image, target, source_type, sentences, attentions, pos_sent, pos_attn_mask, pos_type = data
source_type = np.array(source_type)
target_flag = torch.tensor(np.where(source_type == 'zero', 0, 1))
if args.addzero :
hardpos_flag = torch.tensor(np.where(pos_type == 'hardpos', 1, 0))
else :
# default option for training : only include one, many targets!
hardpos_flag = torch.tensor(np.where((source_type != 'zero') & (pos_type == 'hardpos'), 1, 0))
sentences = sentences.squeeze(1)
attentions = attentions.squeeze(1)
pos_sent = pos_sent.squeeze(1)
pos_attn_mask = pos_attn_mask.squeeze(1)
## ver 1 : hardpos flag outside the model
verb_masks = []
cl_masks = []
images = []
targets = []
sentences_ = []
attentions_ = []
for idx in range(len(image)) :
# Append original data
sentences_.append(sentences[idx])
images.append(image[idx])
targets.append(target[idx])
attentions_.append(attentions[idx])
if hardpos_flag[idx] :
verb_masks.extend([1, 1])
cl_masks.extend([1, 0])
sentences_.append(pos_sent[idx])
images.append(image[idx])
targets.append(target[idx])
attentions_.append(pos_attn_mask[idx])
else:
verb_masks.append(0)
cl_masks.append(1)
image, target, sentences, attentions, verb_masks, cl_masks = \
torch.stack(images).cuda(non_blocking=True),\
torch.stack(targets).cuda(non_blocking=True),\
torch.stack(sentences_).cuda(non_blocking=True),\
torch.stack(attentions_).cuda(non_blocking=True),\
torch.tensor(verb_masks, dtype=torch.bool, device='cuda'),\
torch.tensor(cl_masks, dtype=torch.bool, device='cuda')
loss = 0
metric_loss = 0
if bert_model is not None:
last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] # (6, 10, 768)
embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
attentions = attentions.unsqueeze(dim=-1) # (batch, N_l, 1)
output = model(image, embedding, l_mask=attentions)
else:
output, metric_tensors = model(image, sentences, l_mask=attentions)
ce_loss = criterion(output[cl_masks], target[cl_masks])
if metric_learning:
hardpos_count = sum(hardpos_flag)
divn = 1
if hardpos_count >= 3:
metric_loss = UniAngularLogitContrastLoss(metric_tensors, verb_masks, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
divn+=mlw # (1+mlw)
else:
metric_loss = 0
else:
metric_loss = 0
divn = 1
# if metric_learning and sum(hardpos_flag) > 0 :
# metric_loss = UniAngularLogitContrastLoss(metric_tensors, verb_masks, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
loss = (ce_loss + metric_loss * mlw) / divn
optimizer.zero_grad() # set_to_none=True is only available in pytorch 1.6+
loss.backward()
optimizer.step()
lr_scheduler.step()
torch.cuda.synchronize()
train_loss += loss.item()
iterations += 1
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
del image, target, sentences, attentions, loss, output, data
if bert_model is not None:
del last_hidden_states, embedding
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
loss_log = {
'loss': metric_logger.meters['loss'].global_avg
}
return iterations, loss_log
def main(args):
writer = SummaryWriter('./experiments/{}/{}'.format("_".join([args.dataset, args.splitBy]), args.model_id))
dataset, num_classes = get_dataset("train",
get_transform(args=args),
args=args,
eval_mode=False)
dataset_test, _ = get_dataset(args.split,
get_transform(args=args),
args=args,
eval_mode=True)
# batch sampler
print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.")
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
if args.dataset == 'ref-zom':
train_sampler = Refzom_DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank,
shuffle=True)
else:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank,
shuffle=True)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
# data loader
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=args.pin_mem, drop_last=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers)
# model initialization
print(args.model)
model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights,
args=args)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
single_model = model.module
if args.model != 'lavt_one':
model_class = BertModel
bert_model = model_class.from_pretrained(args.ck_bert)
bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel
bert_model.cuda()
bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model)
bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank])
single_bert_model = bert_model.module
else:
bert_model = None
single_bert_model = None
# resume training
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
single_model.load_state_dict(checkpoint['model'])
if args.model != 'lavt_one':
single_bert_model.load_state_dict(checkpoint['bert_model'])
# parameters to optimize
backbone_no_decay = list()
backbone_decay = list()
for name, m in single_model.backbone.named_parameters():
if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name:
backbone_no_decay.append(m)
else:
backbone_decay.append(m)
if args.model != 'lavt_one':
params_to_optimize = [
{'params': backbone_no_decay, 'weight_decay': 0.0},
{'params': backbone_decay},
{"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
# the following are the parameters of bert
{"params": reduce(operator.concat,
[[p for p in single_bert_model.encoder.layer[i].parameters()
if p.requires_grad] for i in range(10)])},
]
else:
params_to_optimize = [
{'params': backbone_no_decay, 'weight_decay': 0.0},
{'params': backbone_decay},
{"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
# the following are the parameters of bert
{"params": reduce(operator.concat,
[[p for p in single_model.text_encoder.encoder.layer[i].parameters()
if p.requires_grad] for i in range(10)])},
]
# params_to_optimize = [
# {'params': backbone_no_decay, 'weight_decay': 0.0},
# {'params': backbone_decay},
# {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
# # the following are the parameters of bert
# {"params": reduce(operator.concat,
# [[p for p in single_model.text_encoder.encoder.layer[i].parameters()
# if p.requires_grad] for i in range(10)]), 'lr': args.lr/10},
# ]
# optimizer
optimizer = torch.optim.AdamW(params_to_optimize,
lr=args.lr,
weight_decay=args.weight_decay,
amsgrad=args.amsgrad
)
# learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
# housekeeping
start_time = time.time()
iterations = 0
best_oIoU = -0.1
# resume training (optimizer, lr scheduler, and the epoch)
if args.resume:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
resume_epoch = checkpoint['epoch']
else:
resume_epoch = -999
# training loops
for epoch in range(max(0, resume_epoch+1), args.epochs):
data_loader.sampler.set_epoch(epoch)
itrs_temp, loss_log = train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq,
iterations, bert_model, metric_learning=args.metric_learning, args=args)
iou, overallIoU, precs = evaluate(model, data_loader_test, bert_model)
print('Average object IoU {}'.format(iou))
print('Overall IoU {}'.format(overallIoU))
save_checkpoint = (best_oIoU < overallIoU)
if save_checkpoint:
print('Better epoch: {}\n'.format(epoch))
if single_bert_model is not None:
dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(),
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
'lr_scheduler': lr_scheduler.state_dict()}
else:
dict_to_save = {'model': single_model.state_dict(),
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
'lr_scheduler': lr_scheduler.state_dict()}
utils.save_on_master(dict_to_save, os.path.join(args.output_dir,
'model_best_{}.pth'.format(args.model_id)))
best_oIoU = overallIoU
if utils.is_main_process():
writer.add_scalar('val/mIoU', iou, epoch)
writer.add_scalar('val/oIoU', overallIoU, epoch)
writer.add_scalar('val/Prec/50', precs[0], epoch)
writer.add_scalar('val/Prec/60', precs[1], epoch)
writer.add_scalar('val/Prec/70', precs[2], epoch)
writer.add_scalar('val/Prec/80', precs[3], epoch)
writer.add_scalar('val/Prec/90', precs[4], epoch)
writer.add_scalar('train/loss', loss_log['loss'], epoch)
writer.flush()
# summarize
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == "__main__":
from args import get_parser
parser = get_parser()
args = parser.parse_args()
# set up distributed learning
if "LOCAL_RANK" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])
else:
local_rank = 0 # Default value for non-distributed mode
print(f"Local Rank: {local_rank}, World Size: {os.environ.get('WORLD_SIZE', '1')}")
utils.init_distributed_mode(args)
print('Image size: {}'.format(str(args.img_size)))
print('Metric Learning Ops')
print('metric learning flag : ', args.metric_learning)
print('metric loss weight : ', args.metric_loss_weight)
print('metric mode and hardpos selection : ', args.metric_mode, args.hp_selection)
print('margin value : ', args.margin_value)
print('temperature : ', args.temperature)
print('add zero in ACE loss : ', args.addzero)
print(args)
main(args)