|
import argparse |
|
import os, sys |
|
import torch |
|
import cv2 |
|
from torchvision import transforms |
|
from PIL import Image |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from matplotlib import pyplot as plt |
|
from tqdm import tqdm |
|
|
|
|
|
root_path = os.path.abspath('.') |
|
sys.path.append(root_path) |
|
from opt import opt |
|
from dataset_curation_pipeline.IC9600.ICNet import ICNet |
|
|
|
|
|
|
|
inference_transform = transforms.Compose([ |
|
transforms.Resize((512,512)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def blend(ori_img, ic_img, alpha = 0.8, cm = plt.get_cmap("magma")): |
|
cm_ic_map = cm(ic_img) |
|
heatmap = Image.fromarray((cm_ic_map[:, :, -2::-1]*255).astype(np.uint8)) |
|
ori_img = Image.fromarray(ori_img) |
|
blend = Image.blend(ori_img,heatmap,alpha=alpha) |
|
blend = np.array(blend) |
|
return blend |
|
|
|
|
|
def infer_one_image(model, img_path): |
|
with torch.no_grad(): |
|
ori_img = Image.open(img_path).convert("RGB") |
|
ori_height = ori_img.height |
|
ori_width = ori_img.width |
|
img = inference_transform(ori_img) |
|
img = img.cuda() |
|
img = img.unsqueeze(0) |
|
ic_score, ic_map = model(img) |
|
ic_score = ic_score.item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ic_score |
|
|
|
|
|
|
|
def infer_directory(img_dir): |
|
imgs = sorted(os.listdir(img_dir)) |
|
scores = [] |
|
for img in tqdm(imgs): |
|
img_path = os.path.join(img_dir, img) |
|
score = infer_one_image(img_path) |
|
|
|
scores.append((score, img_path)) |
|
print(img_path, score) |
|
|
|
scores = sorted(scores, key=lambda x: x[0]) |
|
scores = scores[::-1] |
|
|
|
for score in scores[:50]: |
|
print(score) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-i', '--input', type = str, default = './example') |
|
parser.add_argument('-o', '--output', type = str, default = './out') |
|
parser.add_argument('-d', '--device', type = int, default=0) |
|
|
|
args = parser.parse_args() |
|
|
|
model = ICNet() |
|
model.load_state_dict(torch.load('./checkpoint/ck.pth',map_location=torch.device('cpu'))) |
|
model.eval() |
|
device = torch.device(args.device) |
|
model.to(device) |
|
|
|
inference_transform = transforms.Compose([ |
|
transforms.Resize((512,512)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
if os.path.isfile(args.input): |
|
infer_one_image(args.input) |
|
else: |
|
infer_directory(args.input) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|