P-DFD / inference.py
mrneuralnet's picture
Modify inference script args
f4c3cd9
raw
history blame contribute delete
No virus
5.31 kB
import cv2
import torch
import random
import argparse
from glob import glob
from os.path import join
from model.network import Recce
from model.common import freeze_weights
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch.transforms import ToTensorV2
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
# fix random seed
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
parser = argparse.ArgumentParser(description="This code helps you use a trained model to "
"do inference.")
parser.add_argument("--weight", "-w",
type=str,
default=None,
help="Specify the path to the model weight (the state dict file). "
"Do not use this argument when '--bin' is set.")
parser.add_argument("--bin", "-b",
type=str,
default=None,
help="Specify the path to the model bin which ends up with '.bin' "
"(which is generated by the trainer of this project). "
"Do not use this argument when '--weight' is set.")
parser.add_argument("--image", "-i",
type=str,
default=None,
help="Specify the path to the input image. "
"Do not use this argument when '--image_folder' is set.")
parser.add_argument("--image_folder", "-f",
type=str,
default=None,
help="Specify the directory to evaluate all the images. "
"Do not use this argument when '--image' is set.")
parser.add_argument('--device', '-d', type=str,
default="cpu",
help="Specify the device to load the model. Default: 'cpu'.")
parser.add_argument('--image_size', '-s', type=int,
default=299,
help="Specify the spatial size of the input image(s). Default: 299.")
parser.add_argument('--visualize', '-v', action="store_true",
default=False, help='Visualize images.')
def preprocess(file_path):
img = cv2.imread(file_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
compose = Compose([Resize(height=args.image_size, width=args.image_size),
Normalize(mean=[0.5] * 3, std=[0.5] * 3),
ToTensorV2()])
img = compose(image=img)['image'].unsqueeze(0)
return img
def prepare_data():
paths = list()
images = list()
# check the console arguments
if args.image and args.image_folder:
raise ValueError("Only one of '--image' or '--image_folder' can be set.")
elif args.image:
images.append(preprocess(args.image))
paths.append(args.image)
elif args.image_folder:
image_paths = glob(args.image_folder + "/*.jpg")
image_paths.extend(glob(args.image_folder + "/*.png"))
for _ in image_paths:
images.append(preprocess(_))
paths.append(_)
else:
raise ValueError("Neither of '--image' nor '--image_folder' is set. Please specify either "
"one of these two arguments to load input image(s) properly.")
return paths, images
def inference(model, images, paths, device):
mean_pred = 0
for img, pt in zip(images, paths):
img = img.to(device)
prediction = model(img)
prediction = torch.sigmoid(prediction).cpu()
fake = True if prediction >= 0.5 else False
mean_pred += prediction.item()
print(f"path: {pt} \t\t| fake probability: {prediction.item():.4f} \t| "
f"prediction: {'fake' if fake else 'real'}")
if args.visualize:
cvimg = cv2.imread(pt)
cvimg = cv2.putText(cvimg, f'p: {prediction.item():.2f}, ' + f"{'fake' if fake else 'real'}",
(5, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 0, 255) if fake else (255, 0, 0), 2)
cv2.imshow("image", cvimg)
cv2.waitKey(0)
cv2.destroyWindow("image")
mean_pred = mean_pred / len(images)
return mean_pred
def main():
print("Arguments:\n", args, end="\n\n")
# set device
device = torch.device(args.device)
# load model
model = eval("Recce")(num_classes=1)
# check the console arguments
if args.weight and args.bin:
raise ValueError("Only one of '--weight' or '--bin' can be set.")
elif args.weight:
weights = torch.load(args.weight, map_location=device)
elif args.bin:
weights = torch.load(args.bin, map_location=device)["model"]
else:
raise ValueError("Neither of '--weight' nor '--bin' is set. Please specify either "
"one of these two arguments to load model's weight properly.")
model.load_state_dict(weights)
model = model.to(device)
freeze_weights(model)
model.eval()
paths, images = prepare_data()
print("Inference:")
mean_pred = inference(model, images=images, paths=paths, device=device)
print("Mean prediction:", mean_pred)
if __name__ == '__main__':
args = parser.parse_args()
main()