File size: 5,312 Bytes
982865f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4c3cd9
 
982865f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eecef1d
982865f
eecef1d
982865f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import cv2
import torch
import random
import argparse
from glob import glob
from os.path import join
from model.network import Recce
from model.common import freeze_weights
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch.transforms import ToTensorV2

import os

os.environ['KMP_DUPLICATE_LIB_OK']='True'

# fix random seed
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

parser = argparse.ArgumentParser(description="This code helps you use a trained model to "
                                             "do inference.")
parser.add_argument("--weight", "-w",
                    type=str,
                    default=None,
                    help="Specify the path to the model weight (the state dict file). "
                         "Do not use this argument when '--bin' is set.")
parser.add_argument("--bin", "-b",
                    type=str,
                    default=None,
                    help="Specify the path to the model bin which ends up with '.bin' "
                         "(which is generated by the trainer of this project). "
                         "Do not use this argument when '--weight' is set.")
parser.add_argument("--image", "-i",
                    type=str,
                    default=None,
                    help="Specify the path to the input image. "
                         "Do not use this argument when '--image_folder' is set.")
parser.add_argument("--image_folder", "-f",
                    type=str,
                    default=None,
                    help="Specify the directory to evaluate all the images. "
                         "Do not use this argument when '--image' is set.")

parser.add_argument('--device', '-d', type=str,
                    default="cpu",
                    help="Specify the device to load the model. Default: 'cpu'.")
parser.add_argument('--image_size', '-s', type=int,
                    default=299,
                    help="Specify the spatial size of the input image(s). Default: 299.")
parser.add_argument('--visualize', '-v', action="store_true",
                    default=False, help='Visualize images.')


def preprocess(file_path):
    img = cv2.imread(file_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    compose = Compose([Resize(height=args.image_size, width=args.image_size),
                       Normalize(mean=[0.5] * 3, std=[0.5] * 3),
                       ToTensorV2()])
    img = compose(image=img)['image'].unsqueeze(0)
    return img


def prepare_data():
    paths = list()
    images = list()
    # check the console arguments
    if args.image and args.image_folder:
        raise ValueError("Only one of '--image' or '--image_folder' can be set.")
    elif args.image:
        images.append(preprocess(args.image))
        paths.append(args.image)
    elif args.image_folder:
        image_paths = glob(args.image_folder + "/*.jpg")
        image_paths.extend(glob(args.image_folder + "/*.png"))
        for _ in image_paths:
            images.append(preprocess(_))
            paths.append(_)
    else:
        raise ValueError("Neither of '--image' nor '--image_folder' is set. Please specify either "
                         "one of these two arguments to load input image(s) properly.")
    return paths, images


def inference(model, images, paths, device):
    mean_pred = 0
    for img, pt in zip(images, paths):
        img = img.to(device)
        prediction = model(img)
        prediction = torch.sigmoid(prediction).cpu()
        fake = True if prediction >= 0.5 else False

        mean_pred += prediction.item()

        print(f"path: {pt} \t\t| fake probability: {prediction.item():.4f} \t| "
              f"prediction: {'fake' if fake else 'real'}")
        if args.visualize:
            cvimg = cv2.imread(pt)
            cvimg = cv2.putText(cvimg, f'p: {prediction.item():.2f}, ' + f"{'fake' if fake else 'real'}",
                                (5, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                                (0, 0, 255) if fake else (255, 0, 0), 2)
            cv2.imshow("image", cvimg)
            cv2.waitKey(0)
            cv2.destroyWindow("image")
    mean_pred = mean_pred / len(images)
    return mean_pred


def main():
    print("Arguments:\n", args, end="\n\n")
    # set device
    device = torch.device(args.device)
    # load model
    model = eval("Recce")(num_classes=1)
    # check the console arguments
    if args.weight and args.bin:
        raise ValueError("Only one of '--weight' or '--bin' can be set.")
    elif args.weight:
        weights = torch.load(args.weight, map_location=device)
    elif args.bin:
        weights = torch.load(args.bin, map_location=device)["model"]
    else:
        raise ValueError("Neither of '--weight' nor '--bin' is set. Please specify either "
                         "one of these two arguments to load model's weight properly.")
    model.load_state_dict(weights)
    model = model.to(device)
    freeze_weights(model)
    model.eval()

    paths, images = prepare_data()
    print("Inference:")
    mean_pred = inference(model, images=images, paths=paths, device=device)
    print("Mean prediction:", mean_pred)


if __name__ == '__main__':
    args = parser.parse_args()
    main()