"""Compute segmentation maps for images in the input folder. """ import os import glob import cv2 import argparse import torch import torch.nn.functional as F import util.io from torchvision.transforms import Compose from dpt.models import DPTSegmentationModel from dpt.transforms import Resize, NormalizeImage, PrepareForNet def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=True): """Run segmentation network Args: input_path (str): path to input folder output_path (str): path to output folder model_path (str): path to saved model """ print("initialize") # select device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("device: %s" % device) net_w = net_h = 480 # load network if model_type == "dpt_large": model = DPTSegmentationModel( 150, path=model_path, backbone="vitl16_384", ) elif model_type == "dpt_hybrid": model = DPTSegmentationModel( 150, path=model_path, backbone="vitb_rn50_384", ) else: assert ( False ), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid]" transform = Compose( [ Resize( net_w, net_h, resize_target=None, keep_aspect_ratio=True, ensure_multiple_of=32, resize_method="minimal", image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), PrepareForNet(), ] ) model.eval() if optimize == True and device == torch.device("cuda"): model = model.to(memory_format=torch.channels_last) model = model.half() model.to(device) # get input img_names = glob.glob(os.path.join(input_path, "*")) num_images = len(img_names) # create output folder os.makedirs(output_path, exist_ok=True) print("start processing") for ind, img_name in enumerate(img_names): print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) # input img = util.io.read_image(img_name) img_input = transform({"image": img})["image"] # compute with torch.no_grad(): sample = torch.from_numpy(img_input).to(device).unsqueeze(0) if optimize == True and device == torch.device("cuda"): sample = sample.to(memory_format=torch.channels_last) sample = sample.half() out = model.forward(sample) prediction = torch.nn.functional.interpolate( out, size=img.shape[:2], mode="bicubic", align_corners=False ) prediction = torch.argmax(prediction, dim=1) + 1 prediction = prediction.squeeze().cpu().numpy() # output filename = os.path.join( output_path, os.path.splitext(os.path.basename(img_name))[0] ) util.io.write_segm_img(filename, img, prediction, alpha=0.5) print("finished") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-i", "--input_path", default="input", help="folder with input images" ) parser.add_argument( "-o", "--output_path", default="output_semseg", help="folder for output images" ) parser.add_argument( "-m", "--model_weights", default=None, help="path to the trained weights of model", ) # 'vit_large', 'vit_hybrid' parser.add_argument("-t", "--model_type", default="dpt_hybrid", help="model type") parser.add_argument("--optimize", dest="optimize", action="store_true") parser.add_argument("--no-optimize", dest="optimize", action="store_false") parser.set_defaults(optimize=True) args = parser.parse_args() default_models = { "dpt_large": "weights/dpt_large-ade20k-b12dca68.pt", "dpt_hybrid": "weights/dpt_hybrid-ade20k-53898607.pt", } if args.model_weights is None: args.model_weights = default_models[args.model_type] # set torch options torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True # compute segmentation maps run( args.input_path, args.output_path, args.model_weights, args.model_type, args.optimize, )