Spaces:
Sleeping
Sleeping
File size: 5,312 Bytes
982865f f4c3cd9 982865f eecef1d 982865f eecef1d 982865f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import cv2
import torch
import random
import argparse
from glob import glob
from os.path import join
from model.network import Recce
from model.common import freeze_weights
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch.transforms import ToTensorV2
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
# fix random seed
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
parser = argparse.ArgumentParser(description="This code helps you use a trained model to "
"do inference.")
parser.add_argument("--weight", "-w",
type=str,
default=None,
help="Specify the path to the model weight (the state dict file). "
"Do not use this argument when '--bin' is set.")
parser.add_argument("--bin", "-b",
type=str,
default=None,
help="Specify the path to the model bin which ends up with '.bin' "
"(which is generated by the trainer of this project). "
"Do not use this argument when '--weight' is set.")
parser.add_argument("--image", "-i",
type=str,
default=None,
help="Specify the path to the input image. "
"Do not use this argument when '--image_folder' is set.")
parser.add_argument("--image_folder", "-f",
type=str,
default=None,
help="Specify the directory to evaluate all the images. "
"Do not use this argument when '--image' is set.")
parser.add_argument('--device', '-d', type=str,
default="cpu",
help="Specify the device to load the model. Default: 'cpu'.")
parser.add_argument('--image_size', '-s', type=int,
default=299,
help="Specify the spatial size of the input image(s). Default: 299.")
parser.add_argument('--visualize', '-v', action="store_true",
default=False, help='Visualize images.')
def preprocess(file_path):
img = cv2.imread(file_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
compose = Compose([Resize(height=args.image_size, width=args.image_size),
Normalize(mean=[0.5] * 3, std=[0.5] * 3),
ToTensorV2()])
img = compose(image=img)['image'].unsqueeze(0)
return img
def prepare_data():
paths = list()
images = list()
# check the console arguments
if args.image and args.image_folder:
raise ValueError("Only one of '--image' or '--image_folder' can be set.")
elif args.image:
images.append(preprocess(args.image))
paths.append(args.image)
elif args.image_folder:
image_paths = glob(args.image_folder + "/*.jpg")
image_paths.extend(glob(args.image_folder + "/*.png"))
for _ in image_paths:
images.append(preprocess(_))
paths.append(_)
else:
raise ValueError("Neither of '--image' nor '--image_folder' is set. Please specify either "
"one of these two arguments to load input image(s) properly.")
return paths, images
def inference(model, images, paths, device):
mean_pred = 0
for img, pt in zip(images, paths):
img = img.to(device)
prediction = model(img)
prediction = torch.sigmoid(prediction).cpu()
fake = True if prediction >= 0.5 else False
mean_pred += prediction.item()
print(f"path: {pt} \t\t| fake probability: {prediction.item():.4f} \t| "
f"prediction: {'fake' if fake else 'real'}")
if args.visualize:
cvimg = cv2.imread(pt)
cvimg = cv2.putText(cvimg, f'p: {prediction.item():.2f}, ' + f"{'fake' if fake else 'real'}",
(5, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 0, 255) if fake else (255, 0, 0), 2)
cv2.imshow("image", cvimg)
cv2.waitKey(0)
cv2.destroyWindow("image")
mean_pred = mean_pred / len(images)
return mean_pred
def main():
print("Arguments:\n", args, end="\n\n")
# set device
device = torch.device(args.device)
# load model
model = eval("Recce")(num_classes=1)
# check the console arguments
if args.weight and args.bin:
raise ValueError("Only one of '--weight' or '--bin' can be set.")
elif args.weight:
weights = torch.load(args.weight, map_location=device)
elif args.bin:
weights = torch.load(args.bin, map_location=device)["model"]
else:
raise ValueError("Neither of '--weight' nor '--bin' is set. Please specify either "
"one of these two arguments to load model's weight properly.")
model.load_state_dict(weights)
model = model.to(device)
freeze_weights(model)
model.eval()
paths, images = prepare_data()
print("Inference:")
mean_pred = inference(model, images=images, paths=paths, device=device)
print("Mean prediction:", mean_pred)
if __name__ == '__main__':
args = parser.parse_args()
main()
|