import datetime import os import time import torch import torch.utils.data from torch import nn from bert.modeling_bert import BertModel import torchvision from lib import segmentation import transforms as T import utils import numpy as np from PIL import Image import torch.nn.functional as F def get_dataset(image_set, transform, args): from data.dataset_refer_bert import ReferDataset ds = ReferDataset(args, split=image_set, image_transforms=transform, target_transforms=None, eval_mode=True ) num_classes = 2 return ds, num_classes def evaluate(model, data_loader, bert_model, device): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") # evaluation variables cum_I, cum_U = 0, 0 eval_seg_iou_list = [.5, .6, .7, .8, .9] seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) seg_total = 0 mean_IoU = [] header = 'Test:' with torch.no_grad(): for data in metric_logger.log_every(data_loader, 100, header): image, target, sentences, attentions = data image, target, sentences, attentions = image.to(device), target.to(device), \ sentences.to(device), attentions.to(device) sentences = sentences.squeeze(1) attentions = attentions.squeeze(1) target = target.cpu().data.numpy() for j in range(sentences.size(-1)): if bert_model is not None: last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] embedding = last_hidden_states.permute(0, 2, 1) output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) else: output = model(image, sentences[:, :, j], l_mask=attentions[:, :, j]) output = output.cpu() output_mask = output.argmax(1).data.numpy() I, U = computeIoU(output_mask, target) if U == 0: this_iou = 0.0 else: this_iou = I*1.0/U mean_IoU.append(this_iou) cum_I += I cum_U += U for n_eval_iou in range(len(eval_seg_iou_list)): eval_seg_iou = eval_seg_iou_list[n_eval_iou] seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) seg_total += 1 del image, target, sentences, attentions, output, output_mask if bert_model is not None: del last_hidden_states, embedding mean_IoU = np.array(mean_IoU) mIoU = np.mean(mean_IoU) print('Final results:') print('Mean IoU is %.2f\n' % (mIoU*100.)) results_str = '' for n_eval_iou in range(len(eval_seg_iou_list)): results_str += ' precision@%s = %.2f\n' % \ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) print(results_str) def get_transform(args): transforms = [T.Resize(args.img_size, args.img_size), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] return T.Compose(transforms) def computeIoU(pred_seg, gd_seg): I = np.sum(np.logical_and(pred_seg, gd_seg)) U = np.sum(np.logical_or(pred_seg, gd_seg)) return I, U def main(args): device = torch.device(args.device) dataset_test, _ = get_dataset(args.split, get_transform(args=args), args) test_sampler = torch.utils.data.SequentialSampler(dataset_test) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers) print(args.model) single_model = segmentation.__dict__[args.model](pretrained='',args=args) checkpoint = torch.load(args.resume, map_location='cpu') single_model.load_state_dict(checkpoint['model']) model = single_model.to(device) if args.model != 'lavt_one': model_class = BertModel single_bert_model = model_class.from_pretrained(args.ck_bert) # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines if args.ddp_trained_weights: single_bert_model.pooler = None single_bert_model.load_state_dict(checkpoint['bert_model']) bert_model = single_bert_model.to(device) else: bert_model = None evaluate(model, data_loader_test, bert_model, device=device) if __name__ == "__main__": from args import get_parser parser = get_parser() args = parser.parse_args() print('Image size: {}'.format(str(args.img_size))) main(args)