import datetime import os import time import torch import torch.utils.data from torch import nn from bert.multimodal_bert import MultiModalBert import torchvision from lib import multimodal_segmentation_ppm import transforms as T import utils import numpy as np from PIL import Image import torch.nn.functional as F from modeling.MaskFormerModel import MaskFormerHead from addict import Dict from bert.modeling_bert import BertLMPredictionHead, BertEncoder import cv2 import textwrap def get_dataset(image_set, transform, args): from data.dataset_refer_bert_vis import ReferDataset ds = ReferDataset(args, split=image_set, image_transforms=transform, target_transforms=None, eval_mode=True ) num_classes = 2 return ds, num_classes def overlay_davis(image, mask, colors=[[0, 0, 0], [0, 255, 0]], cscale=1, alpha=0.4): from scipy.ndimage.morphology import binary_dilation colors = np.reshape(colors, (-1, 3)) colors = np.atleast_2d(colors) * cscale im_overlay = image.copy() object_ids = np.unique(mask) for object_id in object_ids[1:]: # Overlay color on binary mask foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) binary_mask = mask == object_id # Compose image im_overlay[binary_mask] = foreground[binary_mask] # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask countours = binary_dilation(binary_mask) ^ binary_mask # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask im_overlay[countours, :] = 0 return im_overlay.astype(image.dtype) def evaluate(model, data_loader, device, args): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") # evaluation variables cum_I, cum_U = 0, 0 eval_seg_iou_list = [.5, .6, .7, .8, .9] seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) seg_total = 0 mean_IoU = [] header = 'Test:' with torch.no_grad(): idx = 0 for data in metric_logger.log_every(data_loader, 100, header): idx += 1 image, target, sentences, attentions, raw_sentences, this_img, orig_img = data image, target, sentences, attentions = image.to(device), target.to(device), \ sentences.to(device), attentions.to(device) sentences = sentences.squeeze(1) attentions = attentions.squeeze(1) #target = target.cpu().data.numpy() b, h, w, c = orig_img.shape #orig_img = orig_img.numpy()[:, :, :, ::-1] orig_img =orig_img.data.cpu().numpy()[0, :, :, :].astype(np.uint8) vis = np.zeros((h, w*2,3)).astype(np.uint8) #image_mean_iou = [] target_numpy = target.cpu().numpy() for j in range(sentences.size(-1)): #if bert_model is not None: # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] # embedding = last_hidden_states.permute(0, 2, 1) # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) #else: output = model(image, sentences[:, :, j], attentions[:, :, j]) mask_cls_results = output["pred_logits"] mask_pred_results = output["pred_masks"] target_shape = target.shape[-2:] mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True) pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results) #output = pred_masks[0] #output = output.cpu() I, U = computeIoU(pred_masks.cpu().numpy(), target_numpy) if U == 0: this_iou = 0.0 else: this_iou = I*1.0/U mean_IoU.append(this_iou) #image_mean_iou.append(this_iou) cum_I += I cum_U += U for n_eval_iou in range(len(eval_seg_iou_list)): eval_seg_iou = eval_seg_iou_list[n_eval_iou] seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) seg_total += 1 #print(output.shape) #output_mask = output.argmax(1).data.numpy() #output_mask = (output > 0.5).data.numpy() #vis_output_mask = torch.sigmoid(output[:, 1]).data.numpy() #vis_output_mask = torch.sigmoid((output>0.5).float()).data.numpy() #soft #vis_output_mask = output.data.numpy() #vis_output_mask = output_mask #print(output.shape, orig_shape) gt_masks = torch.nn.functional.interpolate(target.unsqueeze(0).float(), (h, w)) pred_masks = torch.nn.functional.interpolate(pred_masks, (h, w)) #print(orig_mask.shape) pred_masks = (pred_masks > 0.5).data.cpu().numpy() #ntarget = target.data.cpu().numpy() ##orig_mask = orig_mask.argmax(1).data.numpy() #print(orig_img[0].shape, orig_mask[0][0].shape, flush=True) #print(orig_img.dtype, orig_mask.dtype) predict_imgs = overlay_davis(orig_img, pred_masks[0][0].astype(np.uint8)) gt_imgs = overlay_davis(orig_img, gt_masks[0][0].cpu().numpy().astype(np.uint8), colors=[[0, 0, 0], [0, 0, 255]]) #print(orig_mask.shape, orig_img.shape) #red_mask = np.zeros((orig_mask.shape[1], orig_mask.shape[2], orig_mask.shape[3], 3)).astype(np.uint8) #print("???", red_mask.shape, orig_mask.shape) #red_mask[:, :, :, 1] = orig_mask * 255 #red_mask = cv2.bitwise_and(red_mask, red_mask, orig_mask.astype(np.uint8)) #temp = cv2.addWeighted(red_mask, 0.5, orig_img, 0.5, 0) #print(orig_img.shape, temp.shape, orig_mask.shape, "WHAT?") #new = orig_img * (1.0 - orig_mask[0][:,:,:,None]) + temp * orig_mask[0][:,:,:,None] #print(new.shape, orig_mask.shape, temp.shape, "check") ##print(vis_output_mask) ##output_mask = output.argmax(1).data.numpy() # #print(raw_sentences[j]) # print(image.shape, target.shape, output_mask.shape) #mean = np.array([0.485, 0.456, 0.406]) #std = np.array([0.229, 0.224, 0.225]) #np_image = (((image[0].permute(1,2,0).cpu().numpy() * std) + mean) * 255).astype(np.uint8)[:,:,::-1] #np_target = (target * 255).transpose(1,2,0).astype(np.uint8) ##print(output_mask) #np_output_mask = (vis_output_mask*255).transpose(1,2,0).repeat(3, axis=2).astype(np.uint8) font = cv2.FONT_HERSHEY_DUPLEX fontScale = 1.0 fontColor = (255,0,0) thickness = 1 lineType = 2 wrapped_text = textwrap.wrap(' '.join(raw_sentences[j]), width=35) for k, line in enumerate(wrapped_text): bottomLeftCornerOfText = (10,h-60 + k*30) gt_imgs = cv2.putText(gt_imgs, line, bottomLeftCornerOfText, font, fontScale, fontColor, thickness, lineType) #temp = j + 2 #split = temp // 3 #row = temp % 3 vis[0:h, 0:w, :] = gt_imgs vis[0:h, w:2*w, :] = predict_imgs #cv2.imwrite("vis/elifan_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], j), new[0].astype(np.uint8)) cv2.imwrite("vis/{:s}/{:s}_{:d}_{:d}_{:.2f}.jpg".format(args.vis_dir, this_img[0].split('.')[0], idx, j, this_iou), vis[:, :, ::-1].astype(np.uint8)) #print('---------------') #cv2.imshow("vis", vis) #cv2.waitKey(0) #image_mean_iou = np.mean(np.array(image_mean_iou)) #print(image_mean_iou) #if image_mean_iou < 0.5: #cv2.imwrite("vis/elian_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], idx), vis) #del image, target, sentences, attentions, output, output_mask #if bert_model is not None: # del last_hidden_states, embedding mean_IoU = np.array(mean_IoU) mIoU = np.mean(mean_IoU) print('Final results:') print('Mean IoU is %.2f\n' % (mIoU*100.)) results_str = '' for n_eval_iou in range(len(eval_seg_iou_list)): results_str += ' precision@%s = %.2f\n' % \ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) print(results_str) #def evaluate(model, data_loader, device): # model.eval() # metric_logger = utils.MetricLogger(delimiter=" ") # # # evaluation variables # cum_I, cum_U = 0, 0 # eval_seg_iou_list = [.5, .6, .7, .8, .9] # seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) # seg_total = 0 # mean_IoU = [] # header = 'Test:' # # with torch.no_grad(): # for data in metric_logger.log_every(data_loader, 100, header): # image, target, sentences, attentions = data # image, target, sentences, attentions = image.to(device), target.to(device), \ # sentences.to(device), attentions.to(device) # sentences = sentences.squeeze(1) # attentions = attentions.squeeze(1) # target = target.cpu().data.numpy() # for j in range(sentences.size(-1)): # #if bert_model is not None: # # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] # # embedding = last_hidden_states.permute(0, 2, 1) # # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) # #else: # output = model(image, sentences[:, :, j], attentions[:, :, j]) # mask_cls_results = output["pred_logits"] # mask_pred_results = output["pred_masks"] # # target_shape = target.shape[-2:] # mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True) # # pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results) # output = pred_masks[0] # # output = output.cpu() # #print(output.shape) # #output_mask = output.argmax(1).data.numpy() # output_mask = (output > 0.5).data.numpy() # I, U = computeIoU(output_mask, target) # if U == 0: # this_iou = 0.0 # else: # this_iou = I*1.0/U # mean_IoU.append(this_iou) # cum_I += I # cum_U += U # for n_eval_iou in range(len(eval_seg_iou_list)): # eval_seg_iou = eval_seg_iou_list[n_eval_iou] # seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) # seg_total += 1 # # #del image, target, sentences, attentions, output, output_mask # #if bert_model is not None: # # del last_hidden_states, embedding # # mean_IoU = np.array(mean_IoU) # mIoU = np.mean(mean_IoU) # print('Final results:') # print('Mean IoU is %.2f\n' % (mIoU*100.)) # results_str = '' # for n_eval_iou in range(len(eval_seg_iou_list)): # results_str += ' precision@%s = %.2f\n' % \ # (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) # results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) # print(results_str) def get_transform(args): transforms = [T.Resize(args.img_size, args.img_size), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] return T.Compose(transforms) def computeIoU(pred_seg, gd_seg): I = np.sum(np.logical_and(pred_seg, gd_seg)) U = np.sum(np.logical_or(pred_seg, gd_seg)) return I, U class WrapperModel(nn.Module): def __init__(self, image_model, language_model, classifier, args) : super(WrapperModel, self).__init__() self.image_model = image_model self.language_model = language_model self.classifier = classifier self.lang_proj = nn.Linear(768,256) config = Dict({ "architectures": [ "BertForMaskedLM" ], "attention_probs_dropout_prob": 0.1, "gradient_checkpointing": False, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 512, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, #"max_position_embeddings": 16+20, "model_type": "bert", "num_attention_heads": 8, "num_hidden_layers": 8, "pad_token_id": 0, "position_embedding_type": "absolute", "transformers_version": "4.6.0.dev0", "type_vocab_size": 2, "use_cache": True, "vocab_size": 30522 }) self.mlm_transformer = BertEncoder(config) self.lang_proj = nn.Linear(768,256) self.mlm_vis_proj = nn.Conv2d(1024,512,1) self.mlm_lang_proj = nn.Linear(768,512) #print(vis_proj) self.mlm_head = BertLMPredictionHead(config) assert args.img_size % 4 == 0 num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2 print(num_img_tokens) self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512) self.mlm_modal_embeds = nn.Embedding(3, 512) self.mlm_mask_embed = nn.Embedding(1, 512) self.mlm_pos_mlp = nn.Sequential( nn.Linear(2, 512), nn.LayerNorm(512), nn.Linear(512,512), nn.GELU() ) def _get_binary_mask(self, target): # 返回每类的binary mask y, x = target.size() target_onehot = torch.zeros(self.num_classes + 1, y, x) target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1) return target_onehot[1:] def semantic_inference(self, mask_cls, mask_pred): mask_cls = F.softmax(mask_cls, dim=1)[...,1:] mask_pred = mask_pred.sigmoid() semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) return semseg def forward(self, image, sentences, attentions): input_shape = image.shape[-2:] l_mask = attentions.unsqueeze(dim=-1) i0, Wh, Ww = self.image_model.forward_stem(image) l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions) i1 = self.image_model.forward_stage1(i0, Wh, Ww) l1 = self.language_model.forward_stage1(l0, extended_attention_mask) i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask) l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) i1 = i1_temp i2 = self.image_model.forward_stage2(i1, Wh, Ww) l2 = self.language_model.forward_stage2(l1, extended_attention_mask) i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask) l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) i2 = i2_temp i3 = self.image_model.forward_stage3(i2, Wh, Ww) l3 = self.language_model.forward_stage3(l2, extended_attention_mask) i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask) l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) i3 = i3_temp i4 = self.image_model.forward_stage4(i3, Wh, Ww) l4 = self.language_model.forward_stage4(l3, extended_attention_mask) i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask) l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) i4 = i4_temp #i1_residual, i2_residual, i3_residual, i4_residual = features #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual) #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) outputs = {} outputs['s1'] = i1_residual outputs['s2'] = i2_residual outputs['s3'] = i3_residual outputs['s4'] = i4_residual predictions, mask_predictions = self.classifier(outputs) return predictions def main(args): #def main(local_rank, args): #device = torch.device(args.device) device = 'cuda' dataset_test, _ = get_dataset(args.split, get_transform(args=args), args) test_sampler = torch.utils.data.SequentialSampler(dataset_test) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers) print(args.model) single_model = multimodal_segmentation_ppm.__dict__[args.model](pretrained='',args=args) #single_model = MultiModalFocal(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], focal_windows=[9,9,9,9], drop_path_rate=0.3) #single_model.init_weights('./focalnet_base_lrf.pth') checkpoint = torch.load(args.resume, map_location='cpu') #single_model.load_state_dict(checkpoint['model']) #model = single_model.to(device) if args.model != 'lavt_one': model_class = MultiModalBert #single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=128) single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=single_model.backbone.embed_dim) # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines if args.ddp_trained_weights: single_bert_model.pooler = None #single_bert_model.load_state_dict(checkpoint['bert_model']) #bert_model = single_bert_model.to(device) else: bert_model = None #model = WrapperModel(single_model.backbone, single_bert_model, single_model.classifier) #model.load_state_dict(checkpoint['model']) #model.to(device) input_shape = dict() input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) cfg = Dict() cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 cfg.MODEL.MASK_FORMER.NHEADS = 8 cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4 cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1 cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10 cfg.MODEL.MASK_FORMER.PRE_NORM = False maskformer_head = MaskFormerHead(cfg, input_shape) #maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head) #maskformer_head.cuda() #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False) #single_head = maskformer_head.module #print(single_head) model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head, args) model.load_state_dict(checkpoint['model']) model.to(device) #model.cuda() #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) #single_model = model.module #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) #single_model = model.module evaluate(model, data_loader_test, device=device, args=args) if __name__ == "__main__": from args import get_parser parser = get_parser() args = parser.parse_args() print('Image size: {}'.format(str(args.img_size))) print(args) os.makedirs('vis/' + args.vis_dir, exist_ok=True) main(args) #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())