import numpy as np import cv2 import torch import glob as glob import os import time import argparse import yaml import matplotlib.pyplot as plt from models.create_fasterrcnn_model import create_model from utils.annotations import inference_annotations from utils.general import set_infer_dir from utils.transforms import infer_transforms def collect_all_images(dir_test): """ Function to return a list of image paths. :param dir_test: Directory containing images or single image path. Returns: test_images: List containing all image paths. """ test_images = [] if os.path.isdir(dir_test): image_file_types = ['*.jpg', '*.jpeg', '*.png', '*.ppm'] for file_type in image_file_types: test_images.extend(glob.glob(f"{dir_test}/{file_type}")) else: test_images.append(dir_test) return test_images def parse_opt(): # Construct the argument parser. parser = argparse.ArgumentParser() parser.add_argument( '-i', '--input', help='folder path to input input image (one image or a folder path)', ) parser.add_argument( '-c', '--config', default=None, help='(optional) path to the data config file' ) parser.add_argument( '-m', '--model', default=None, help='name of the model' ) parser.add_argument( '-w', '--weights', default=None, help='path to trained checkpoint weights if providing custom YAML file' ) parser.add_argument( '-th', '--threshold', default=0.3, type=float, help='detection threshold' ) parser.add_argument( '-si', '--show-image', dest='show_image', action='store_true', help='visualize output only if this argument is passed' ) parser.add_argument( '-mpl', '--mpl-show', dest='mpl_show', action='store_true', help='visualize using matplotlib, helpful in notebooks' ) parser.add_argument( '-d', '--device', default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), help='computation/training device, default is GPU if GPU present' ) args = vars(parser.parse_args()) return args def main(args): # For same annotation colors each time. np.random.seed(42) # Load the data configurations. data_configs = None if args['config'] is not None: with open(args['config']) as file: data_configs = yaml.safe_load(file) NUM_CLASSES = data_configs['NC'] CLASSES = data_configs['CLASSES'] DEVICE = args['device'] OUT_DIR = set_infer_dir() # Load the pretrained model if args['weights'] is None: # If the config file is still None, # then load the default one for COCO. if data_configs is None: with open(os.path.join('data_configs', 'test_image_config.yaml')) as file: data_configs = yaml.safe_load(file) NUM_CLASSES = data_configs['NC'] CLASSES = data_configs['CLASSES'] try: build_model = create_model[args['model']] except: build_model = create_model['fasterrcnn_resnet50_fpn'] model = build_model(num_classes=NUM_CLASSES, coco_model=True) # Load weights if path provided. if args['weights'] is not None: checkpoint = torch.load(args['weights'], map_location=DEVICE) # If config file is not given, load from model dictionary. if data_configs is None: data_configs = True NUM_CLASSES = checkpoint['config']['NC'] CLASSES = checkpoint['config']['CLASSES'] try: print('Building from model name arguments...') build_model = create_model[str(args['model'])] except: build_model = create_model[checkpoint['model_name']] model = build_model(num_classes=NUM_CLASSES, coco_model=False) model.load_state_dict(checkpoint['model_state_dict']) model.to(DEVICE).eval() COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3)) if args['input'] == None: DIR_TEST = data_configs['image_path'] test_images = collect_all_images(DIR_TEST) else: DIR_TEST = args['input'] test_images = collect_all_images(DIR_TEST) print(f"Test instances: {len(test_images)}") # Define the detection threshold any detection having # score below this will be discarded. detection_threshold = args['threshold'] # To count the total number of frames iterated through. frame_count = 0 # To keep adding the frames' FPS. total_fps = 0 for i in range(len(test_images)): # Get the image file name for saving output later on. image_name = test_images[i].split(os.path.sep)[-1].split('.')[0] image = cv2.imread(test_images[i]) orig_image = image.copy() # BGR to RGB image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) image = infer_transforms(image) # Add batch dimension. image = torch.unsqueeze(image, 0) start_time = time.time() with torch.no_grad(): outputs = model(image.to(DEVICE)) end_time = time.time() # Get the current fps. fps = 1 / (end_time - start_time) # Add `fps` to `total_fps`. total_fps += fps # Increment frame count. frame_count += 1 # Load all detection to CPU for further operations. outputs = [{k: v.to('cpu') for k, v in t.items()} for t in outputs] # Carry further only if there are detected boxes. if len(outputs[0]['boxes']) != 0: orig_image = inference_annotations( outputs, detection_threshold, CLASSES, COLORS, orig_image ) if args['show_image']: cv2.imshow('Prediction', orig_image) cv2.waitKey(1) if args['mpl_show']: plt.imshow(orig_image[:, :, ::-1]) plt.axis('off') plt.show() cv2.imwrite(f"{OUT_DIR}/{image_name}.jpg", orig_image) print(f"Image {i+1} done...") print('-'*50) print('TEST PREDICTIONS COMPLETE') cv2.destroyAllWindows() # Calculate and print the average FPS. avg_fps = total_fps / frame_count print(f"Average FPS: {avg_fps:.3f}") if __name__ == '__main__': args = parse_opt() main(args)