import cv2 import torch import random import argparse from glob import glob from os.path import join from model.network import Recce from model.common import freeze_weights from albumentations import Compose, Normalize, Resize from albumentations.pytorch.transforms import ToTensorV2 import os os.environ['KMP_DUPLICATE_LIB_OK']='True' # fix random seed seed = 0 random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) parser = argparse.ArgumentParser(description="This code helps you use a trained model to " "do inference.") parser.add_argument("--weight", "-w", type=str, default=None, help="Specify the path to the model weight (the state dict file). " "Do not use this argument when '--bin' is set.") parser.add_argument("--bin", "-b", type=str, default=None, help="Specify the path to the model bin which ends up with '.bin' " "(which is generated by the trainer of this project). " "Do not use this argument when '--weight' is set.") parser.add_argument("--image", "-i", type=str, default=None, help="Specify the path to the input image. " "Do not use this argument when '--image_folder' is set.") parser.add_argument("--image_folder", "-f", type=str, default=None, help="Specify the directory to evaluate all the images. " "Do not use this argument when '--image' is set.") parser.add_argument('--device', '-d', type=str, default="cpu", help="Specify the device to load the model. Default: 'cpu'.") parser.add_argument('--image_size', '-s', type=int, default=299, help="Specify the spatial size of the input image(s). Default: 299.") parser.add_argument('--visualize', '-v', action="store_true", default=False, help='Visualize images.') def preprocess(file_path): img = cv2.imread(file_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) compose = Compose([Resize(height=args.image_size, width=args.image_size), Normalize(mean=[0.5] * 3, std=[0.5] * 3), ToTensorV2()]) img = compose(image=img)['image'].unsqueeze(0) return img def prepare_data(): paths = list() images = list() # check the console arguments if args.image and args.image_folder: raise ValueError("Only one of '--image' or '--image_folder' can be set.") elif args.image: images.append(preprocess(args.image)) paths.append(args.image) elif args.image_folder: image_paths = glob(args.image_folder + "/*.jpg") image_paths.extend(glob(args.image_folder + "/*.png")) for _ in image_paths: images.append(preprocess(_)) paths.append(_) else: raise ValueError("Neither of '--image' nor '--image_folder' is set. Please specify either " "one of these two arguments to load input image(s) properly.") return paths, images def inference(model, images, paths, device): mean_pred = 0 for img, pt in zip(images, paths): img = img.to(device) prediction = model(img) prediction = torch.sigmoid(prediction).cpu() fake = True if prediction >= 0.5 else False mean_pred += prediction.item() print(f"path: {pt} \t\t| fake probability: {prediction.item():.4f} \t| " f"prediction: {'fake' if fake else 'real'}") if args.visualize: cvimg = cv2.imread(pt) cvimg = cv2.putText(cvimg, f'p: {prediction.item():.2f}, ' + f"{'fake' if fake else 'real'}", (5, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255) if fake else (255, 0, 0), 2) cv2.imshow("image", cvimg) cv2.waitKey(0) cv2.destroyWindow("image") mean_pred = mean_pred / len(images) return mean_pred def main(): print("Arguments:\n", args, end="\n\n") # set device device = torch.device(args.device) # load model model = eval("Recce")(num_classes=1) # check the console arguments if args.weight and args.bin: raise ValueError("Only one of '--weight' or '--bin' can be set.") elif args.weight: weights = torch.load(args.weight, map_location=device) elif args.bin: weights = torch.load(args.bin, map_location=device)["model"] else: raise ValueError("Neither of '--weight' nor '--bin' is set. Please specify either " "one of these two arguments to load model's weight properly.") model.load_state_dict(weights) model = model.to(device) freeze_weights(model) model.eval() paths, images = prepare_data() print("Inference:") mean_pred = inference(model, images=images, paths=paths, device=device) print("Mean prediction:", mean_pred) if __name__ == '__main__': args = parser.parse_args() main()