elia / train_elia.py
yxchng
add files
a166479
raw history blame
No virus
36.4 kB
import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
from functools import reduce
import operator
from bert.multimodal_bert import MultiModalBert
import torchvision
from lib import multimodal_segmentation_ppm
import transforms as T
import utils
import numpy as np
import torch.nn.functional as F
import gc
from collections import OrderedDict
import torch.backends.cudnn as cudnn
#from ffrecord.torch import DataLoader,Dataset
from modeling.MaskFormerModel import MaskFormerHead
from addict import Dict
from mask2former_utils.criterion import SetCriterion, Criterion
from mask2former_utils.matcher import HungarianMatcher
from bert.modeling_bert import BertLMPredictionHead, BertEncoder
class WrapperModel(nn.Module):
def __init__(self, image_model, language_model, classifier, args) :
super(WrapperModel, self).__init__()
self.image_model = image_model
self.language_model = language_model
self.classifier = classifier
self.lang_proj = nn.Linear(768,256)
config = Dict({
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": False,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 512,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
#"max_position_embeddings": 16+20,
"model_type": "bert",
"num_attention_heads": 8,
"num_hidden_layers": 8,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.0.dev0",
"type_vocab_size": 2,
"use_cache": True,
"vocab_size": 30522
})
self.mlm_transformer = BertEncoder(config)
self.lang_proj = nn.Linear(768,256)
self.mlm_vis_proj = nn.Conv2d(1024,512,1)
self.mlm_lang_proj = nn.Linear(768,512)
#print(vis_proj)
self.mlm_head = BertLMPredictionHead(config)
assert args.img_size % 4 == 0
num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2
print(num_img_tokens)
self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512)
self.mlm_modal_embeds = nn.Embedding(3, 512)
self.mlm_mask_embed = nn.Embedding(1, 512)
self.mlm_pos_mlp = nn.Sequential(
nn.Linear(2, 512),
nn.LayerNorm(512),
nn.Linear(512,512),
nn.GELU()
)
def _get_binary_mask(self, target):
# 返回每类的binary mask
y, x = target.size()
target_onehot = torch.zeros(self.num_classes + 1, y, x)
target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
return target_onehot[1:]
def semantic_inference(self, mask_cls, mask_pred):
mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
mask_pred = mask_pred.sigmoid()
semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
return semseg
def forward(self, image, sentences, attentions, mlm_targets, mlm_masks, position):
input_shape = image.shape[-2:]
l_mask = attentions.unsqueeze(dim=-1)
i0, Wh, Ww = self.image_model.forward_stem(image)
l0, extended_attention_mask = self.language_model.forward_stem(mlm_targets.squeeze(1), attentions)
i1 = self.image_model.forward_stage1(i0, Wh, Ww)
l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
i1 = i1_temp
i2 = self.image_model.forward_stage2(i1, Wh, Ww)
l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
i2 = i2_temp
i3 = self.image_model.forward_stage3(i2, Wh, Ww)
l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
i3 = i3_temp
i4 = self.image_model.forward_stage4(i3, Wh, Ww)
l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
i4 = i4_temp
#i1_residual, i2_residual, i3_residual, i4_residual = features
#x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
#x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
outputs = {}
outputs['s1'] = i1_residual
outputs['s2'] = i2_residual
outputs['s3'] = i3_residual
outputs['s4'] = i4_residual
predictions, mask_features = self.classifier(outputs)
#print(target_reshape.shape)
#tmp = np.argwhere(target_reshape[:, 0].detach().cpu().numpy()).reshape(-1, target_reshape.shape[2]*target_reshape[3], 3)
#centroid = tmp.mean(1)
#print(centroid)
#centroid_x, centroid_y = int(centroid[1]), int(centroid[0])
#last_hidden_states = brt_model(sentences, attention_mask=attentions)[0] # (6, 10, 768)
#embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
mlp_embed = self.mlm_pos_mlp(position)
#print(centroid_x, centroid_y)
mlm_targets = torch.where(
mlm_masks > 0,
mlm_targets,
torch.ones_like(mlm_targets) * (-1)
)
#print(x_c4[target_reshape[:, [0]].bool()].shape)
vis_features = self.mlm_vis_proj(i4_residual).flatten(2).permute(0,2,1)
#print(l4.shape)
lang_features = self.mlm_lang_proj(l4)
#print(lang_features.shape, vis_features.shape, mlp_embed.shape)
mm_features = torch.cat([lang_features, vis_features, mlp_embed.unsqueeze(1)], dim=1)
#print(mm_features.shape)
#print(mlm_modal_embeds.weight.shape)
modal_embeds = torch.cat([self.mlm_modal_embeds.weight[0].unsqueeze(0).repeat(1, lang_features.shape[1], 1), self.mlm_modal_embeds.weight[1].unsqueeze(0).repeat(1, vis_features.shape[1], 1), self.mlm_modal_embeds.weight[2].unsqueeze(0).repeat(1,1,1)], dim=1)
#print(modal_embeds.shape)
#print(mlm_transformer)
#print(attentions.shape)
mixed_attention_mask = torch.cat([attentions.unsqueeze(-1), torch.ones(attentions.shape[0], vis_features.shape[1]+1, 1).to(attentions.device)], dim=1)
mixed_attention_mask = mixed_attention_mask.permute(0,2,1).unsqueeze(1)
mixed_attention_mask = (1-mixed_attention_mask)* -10000.0
head_mask = [None] * 8
#extended_attention_mask = get_extended_attention_mask(mixed_attention_mask, mm_features.shape, mm_features.device)
#print(mm_features.shape, mixed_attention_mask.shape, head_mask)
#print(mm_features.shape, self.mlm_pos_embeds.weight.shape, self.mlm_modal_embeds.weight.shape)
head_features = self.mlm_transformer(mm_features + self.mlm_pos_embeds.weight.unsqueeze(0) + modal_embeds, mixed_attention_mask, head_mask)[0]
#print(head_features.shape, attentions.shape)
head_features = head_features[:, :20][attentions.bool()]
#print(embedding.shape, mask_features.shape)
mlm_predictions = self.mlm_head(head_features)
mlm_predictions = mlm_predictions.reshape(-1, self.language_model.config.vocab_size)
mlm_targets = mlm_targets.squeeze(1)[attentions.bool()]
#mlm_loss = mlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets)
#loss += mlm_loss
#mlm_loss_print=mlm_loss.item()
return predictions, mask_features, self.lang_proj((l4_residual * l_mask).sum(1)/l_mask.sum(1)), mlm_predictions, mlm_targets
# IoU calculation for validation
def IoU(pred, gt):
#pred = pred.argmax(1)
pred = (pred > 0.5)
intersection = torch.sum(torch.mul(pred, gt))
union = torch.sum(torch.add(pred, gt)) - intersection
if intersection == 0 or union == 0:
iou = 0
else:
iou = float(intersection) / float(union)
return iou, intersection, union
def get_dataset(image_set, transform, args):
from data.dataset_refer_bert_mlm import ReferDataset
ds = ReferDataset(args,
split=image_set,
image_transforms=transform,
target_transforms=None
)
num_classes = 2
return ds, num_classes
def get_transform(args):
transforms = [T.Resize(args.img_size, args.img_size),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
return T.Compose(transforms)
#def criterion(input, target):
# weight = torch.FloatTensor([0.9, 1.1]).cuda()
# return nn.functional.cross_entropy(input, target, weight=weight)
def evaluate(model, data_loader):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
total_its = 0
acc_ious = 0
# evaluation variables
cum_I, cum_U = 0, 0
eval_seg_iou_list = [.5, .6, .7, .8, .9]
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
seg_total = 0
mean_IoU = []
with torch.no_grad():
for data in metric_logger.log_every(data_loader, 100, header):
total_its += 1
#image, target, sentences, attentions = data
#image, target, sentences, attentions = image.cuda(non_blocking=True),\
# target.cuda(non_blocking=True),\
# sentences.cuda(non_blocking=True),\
# attentions.cuda(non_blocking=True)
image, target, sentences, attentions, mlm_targets, mlm_masks, position = data
image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\
target.cuda(non_blocking=True),\
sentences.cuda(non_blocking=True),\
attentions.cuda(non_blocking=True), \
mlm_targets.cuda(non_blocking=True), \
mlm_masks.cuda(non_blocking=True), \
position.cuda(non_blocking=True)
sentences = sentences.squeeze(1)
attentions = attentions.squeeze(1)
#print("sentences", sentences.shape)
#print("attentions", attentions.shape)
output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position)
mask_cls_results = output["pred_logits"]
mask_pred_results = output["pred_masks"]
target_shape = target.shape[-2:]
mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
pred_masks = model.module.semantic_inference(mask_cls_results, mask_pred_results)
output = pred_masks[0]
iou, I, U = IoU(output, target)
acc_ious += iou
mean_IoU.append(iou)
cum_I += I
cum_U += U
for n_eval_iou in range(len(eval_seg_iou_list)):
eval_seg_iou = eval_seg_iou_list[n_eval_iou]
seg_correct[n_eval_iou] += (iou >= eval_seg_iou)
seg_total += 1
iou = acc_ious / total_its
mean_IoU = np.array(mean_IoU)
mIoU = np.mean(mean_IoU)
print('Final results:')
print('Mean IoU is %.2f\n' % (mIoU * 100.))
results_str = ''
for n_eval_iou in range(len(eval_seg_iou_list)):
results_str += ' precision@%s = %.2f\n' % \
(str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
print(results_str)
return 100 * iou, 100 * cum_I / cum_U
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq,
iterations, args):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
header = 'Epoch: [{}]'.format(epoch)
train_loss = 0
total_its = 0
for data in metric_logger.log_every(data_loader, print_freq, header):
total_its += 1
#image, target, sentences, attentions = data
#image, target, sentences, attentions = image.cuda(non_blocking=True),\
# target.cuda(non_blocking=True),\
# sentences.cuda(non_blocking=True),\
# attentions.cuda(non_blocking=True)
image, target, sentences, attentions, mlm_targets, mlm_masks, position = data
image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\
target.cuda(non_blocking=True),\
sentences.cuda(non_blocking=True),\
attentions.cuda(non_blocking=True), \
mlm_targets.cuda(non_blocking=True), \
mlm_masks.cuda(non_blocking=True), \
position.cuda(non_blocking=True)
sentences = sentences.squeeze(1)
attentions = attentions.squeeze(1)
#l_mask = attentions.unsqueeze(dim=-1)
output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position)
#print(avg_lang_feature.shape)
avg_lang_feature = torch.nn.functional.normalize(avg_lang_feature, dim=1)
#print("----")
#print(output.shape)
#print(mask_features.shape)
#print(avg_lang_feature.shape)
#print( mlm_predictions.shape)
#print(mlm_targets.shape)
#print("----")
target_shape = target.shape[-2:]
output['pred_masks'] = F.interpolate(output['pred_masks'], size=target_shape, mode='bilinear', align_corners=True)
if "aux_outputs" in output:
for i, aux_outputs in enumerate(output["aux_outputs"]):
output['aux_outputs'][i]['pred_masks'] = F.interpolate(output['aux_outputs'][i]['pred_masks'], size=target_shape, mode='bilinear', align_corners=True)
# pixel region
B, C, H, W = mask_features.shape
target_reshape = F.interpolate(target.unsqueeze(1).float(), size=mask_features.shape[-2:], mode='nearest').long()
target_reshape = target_reshape.repeat(1, mask_features.shape[1], 1, 1)
#print(avg_pos_feature.shape, avg_lang_feature.shape, avg_neg_feature.shape)
#cl_loss = 0.0
plic_lang_loss = 0.0
plic_pos_loss = 0.0
plic_neg_loss = 0.0
for i in range(B):
if ((target_reshape[[i]] == 0).sum() != 0 and (target_reshape[[i]] == 1).sum() != 0):
avg_pos_feature = (mask_features[[i]] * target_reshape[[i]]).sum(-1).sum(-1) / target_reshape[[i]].sum(-1).sum(-1)
avg_neg_feature = (mask_features[[i]] * (1.0-target_reshape[[i]])).sum(-1).sum(-1) / (1.0-target_reshape[[i]]).sum(-1).sum(-1)
avg_pos_feature = torch.nn.functional.normalize(avg_pos_feature, dim=1)
avg_neg_feature = torch.nn.functional.normalize(avg_neg_feature, dim=1)
#avg lang feature no normalize???
pos_features = mask_features[[i]][target_reshape[[i]]==1].view(1, C, -1)
neg_features = mask_features[[i]][target_reshape[[i]]==0].view(1, C, -1)
#inter_neg_features = mask_features[[B-i-1]][target_reshape[[B-i-1]]==1].view(1, C, -1)
#neg_features = torch.cat([intra_neg_features, inter_neg_features], dim=2)
pos_features = torch.nn.functional.normalize(pos_features, dim=1)
neg_features = torch.nn.functional.normalize(neg_features, dim=1)
#print(avg_lang_feature.shape, avg_lang_feature[[i]].shape, pos_features.shape)
lang_pos_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], pos_features)
lang_neg_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], neg_features)
lang_matrix = torch.cat([lang_pos_scores.unsqueeze(-1), lang_neg_scores.unsqueeze(1).repeat(1, lang_pos_scores.shape[1], 1)], dim=2)
lang_labels = torch.zeros(lang_matrix.shape[1], dtype=torch.long).cuda()
lang_labels = lang_labels.unsqueeze(0).repeat(lang_matrix.shape[0], 1)
lang_score = torch.softmax(lang_matrix, -1)
lang_score = 1.0 - lang_score[:, :, 0]
pos_pos_scores = torch.einsum("bq,bqn->bn", avg_pos_feature, pos_features)
pos_neg_scores = torch.einsum("bqn,bqm->bnm", pos_features, neg_features)
pos_matrix = torch.cat([pos_pos_scores.unsqueeze(-1), pos_neg_scores], dim=2)
pos_labels = torch.zeros(pos_matrix.shape[1], dtype=torch.long).cuda()
pos_labels = pos_labels.unsqueeze(0).repeat(pos_matrix.shape[0], 1)
pos_score = torch.softmax(pos_matrix, -1)
pos_score = 1.0 - pos_score[:, :, 0]
#pos_weight = pos_weight.view(-1, pos_weight.shape[-1])
#intra_neg_features = torch.nn.functional.normalize(intra_neg_features, dim=1)
neg_neg_scores = torch.einsum("bq,bqn->bn", avg_neg_feature, neg_features)
neg_pos_scores = torch.einsum("bqn,bqm->bnm", neg_features, pos_features)
neg_matrix = torch.cat([neg_neg_scores.unsqueeze(-1), neg_pos_scores], dim=2)
neg_labels = torch.zeros(neg_matrix.shape[1], dtype=torch.long).cuda()
neg_labels = neg_labels.unsqueeze(0).repeat(neg_matrix.shape[0], 1)
neg_score = torch.softmax(neg_matrix, -1)
neg_score = 1.0 - neg_score[:, :, 0]
#neg_weight = neg_weight.view(-1, neg_weight.shape[-1])
pos_loss = (torch.pow(pos_score, args.plic_pos_alpha) * torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/args.plic_pos_temp, pos_labels.view(-1), reduction='none')).mean()
neg_loss = (torch.pow(neg_score, args.plic_neg_alpha) * torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/args.plic_neg_temp, neg_labels.view(-1), reduction='none')).mean()
lang_loss = (torch.pow(lang_score, args.plic_lang_alpha) * torch.nn.functional.cross_entropy(lang_matrix.view(-1, lang_matrix.shape[-1])/args.plic_lang_temp, lang_labels.view(-1), reduction='none')).mean()
plic_pos_loss += pos_loss
plic_neg_loss += neg_loss
plic_lang_loss += lang_loss
#cl_loss += 0.5 * (torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/cl_temp, pos_labels.view(-1))+torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/cl_temp, neg_labels.view(-1)))
plic_pos_loss = (args.plic_pos_weight * plic_pos_loss) / B
plic_neg_loss = (args.plic_neg_weight * plic_neg_loss) / B
plic_lang_loss = (args.plic_lang_weight * plic_lang_loss) / B
plic_loss = plic_pos_loss + plic_neg_loss +plic_lang_loss
#print(output.device, target.device)
losses = criterion(output, target)
weight_dict = criterion.weight_dict
loss_ce = 0.0
loss_dice = 0.0
loss_mask = 0.0
for k in list(losses.keys()):
if k in weight_dict:
losses[k] *= criterion.weight_dict[k]
if '_ce' in k:
loss_ce += losses[k]
elif '_dice' in k:
loss_dice += losses[k]
else:
loss_mask += losses[k]
else:
# remove this loss if not specified in `weight_dict`
losses.pop(k)
#loss = 0.3 * loss_ce + 0.3 * loss_dice + 0.4 * loss_mask
smlm_loss = args.smlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets)
loss = loss_ce + loss_dice + loss_mask + plic_loss + smlm_loss
#loss = criterion(output.squeeze(1), target.float())
optimizer.zero_grad() # set_to_none=True is only available in pytorch 1.6+
loss.backward()
optimizer.step()
lr_scheduler.step()
torch.cuda.synchronize()
train_loss += loss.item()
iterations += 1
#metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), plic_loss=plic_loss.item(), plic_lang_loss=plic_lang_loss.item(), plic_pos_loss=plic_pos_loss.item(), plic_neg_loss=plic_neg_loss.item(), smlm_loss=smlm_loss.item())
#metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), cl_loss=cl_loss.item(), cl_lang_loss=cl_lang_loss_print, cl_pos_loss=cl_pos_loss_print, cl_neg_loss=cl_neg_loss_print)
#del image, target, sentences, attentions, loss, output, data
#if bert_model is not None:
# del last_hidden_states, embedding
#gc.collect()
#torch.cuda.empty_cache()
#del loss
#del cl_loss
#del cl_lang_loss
#del loss_ce
#del loss_dice
#del loss_mask
torch.cuda.synchronize()
def main(args):
#def main(local_rank, args):
#ip = os.environ['MASTER_IP']
#port = os.environ['MASTER_PORT']
#hosts = int(os.environ['WORLD_SIZE']) # 机器个数 1
#rank = int(os.environ['RANK']) # 当前机器编号
#gpus = torch.cuda.device_count() # 每台机器的GPU个数
#print(local_rank, rank, gpus) #3 0 8
#dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=hosts*gpus, rank=rank*gpus+local_rank)
#torch.cuda.set_device(local_rank)
#dist.barrier()
##utils.init_distributed_mode(args)
#args.distributed=True
#args.gpu = local_rank
#print(args)
##misc.init_distributed_mode(args)
#print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
#print("{}".format(args).replace(', ', ',\n'))
#device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
print('seed', seed)
torch.manual_seed(seed)
np.random.seed(seed)
#cudnn.benchmark = True
dataset, num_classes = get_dataset("train",
get_transform(args=args),
args=args)
dataset_test, _ = get_dataset("val",
get_transform(args=args),
args=args)
# batch sampler
print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.")
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
#num_tasks = hosts*gpus
#global_rank = rank*gpus+local_rank
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank,
shuffle=True)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
# data loader
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers)
# model initialization
print(args.model)
model = multimodal_segmentation_ppm.__dict__[args.model](pretrained=args.pretrained_swin_weights,
args=args)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
#model.cuda()
#model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
#model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=False)
#single_model = model.module
if args.model != 'lavt_one':
model_class = MultiModalBert
bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=model.backbone.embed_dim)
bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel
#bert_model.cuda()
bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model)
#bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[local_rank])
#single_bert_model = bert_model.module
else:
bert_model = None
single_bert_model = None
input_shape = dict()
input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
cfg = Dict()
cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
cfg.MODEL.MASK_FORMER.NHEADS = 8
cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = args.transformer_enc_layers
cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = args.num_object_queries
cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = args.dim_feedforward
cfg.MODEL.MASK_FORMER.DEC_LAYERS = args.dec_layers
cfg.MODEL.MASK_FORMER.PRE_NORM = False
cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = args.no_object_weight
cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = args.class_weight
cfg.MODEL.MASK_FORMER.DICE_WEIGHT = args.dice_weight
cfg.MODEL.MASK_FORMER.MASK_WEIGHT = args.mask_weight
cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = args.train_num_points
cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0
cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75
print(cfg)
maskformer_head = MaskFormerHead(cfg, input_shape)
maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
#maskformer_head.cuda()
#maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
#single_head = maskformer_head.module
#print(single_head)
model = WrapperModel(model.backbone, bert_model, maskformer_head, args)
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
single_model = model.module
# mask2former loss
deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
# loss weights
class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
# self.criterion = Criterion(self.num_classes)
# building criterion
matcher = HungarianMatcher(
cost_class=class_weight,
cost_mask=mask_weight,
cost_dice=dice_weight,
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
)
weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
if deep_supervision:
dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
aux_weight_dict = {}
for i in range(dec_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
losses = ["labels", "masks"]
criterion = SetCriterion(
cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
matcher=matcher,
weight_dict=weight_dict,
eos_coef=no_object_weight,
losses=losses,
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
device='cuda'
)
if args.resume == "auto":
last_ckpt = ""
for e in range(args.epochs):
ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
if os.path.exists(ckpt_path):
last_ckpt = ckpt_path
args.resume = last_ckpt
# resume training
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
single_model.load_state_dict(checkpoint['model'])
#if args.model != 'lavt_one':
# single_bert_model.load_state_dict(checkpoint['bert_model'])
# parameters to optimize
backbone_no_decay = list()
backbone_decay = list()
for name, m in single_model.image_model.named_parameters():
if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name:
backbone_no_decay.append(m)
else:
backbone_decay.append(m)
params_to_optimize = [
{'params': backbone_no_decay, 'weight_decay': 0.0},
{'params': backbone_decay},
{"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
# the following are the parameters of bert
{"params": reduce(operator.concat,
[[p for p in single_model.language_model.encoder.layer[i].parameters()
if p.requires_grad] for i in range(10)])},
{"params": single_model.language_model.pwams.parameters()},
{"params": single_model.language_model.res_gates.parameters()},
{"params": single_model.language_model.norms.parameters()},
{"params": single_model.lang_proj.parameters()},
#{"params": single_model.language_model.parameters()},
{'params': single_model.mlm_head.parameters()},
{'params': single_model.mlm_vis_proj.parameters()},
{'params': single_model.mlm_lang_proj.parameters()},
{'params': single_model.mlm_transformer.parameters()},
{'params': single_model.mlm_pos_embeds.parameters()},
{'params': single_model.mlm_modal_embeds.parameters()},
{'params': single_model.mlm_mask_embed.parameters()},
{'params': single_model.mlm_pos_mlp.parameters()},
#{'params': mlm_head.parameters(), 'weight_decay': 0.0},
#{'params': mlm_vis_proj.parameters(), 'weight_decay': 0.0},
#{'params': mlm_lang_proj.parameters(), 'weight_decay': 0.0},
#{'params': mlm_transformer.parameters(), 'weight_decay': 0.0},
#{'params': mlm_pos_embeds.parameters(), 'weight_decay': 0.0},
#{'params': mlm_modal_embeds.parameters(), 'weight_decay': 0.0},
#{'params': mlm_mask_embed.parameters(), 'weight_decay': 0.0},
#{'params': mlm_pos_mlp.parameters(), 'weight_decay': 0.0},
]
# optimizer
optimizer = torch.optim.AdamW(params_to_optimize,
lr=args.lr,
weight_decay=args.weight_decay,
amsgrad=args.amsgrad
)
# learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
# housekeeping
start_time = time.time()
iterations = 0
best_oIoU = -0.1
# resume training (optimizer, lr scheduler, and the epoch)
if args.resume:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
resume_epoch = checkpoint['epoch']
else:
resume_epoch = -999
# training loops
for epoch in range(max(0, resume_epoch+1), args.epochs):
data_loader.sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq,
iterations, args)
iou, overallIoU = evaluate(model, data_loader_test)
print('Average object IoU {}'.format(iou))
print('Overall IoU {}'.format(overallIoU))
dict_to_save = {'model': single_model.state_dict(),
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
'lr_scheduler': lr_scheduler.state_dict()}
checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch))
utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP')
if utils.is_main_process():
os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))
if utils.is_main_process():
ckpt_paths = []
for e in range(args.epochs):
ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
print(ckpt_path)
if os.path.exists(ckpt_path):
ckpt_paths.append(ckpt_path)
print(ckpt_paths)
for ckpt_path in ckpt_paths[:-args.max_ckpt]:
os.remove(ckpt_path)
print("remove {:s}".format(ckpt_path))
save_checkpoint = (best_oIoU < overallIoU)
if save_checkpoint:
print('Better epoch: {}\n'.format(epoch))
dict_to_save = {'model': single_model.state_dict(),
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
'lr_scheduler': lr_scheduler.state_dict()}
checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id))
utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP')
if utils.is_main_process():
os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))
best_oIoU = overallIoU
torch.cuda.empty_cache()
# summarize
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == "__main__":
from args import get_parser
parser = get_parser()
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# set up distributed learning
utils.init_distributed_mode(args)
print('Image size: {}'.format(str(args.img_size)))
main(args)
#mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())