|
import os |
|
import argparse |
|
import numpy as np |
|
from tqdm import tqdm |
|
from collections import OrderedDict |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils import data |
|
import torchvision.transforms as transform |
|
from torch.nn.parallel.scatter_gather import gather |
|
import encoding.utils as utils |
|
from encoding.nn import SegmentationLosses, SyncBatchNorm |
|
from encoding.parallel import DataParallelModel, DataParallelCriterion |
|
from encoding.datasets import test_batchify_fn |
|
from encoding.models.sseg import BaseNet |
|
from modules.lseg_module import LSegModule |
|
from utils import Resize |
|
import cv2 |
|
import math |
|
import types |
|
import functools |
|
import torchvision.transforms as torch_transforms |
|
import copy |
|
import itertools |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import clip |
|
import matplotlib as mpl |
|
import matplotlib.colors as mplc |
|
import matplotlib.figure as mplfigure |
|
import matplotlib.patches as mpatches |
|
from matplotlib.backends.backend_agg import FigureCanvasAgg |
|
from data import get_dataset |
|
from additional_utils.encoding_models import MultiEvalModule as LSeg_MultiEvalModule |
|
import torchvision.transforms as transforms |
|
|
|
class Options: |
|
def __init__(self): |
|
parser = argparse.ArgumentParser(description="PyTorch Segmentation") |
|
|
|
parser.add_argument( |
|
"--model", type=str, default="encnet", help="model name (default: encnet)" |
|
) |
|
parser.add_argument( |
|
"--backbone", |
|
type=str, |
|
default="clip_vitl16_384", |
|
help="backbone name (default: resnet50)", |
|
) |
|
parser.add_argument( |
|
"--dataset", |
|
type=str, |
|
default="ade20k", |
|
help="dataset name (default: pascal12)", |
|
) |
|
parser.add_argument( |
|
"--workers", type=int, default=16, metavar="N", help="dataloader threads" |
|
) |
|
parser.add_argument( |
|
"--base-size", type=int, default=520, help="base image size" |
|
) |
|
parser.add_argument( |
|
"--crop-size", type=int, default=480, help="crop image size" |
|
) |
|
parser.add_argument( |
|
"--train-split", |
|
type=str, |
|
default="train", |
|
help="dataset train split (default: train)", |
|
) |
|
|
|
parser.add_argument( |
|
"--aux", action="store_true", default=False, help="Auxilary Loss" |
|
) |
|
parser.add_argument( |
|
"--se-loss", |
|
action="store_true", |
|
default=False, |
|
help="Semantic Encoding Loss SE-loss", |
|
) |
|
parser.add_argument( |
|
"--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" |
|
) |
|
parser.add_argument( |
|
"--batch-size", |
|
type=int, |
|
default=16, |
|
metavar="N", |
|
help="input batch size for \ |
|
training (default: auto)", |
|
) |
|
parser.add_argument( |
|
"--test-batch-size", |
|
type=int, |
|
default=16, |
|
metavar="N", |
|
help="input batch size for \ |
|
testing (default: same as batch size)", |
|
) |
|
|
|
parser.add_argument( |
|
"--no-cuda", |
|
action="store_true", |
|
default=False, |
|
help="disables CUDA training", |
|
) |
|
parser.add_argument( |
|
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" |
|
) |
|
parser.add_argument( |
|
"--weights", type=str, default=None, help="checkpoint to test" |
|
) |
|
parser.add_argument( |
|
"--eval", action="store_true", default=False, help="evaluating mIoU" |
|
) |
|
parser.add_argument( |
|
"--export", |
|
type=str, |
|
default=None, |
|
help="put the path to resuming file if needed", |
|
) |
|
|
|
parser.add_argument( |
|
"--acc-bn", |
|
action="store_true", |
|
default=False, |
|
help="Re-accumulate BN statistics", |
|
) |
|
parser.add_argument( |
|
"--test-val", |
|
action="store_true", |
|
default=False, |
|
help="generate masks on val set", |
|
) |
|
parser.add_argument( |
|
"--no-val", |
|
action="store_true", |
|
default=False, |
|
help="skip validation during training", |
|
) |
|
parser.add_argument( |
|
"--module", |
|
default='lseg', |
|
help="select model definition", |
|
) |
|
|
|
parser.add_argument( |
|
"--data-path", type=str, default=None, help="path to test image folder" |
|
) |
|
parser.add_argument( |
|
"--no-scaleinv", |
|
dest="scale_inv", |
|
default=True, |
|
action="store_false", |
|
help="turn off scaleinv layers", |
|
) |
|
parser.add_argument( |
|
"--widehead", default=False, action="store_true", help="wider output head" |
|
) |
|
parser.add_argument( |
|
"--widehead_hr", |
|
default=False, |
|
action="store_true", |
|
help="wider output head", |
|
) |
|
parser.add_argument( |
|
"--ignore_index", |
|
type=int, |
|
default=-1, |
|
help="numeric value of ignore label in gt", |
|
) |
|
parser.add_argument( |
|
"--label_src", |
|
type=str, |
|
default="default", |
|
help="how to get the labels", |
|
) |
|
parser.add_argument( |
|
"--jobname", |
|
type=str, |
|
default="default", |
|
help="select which dataset", |
|
) |
|
parser.add_argument( |
|
"--no-strict", |
|
dest="strict", |
|
default=True, |
|
action="store_false", |
|
help="no-strict copy the model", |
|
) |
|
parser.add_argument( |
|
"--arch_option", |
|
type=int, |
|
default=0, |
|
help="which kind of architecture to be used", |
|
) |
|
parser.add_argument( |
|
"--block_depth", |
|
type=int, |
|
default=0, |
|
help="how many blocks should be used", |
|
) |
|
parser.add_argument( |
|
"--activation", |
|
choices=['lrelu', 'tanh'], |
|
default="lrelu", |
|
help="use which activation to activate the block", |
|
) |
|
|
|
self.parser = parser |
|
|
|
def parse(self): |
|
args = self.parser.parse_args() |
|
args.cuda = not args.no_cuda and torch.cuda.is_available() |
|
print(args) |
|
return args |
|
|
|
|
|
def test(args): |
|
|
|
module = LSegModule.load_from_checkpoint( |
|
checkpoint_path=args.weights, |
|
data_path=args.data_path, |
|
dataset=args.dataset, |
|
backbone=args.backbone, |
|
aux=args.aux, |
|
num_features=256, |
|
aux_weight=0, |
|
se_loss=False, |
|
se_weight=0, |
|
base_lr=0, |
|
batch_size=1, |
|
max_epochs=0, |
|
ignore_index=args.ignore_index, |
|
dropout=0.0, |
|
scale_inv=args.scale_inv, |
|
augment=False, |
|
no_batchnorm=False, |
|
widehead=args.widehead, |
|
widehead_hr=args.widehead_hr, |
|
map_locatin="cpu", |
|
arch_option=args.arch_option, |
|
strict=args.strict, |
|
block_depth=args.block_depth, |
|
activation=args.activation, |
|
) |
|
input_transform = module.val_transform |
|
num_classes = module.num_classes |
|
|
|
|
|
testset = get_dataset( |
|
args.dataset, |
|
root=args.data_path, |
|
split="val", |
|
mode="testval", |
|
transform=input_transform, |
|
) |
|
|
|
|
|
loader_kwargs = ( |
|
{"num_workers": args.workers, "pin_memory": True} if args.cuda else {} |
|
) |
|
test_data = data.DataLoader( |
|
testset, |
|
batch_size=args.test_batch_size, |
|
drop_last=False, |
|
shuffle=False, |
|
collate_fn=test_batchify_fn, |
|
**loader_kwargs |
|
) |
|
|
|
if isinstance(module.net, BaseNet): |
|
model = module.net |
|
else: |
|
model = module |
|
|
|
model = model.eval() |
|
model = model.cpu() |
|
|
|
print(model) |
|
if args.acc_bn: |
|
from encoding.utils.precise_bn import update_bn_stats |
|
|
|
data_kwargs = { |
|
"transform": input_transform, |
|
"base_size": args.base_size, |
|
"crop_size": args.crop_size, |
|
} |
|
trainset = get_dataset( |
|
args.dataset, split=args.train_split, mode="train", **data_kwargs |
|
) |
|
trainloader = data.DataLoader( |
|
ReturnFirstClosure(trainset), |
|
root=args.data_path, |
|
batch_size=args.batch_size, |
|
drop_last=True, |
|
shuffle=True, |
|
**loader_kwargs |
|
) |
|
print("Reseting BN statistics") |
|
model.cuda() |
|
update_bn_stats(model, trainloader) |
|
|
|
if args.export: |
|
torch.save(model.state_dict(), args.export + ".pth") |
|
return |
|
|
|
scales = ( |
|
[0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] |
|
if args.dataset == "citys" |
|
else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] |
|
) |
|
|
|
evaluator = LSeg_MultiEvalModule( |
|
model, num_classes, scales=scales, flip=True |
|
).cuda() |
|
evaluator.eval() |
|
|
|
metric = utils.SegmentationMetric(testset.num_class) |
|
tbar = tqdm(test_data) |
|
|
|
f = open("logs/log_test_{}_{}.txt".format(args.jobname, args.dataset), "a+") |
|
per_class_iou = np.zeros(testset.num_class) |
|
cnt = 0 |
|
for i, (image, dst) in enumerate(tbar): |
|
if args.eval: |
|
with torch.no_grad(): |
|
if False: |
|
sample = {"image": image[0].cpu().permute(1, 2, 0).numpy()} |
|
out = torch.zeros( |
|
1, testset.num_class, image[0].shape[1], image[0].shape[2] |
|
).cuda() |
|
|
|
H, W = image[0].shape[1], image[0].shape[2] |
|
for scale in scales: |
|
long_size = int(math.ceil(520 * scale)) |
|
if H > W: |
|
height = long_size |
|
width = int(1.0 * W * long_size / H + 0.5) |
|
short_size = width |
|
else: |
|
width = long_size |
|
height = int(1.0 * H * long_size / W + 0.5) |
|
short_size = height |
|
|
|
rs = Resize( |
|
width, |
|
height, |
|
resize_target=False, |
|
keep_aspect_ratio=True, |
|
ensure_multiple_of=32, |
|
resize_method="minimal", |
|
image_interpolation_method=cv2.INTER_AREA, |
|
) |
|
|
|
inf_image = ( |
|
torch.from_numpy(rs(sample)["image"]) |
|
.cuda() |
|
.permute(2, 0, 1) |
|
.unsqueeze(0) |
|
) |
|
inf_image = torch.cat((inf_image, torch.fliplr(inf_image)), 0) |
|
try: |
|
pred = model(inf_image) |
|
except: |
|
print(H, W, sz, i) |
|
exit() |
|
|
|
pred0 = F.softmax(pred[0], dim=1) |
|
pred1 = F.softmax(pred[1], dim=1) |
|
|
|
pred = pred0 + 0.2 * pred1 |
|
|
|
out += F.interpolate( |
|
pred.sum(0, keepdim=True), |
|
(out.shape[2], out.shape[3]), |
|
mode="bilinear", |
|
align_corners=True, |
|
) |
|
|
|
predicts = [out] |
|
else: |
|
predicts = evaluator.parallel_forward(image) |
|
|
|
metric.update(dst, predicts) |
|
pixAcc, mIoU = metric.get() |
|
|
|
_, _, total_inter, total_union = metric.get_all() |
|
per_class_iou += 1.0 * total_inter / (np.spacing(1) + total_union) |
|
cnt+=1 |
|
|
|
tbar.set_description("pixAcc: %.4f, mIoU: %.4f" % (pixAcc, mIoU)) |
|
else: |
|
with torch.no_grad(): |
|
outputs = evaluator.parallel_forward(image) |
|
predicts = [ |
|
testset.make_pred(torch.max(output, 1)[1].cpu().numpy()) |
|
for output in outputs |
|
] |
|
|
|
|
|
outdir = "outdir_ours" |
|
if not os.path.exists(outdir): |
|
os.makedirs(outdir) |
|
|
|
for predict, impath in zip(predicts, dst): |
|
mask = utils.get_mask_pallete(predict, args.dataset) |
|
outname = os.path.splitext(impath)[0] + ".png" |
|
mask.save(os.path.join(outdir, outname)) |
|
|
|
if args.eval: |
|
each_classes_iou = per_class_iou/cnt |
|
print("pixAcc: %.4f, mIoU: %.4f" % (pixAcc, mIoU)) |
|
print(each_classes_iou) |
|
f.write("dataset {} ==> pixAcc: {:.4f}, mIoU: {:.4f}\n".format(args.dataset, pixAcc, mIoU)) |
|
for per_iou in each_classes_iou: f.write('{:.4f}, '.format(per_iou)) |
|
f.write('\n') |
|
|
|
|
|
class ReturnFirstClosure(object): |
|
def __init__(self, data): |
|
self._data = data |
|
|
|
def __len__(self): |
|
return len(self._data) |
|
|
|
def __getitem__(self, idx): |
|
outputs = self._data[idx] |
|
return outputs[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
args = Options().parse() |
|
torch.manual_seed(args.seed) |
|
args.test_batch_size = torch.cuda.device_count() |
|
test(args) |
|
|