elia / visualize.py
yxchng
add files
a166479
import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
from bert.multimodal_bert import MultiModalBert
import torchvision
from lib import multimodal_segmentation_ppm
import transforms as T
import utils
import numpy as np
from PIL import Image
import torch.nn.functional as F
from modeling.MaskFormerModel import MaskFormerHead
from addict import Dict
from bert.modeling_bert import BertLMPredictionHead, BertEncoder
import cv2
import textwrap
def get_dataset(image_set, transform, args):
from data.dataset_refer_bert_vis import ReferDataset
ds = ReferDataset(args,
split=image_set,
image_transforms=transform,
target_transforms=None,
eval_mode=True
)
num_classes = 2
return ds, num_classes
def overlay_davis(image, mask, colors=[[0, 0, 0], [0, 255, 0]], cscale=1, alpha=0.4):
from scipy.ndimage.morphology import binary_dilation
colors = np.reshape(colors, (-1, 3))
colors = np.atleast_2d(colors) * cscale
im_overlay = image.copy()
object_ids = np.unique(mask)
for object_id in object_ids[1:]:
# Overlay color on binary mask
foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
binary_mask = mask == object_id
# Compose image
im_overlay[binary_mask] = foreground[binary_mask]
# countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
countours = binary_dilation(binary_mask) ^ binary_mask
# countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
im_overlay[countours, :] = 0
return im_overlay.astype(image.dtype)
def evaluate(model, data_loader, device, args):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
# evaluation variables
cum_I, cum_U = 0, 0
eval_seg_iou_list = [.5, .6, .7, .8, .9]
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
seg_total = 0
mean_IoU = []
header = 'Test:'
with torch.no_grad():
idx = 0
for data in metric_logger.log_every(data_loader, 100, header):
idx += 1
image, target, sentences, attentions, raw_sentences, this_img, orig_img = data
image, target, sentences, attentions = image.to(device), target.to(device), \
sentences.to(device), attentions.to(device)
sentences = sentences.squeeze(1)
attentions = attentions.squeeze(1)
#target = target.cpu().data.numpy()
b, h, w, c = orig_img.shape
#orig_img = orig_img.numpy()[:, :, :, ::-1]
orig_img =orig_img.data.cpu().numpy()[0, :, :, :].astype(np.uint8)
vis = np.zeros((h, w*2,3)).astype(np.uint8)
#image_mean_iou = []
target_numpy = target.cpu().numpy()
for j in range(sentences.size(-1)):
#if bert_model is not None:
# last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
# embedding = last_hidden_states.permute(0, 2, 1)
# output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
#else:
output = model(image, sentences[:, :, j], attentions[:, :, j])
mask_cls_results = output["pred_logits"]
mask_pred_results = output["pred_masks"]
target_shape = target.shape[-2:]
mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
#output = pred_masks[0]
#output = output.cpu()
I, U = computeIoU(pred_masks.cpu().numpy(), target_numpy)
if U == 0:
this_iou = 0.0
else:
this_iou = I*1.0/U
mean_IoU.append(this_iou)
#image_mean_iou.append(this_iou)
cum_I += I
cum_U += U
for n_eval_iou in range(len(eval_seg_iou_list)):
eval_seg_iou = eval_seg_iou_list[n_eval_iou]
seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
seg_total += 1
#print(output.shape)
#output_mask = output.argmax(1).data.numpy()
#output_mask = (output > 0.5).data.numpy()
#vis_output_mask = torch.sigmoid(output[:, 1]).data.numpy()
#vis_output_mask = torch.sigmoid((output>0.5).float()).data.numpy()
#soft
#vis_output_mask = output.data.numpy()
#vis_output_mask = output_mask
#print(output.shape, orig_shape)
gt_masks = torch.nn.functional.interpolate(target.unsqueeze(0).float(), (h, w))
pred_masks = torch.nn.functional.interpolate(pred_masks, (h, w))
#print(orig_mask.shape)
pred_masks = (pred_masks > 0.5).data.cpu().numpy()
#ntarget = target.data.cpu().numpy()
##orig_mask = orig_mask.argmax(1).data.numpy()
#print(orig_img[0].shape, orig_mask[0][0].shape, flush=True)
#print(orig_img.dtype, orig_mask.dtype)
predict_imgs = overlay_davis(orig_img, pred_masks[0][0].astype(np.uint8))
gt_imgs = overlay_davis(orig_img, gt_masks[0][0].cpu().numpy().astype(np.uint8), colors=[[0, 0, 0], [0, 0, 255]])
#print(orig_mask.shape, orig_img.shape)
#red_mask = np.zeros((orig_mask.shape[1], orig_mask.shape[2], orig_mask.shape[3], 3)).astype(np.uint8)
#print("???", red_mask.shape, orig_mask.shape)
#red_mask[:, :, :, 1] = orig_mask * 255
#red_mask = cv2.bitwise_and(red_mask, red_mask, orig_mask.astype(np.uint8))
#temp = cv2.addWeighted(red_mask, 0.5, orig_img, 0.5, 0)
#print(orig_img.shape, temp.shape, orig_mask.shape, "WHAT?")
#new = orig_img * (1.0 - orig_mask[0][:,:,:,None]) + temp * orig_mask[0][:,:,:,None]
#print(new.shape, orig_mask.shape, temp.shape, "check")
##print(vis_output_mask)
##output_mask = output.argmax(1).data.numpy()
#
#print(raw_sentences[j])
# print(image.shape, target.shape, output_mask.shape)
#mean = np.array([0.485, 0.456, 0.406])
#std = np.array([0.229, 0.224, 0.225])
#np_image = (((image[0].permute(1,2,0).cpu().numpy() * std) + mean) * 255).astype(np.uint8)[:,:,::-1]
#np_target = (target * 255).transpose(1,2,0).astype(np.uint8)
##print(output_mask)
#np_output_mask = (vis_output_mask*255).transpose(1,2,0).repeat(3, axis=2).astype(np.uint8)
font = cv2.FONT_HERSHEY_DUPLEX
fontScale = 1.0
fontColor = (255,0,0)
thickness = 1
lineType = 2
wrapped_text = textwrap.wrap(' '.join(raw_sentences[j]), width=35)
for k, line in enumerate(wrapped_text):
bottomLeftCornerOfText = (10,h-60 + k*30)
gt_imgs = cv2.putText(gt_imgs, line,
bottomLeftCornerOfText,
font,
fontScale,
fontColor,
thickness,
lineType)
#temp = j + 2
#split = temp // 3
#row = temp % 3
vis[0:h, 0:w, :] = gt_imgs
vis[0:h, w:2*w, :] = predict_imgs
#cv2.imwrite("vis/elifan_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], j), new[0].astype(np.uint8))
cv2.imwrite("vis/{:s}/{:s}_{:d}_{:d}_{:.2f}.jpg".format(args.vis_dir, this_img[0].split('.')[0], idx, j, this_iou), vis[:, :, ::-1].astype(np.uint8))
#print('---------------')
#cv2.imshow("vis", vis)
#cv2.waitKey(0)
#image_mean_iou = np.mean(np.array(image_mean_iou))
#print(image_mean_iou)
#if image_mean_iou < 0.5:
#cv2.imwrite("vis/elian_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], idx), vis)
#del image, target, sentences, attentions, output, output_mask
#if bert_model is not None:
# del last_hidden_states, embedding
mean_IoU = np.array(mean_IoU)
mIoU = np.mean(mean_IoU)
print('Final results:')
print('Mean IoU is %.2f\n' % (mIoU*100.))
results_str = ''
for n_eval_iou in range(len(eval_seg_iou_list)):
results_str += ' precision@%s = %.2f\n' % \
(str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
print(results_str)
#def evaluate(model, data_loader, device):
# model.eval()
# metric_logger = utils.MetricLogger(delimiter=" ")
#
# # evaluation variables
# cum_I, cum_U = 0, 0
# eval_seg_iou_list = [.5, .6, .7, .8, .9]
# seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
# seg_total = 0
# mean_IoU = []
# header = 'Test:'
#
# with torch.no_grad():
# for data in metric_logger.log_every(data_loader, 100, header):
# image, target, sentences, attentions = data
# image, target, sentences, attentions = image.to(device), target.to(device), \
# sentences.to(device), attentions.to(device)
# sentences = sentences.squeeze(1)
# attentions = attentions.squeeze(1)
# target = target.cpu().data.numpy()
# for j in range(sentences.size(-1)):
# #if bert_model is not None:
# # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
# # embedding = last_hidden_states.permute(0, 2, 1)
# # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
# #else:
# output = model(image, sentences[:, :, j], attentions[:, :, j])
# mask_cls_results = output["pred_logits"]
# mask_pred_results = output["pred_masks"]
#
# target_shape = target.shape[-2:]
# mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
#
# pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
# output = pred_masks[0]
#
# output = output.cpu()
# #print(output.shape)
# #output_mask = output.argmax(1).data.numpy()
# output_mask = (output > 0.5).data.numpy()
# I, U = computeIoU(output_mask, target)
# if U == 0:
# this_iou = 0.0
# else:
# this_iou = I*1.0/U
# mean_IoU.append(this_iou)
# cum_I += I
# cum_U += U
# for n_eval_iou in range(len(eval_seg_iou_list)):
# eval_seg_iou = eval_seg_iou_list[n_eval_iou]
# seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
# seg_total += 1
#
# #del image, target, sentences, attentions, output, output_mask
# #if bert_model is not None:
# # del last_hidden_states, embedding
#
# mean_IoU = np.array(mean_IoU)
# mIoU = np.mean(mean_IoU)
# print('Final results:')
# print('Mean IoU is %.2f\n' % (mIoU*100.))
# results_str = ''
# for n_eval_iou in range(len(eval_seg_iou_list)):
# results_str += ' precision@%s = %.2f\n' % \
# (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
# results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
# print(results_str)
def get_transform(args):
transforms = [T.Resize(args.img_size, args.img_size),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
return T.Compose(transforms)
def computeIoU(pred_seg, gd_seg):
I = np.sum(np.logical_and(pred_seg, gd_seg))
U = np.sum(np.logical_or(pred_seg, gd_seg))
return I, U
class WrapperModel(nn.Module):
def __init__(self, image_model, language_model, classifier, args) :
super(WrapperModel, self).__init__()
self.image_model = image_model
self.language_model = language_model
self.classifier = classifier
self.lang_proj = nn.Linear(768,256)
config = Dict({
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": False,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 512,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
#"max_position_embeddings": 16+20,
"model_type": "bert",
"num_attention_heads": 8,
"num_hidden_layers": 8,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.0.dev0",
"type_vocab_size": 2,
"use_cache": True,
"vocab_size": 30522
})
self.mlm_transformer = BertEncoder(config)
self.lang_proj = nn.Linear(768,256)
self.mlm_vis_proj = nn.Conv2d(1024,512,1)
self.mlm_lang_proj = nn.Linear(768,512)
#print(vis_proj)
self.mlm_head = BertLMPredictionHead(config)
assert args.img_size % 4 == 0
num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2
print(num_img_tokens)
self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512)
self.mlm_modal_embeds = nn.Embedding(3, 512)
self.mlm_mask_embed = nn.Embedding(1, 512)
self.mlm_pos_mlp = nn.Sequential(
nn.Linear(2, 512),
nn.LayerNorm(512),
nn.Linear(512,512),
nn.GELU()
)
def _get_binary_mask(self, target):
# 返回每类的binary mask
y, x = target.size()
target_onehot = torch.zeros(self.num_classes + 1, y, x)
target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
return target_onehot[1:]
def semantic_inference(self, mask_cls, mask_pred):
mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
mask_pred = mask_pred.sigmoid()
semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
return semseg
def forward(self, image, sentences, attentions):
input_shape = image.shape[-2:]
l_mask = attentions.unsqueeze(dim=-1)
i0, Wh, Ww = self.image_model.forward_stem(image)
l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
i1 = self.image_model.forward_stage1(i0, Wh, Ww)
l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
i1 = i1_temp
i2 = self.image_model.forward_stage2(i1, Wh, Ww)
l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
i2 = i2_temp
i3 = self.image_model.forward_stage3(i2, Wh, Ww)
l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
i3 = i3_temp
i4 = self.image_model.forward_stage4(i3, Wh, Ww)
l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
i4 = i4_temp
#i1_residual, i2_residual, i3_residual, i4_residual = features
#x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
#x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
outputs = {}
outputs['s1'] = i1_residual
outputs['s2'] = i2_residual
outputs['s3'] = i3_residual
outputs['s4'] = i4_residual
predictions, mask_predictions = self.classifier(outputs)
return predictions
def main(args):
#def main(local_rank, args):
#device = torch.device(args.device)
device = 'cuda'
dataset_test, _ = get_dataset(args.split, get_transform(args=args), args)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1,
sampler=test_sampler, num_workers=args.workers)
print(args.model)
single_model = multimodal_segmentation_ppm.__dict__[args.model](pretrained='',args=args)
#single_model = MultiModalFocal(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], focal_windows=[9,9,9,9], drop_path_rate=0.3)
#single_model.init_weights('./focalnet_base_lrf.pth')
checkpoint = torch.load(args.resume, map_location='cpu')
#single_model.load_state_dict(checkpoint['model'])
#model = single_model.to(device)
if args.model != 'lavt_one':
model_class = MultiModalBert
#single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=128)
single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=single_model.backbone.embed_dim)
# work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines
if args.ddp_trained_weights:
single_bert_model.pooler = None
#single_bert_model.load_state_dict(checkpoint['bert_model'])
#bert_model = single_bert_model.to(device)
else:
bert_model = None
#model = WrapperModel(single_model.backbone, single_bert_model, single_model.classifier)
#model.load_state_dict(checkpoint['model'])
#model.to(device)
input_shape = dict()
input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
cfg = Dict()
cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
cfg.MODEL.MASK_FORMER.NHEADS = 8
cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
cfg.MODEL.MASK_FORMER.PRE_NORM = False
maskformer_head = MaskFormerHead(cfg, input_shape)
#maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
#maskformer_head.cuda()
#maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
#single_head = maskformer_head.module
#print(single_head)
model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head, args)
model.load_state_dict(checkpoint['model'])
model.to(device)
#model.cuda()
#model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
#single_model = model.module
#model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
#single_model = model.module
evaluate(model, data_loader_test, device=device, args=args)
if __name__ == "__main__":
from args import get_parser
parser = get_parser()
args = parser.parse_args()
print('Image size: {}'.format(str(args.img_size)))
print(args)
os.makedirs('vis/' + args.vis_dir, exist_ok=True)
main(args)
#mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())