|
import datetime |
|
import os |
|
import time |
|
|
|
import torch |
|
import torch.utils.data |
|
from torch import nn |
|
|
|
from functools import reduce |
|
import operator |
|
from bert.modeling_bert import BertModel |
|
|
|
import torchvision |
|
from lib import segmentation |
|
|
|
import transforms as T |
|
import utils |
|
import numpy as np |
|
|
|
import torch.nn.functional as F |
|
|
|
import gc |
|
from collections import OrderedDict |
|
from data.utils import MosaicVisualization, COCOVisualization |
|
|
|
import albumentations as A |
|
from albumentations.pytorch import ToTensorV2 |
|
|
|
def get_dataset(image_set, transform, args): |
|
if args.dataset == "grefcoco": |
|
|
|
from data.dataset_grefer_mosaic_retrieval import GReferDataset |
|
ds = GReferDataset(args=args, |
|
refer_root=args.refer_data_root, |
|
dataset_name=args.dataset, |
|
splitby=args.splitBy, |
|
split=image_set, |
|
image_root=os.path.join(args.refer_data_root, 'images/train2014') |
|
) |
|
fpath = os.path.join('coco-data-vis-retrieval', args.model_id, image_set) |
|
MosaicVisualization(ds, fpath) |
|
else : |
|
|
|
from data.dataset_refer_bert_mosaic_retrieval import ReferDataset |
|
ds = ReferDataset(args, |
|
split=image_set |
|
) |
|
fpath = os.path.join('coco-data-vis-retrieval', args.model_id, image_set) |
|
MosaicVisualization(ds, fpath) |
|
num_classes = 2 |
|
|
|
return ds, num_classes |
|
|
|
|
|
|
|
def IoU(pred, gt): |
|
pred = pred.argmax(1) |
|
|
|
intersection = torch.sum(torch.mul(pred, gt)) |
|
union = torch.sum(torch.add(pred, gt)) - intersection |
|
|
|
if intersection == 0 or union == 0: |
|
iou = 0 |
|
else: |
|
iou = float(intersection) / float(union) |
|
|
|
return iou, intersection, union |
|
|
|
|
|
def get_transform(args): |
|
transforms = [T.Resize(args.img_size, args.img_size), |
|
T.ToTensor(), |
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
] |
|
|
|
return T.Compose(transforms) |
|
|
|
|
|
def criterion(input, target): |
|
weight = torch.FloatTensor([0.9, 1.1]).cuda() |
|
return nn.functional.cross_entropy(input, target, weight=weight) |
|
|
|
|
|
def evaluate(model, data_loader, bert_model=None): |
|
model.eval() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = 'Test:' |
|
total_its = 0 |
|
acc_ious = 0 |
|
|
|
|
|
cum_I, cum_U = 0, 0 |
|
eval_seg_iou_list = [.5, .6, .7, .8, .9] |
|
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) |
|
seg_total = 0 |
|
mean_IoU = [] |
|
|
|
with torch.no_grad(): |
|
for data in metric_logger.log_every(data_loader, 100, header): |
|
total_its += 1 |
|
image, target, sentences, attentions = data['image'], data['seg_target'], data['sentence'], data['attn_mask'] |
|
image, target, sentences, attentions = image.cuda(non_blocking=True),\ |
|
target.cuda(non_blocking=True),\ |
|
sentences.cuda(non_blocking=True),\ |
|
attentions.cuda(non_blocking=True) |
|
|
|
sentences = sentences.squeeze(1) |
|
attentions = attentions.squeeze(1) |
|
|
|
if bert_model is not None: |
|
last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] |
|
embedding = last_hidden_states.permute(0, 2, 1) |
|
attentions = attentions.unsqueeze(dim=-1) |
|
output = model(image, embedding, l_mask=attentions) |
|
else: |
|
output = model(image, sentences, l_mask=attentions) |
|
|
|
iou, I, U = IoU(output, target) |
|
acc_ious += iou |
|
mean_IoU.append(iou) |
|
cum_I += I |
|
cum_U += U |
|
for n_eval_iou in range(len(eval_seg_iou_list)): |
|
eval_seg_iou = eval_seg_iou_list[n_eval_iou] |
|
seg_correct[n_eval_iou] += (iou >= eval_seg_iou) |
|
seg_total += 1 |
|
|
|
|
|
|
|
|
|
iou = acc_ious / total_its |
|
|
|
mean_IoU = np.array(mean_IoU) |
|
mIoU = np.mean(mean_IoU) |
|
print('Final results:') |
|
print('Mean IoU is %.2f\n' % (mIoU * 100.)) |
|
results_str = '' |
|
for n_eval_iou in range(len(eval_seg_iou_list)): |
|
results_str += ' precision@%s = %.2f\n' % \ |
|
(str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) |
|
results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) |
|
print(results_str) |
|
|
|
return 100 * iou, 100 * cum_I / cum_U |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_transform(args): |
|
mean = (0.485, 0.456, 0.406) |
|
std = (0.229, 0.224, 0.225) |
|
transforms = A.Compose([ |
|
A.Resize(args.img_size, args.img_size, always_apply=True), |
|
A.Normalize(mean=mean, std=std), |
|
ToTensorV2 (), |
|
]) |
|
return transforms |
|
|
|
|
|
def computeIoU(pred_seg, gd_seg): |
|
I = np.sum(np.logical_and(pred_seg, gd_seg)) |
|
U = np.sum(np.logical_or(pred_seg, gd_seg)) |
|
return I, U |
|
|
|
|
|
def main(args): |
|
device = 'cuda' |
|
dataset_test,_ = get_dataset(args.split, get_transform(args=args), args=args) |
|
test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
|
data_loader_test = torch.utils.data.DataLoader( |
|
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers) |
|
|
|
|
|
print(args.model) |
|
model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights, |
|
args=args) |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
|
|
|
|
|
|
|
|
|
|
model.text_encoder.resize_token_embeddings(len(dataset_test.tokenizer)) |
|
model.cuda() |
|
checkpoint = torch.load(args.resume, map_location='cpu') |
|
model.load_state_dict(checkpoint['model'], strict=True) |
|
|
|
|
|
iou, overallIoU = evaluate(model, data_loader_test) |
|
|
|
print('Average object IoU {}'.format(iou)) |
|
print('Overall IoU {}'.format(overallIoU)) |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='RefCOCO Test') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
assert args.config is not None |
|
cfg = OmegaConf.load(args.config) |
|
cfg['local_rank'] = args.local_rank |
|
return cfg |
|
|
|
if __name__ == "__main__": |
|
from args import get_parser |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
if args.config is not None : |
|
from config.utils import CfgNode |
|
cn = CfgNode(CfgNode.load_yaml_with_base(args.config)) |
|
for k,v in cn.items(): |
|
if not hasattr(args, k): |
|
print('Warning: key %s not in args' %k) |
|
setattr(args, k, v) |
|
args = parser.parse_args(namespace=args) |
|
print(args) |
|
print(f'Image size: {args.img_size}') |
|
main(args) |
|
|