Zero-Shot-Material-Transfer / DPT /run_segmentation.py
fffiloni's picture
Upload 47 files
a9289c0
raw history blame
No virus
4.54 kB
"""Compute segmentation maps for images in the input folder.
"""
import os
import glob
import cv2
import argparse
import torch
import torch.nn.functional as F
import util.io
from torchvision.transforms import Compose
from dpt.models import DPTSegmentationModel
from dpt.transforms import Resize, NormalizeImage, PrepareForNet
def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=True):
"""Run segmentation network
Args:
input_path (str): path to input folder
output_path (str): path to output folder
model_path (str): path to saved model
"""
print("initialize")
# select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: %s" % device)
net_w = net_h = 480
# load network
if model_type == "dpt_large":
model = DPTSegmentationModel(
150,
path=model_path,
backbone="vitl16_384",
)
elif model_type == "dpt_hybrid":
model = DPTSegmentationModel(
150,
path=model_path,
backbone="vitb_rn50_384",
)
else:
assert (
False
), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid]"
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="minimal",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
PrepareForNet(),
]
)
model.eval()
if optimize == True and device == torch.device("cuda"):
model = model.to(memory_format=torch.channels_last)
model = model.half()
model.to(device)
# get input
img_names = glob.glob(os.path.join(input_path, "*"))
num_images = len(img_names)
# create output folder
os.makedirs(output_path, exist_ok=True)
print("start processing")
for ind, img_name in enumerate(img_names):
print(" processing {} ({}/{})".format(img_name, ind + 1, num_images))
# input
img = util.io.read_image(img_name)
img_input = transform({"image": img})["image"]
# compute
with torch.no_grad():
sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
if optimize == True and device == torch.device("cuda"):
sample = sample.to(memory_format=torch.channels_last)
sample = sample.half()
out = model.forward(sample)
prediction = torch.nn.functional.interpolate(
out, size=img.shape[:2], mode="bicubic", align_corners=False
)
prediction = torch.argmax(prediction, dim=1) + 1
prediction = prediction.squeeze().cpu().numpy()
# output
filename = os.path.join(
output_path, os.path.splitext(os.path.basename(img_name))[0]
)
util.io.write_segm_img(filename, img, prediction, alpha=0.5)
print("finished")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input_path", default="input", help="folder with input images"
)
parser.add_argument(
"-o", "--output_path", default="output_semseg", help="folder for output images"
)
parser.add_argument(
"-m",
"--model_weights",
default=None,
help="path to the trained weights of model",
)
# 'vit_large', 'vit_hybrid'
parser.add_argument("-t", "--model_type", default="dpt_hybrid", help="model type")
parser.add_argument("--optimize", dest="optimize", action="store_true")
parser.add_argument("--no-optimize", dest="optimize", action="store_false")
parser.set_defaults(optimize=True)
args = parser.parse_args()
default_models = {
"dpt_large": "weights/dpt_large-ade20k-b12dca68.pt",
"dpt_hybrid": "weights/dpt_hybrid-ade20k-53898607.pt",
}
if args.model_weights is None:
args.model_weights = default_models[args.model_type]
# set torch options
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# compute segmentation maps
run(
args.input_path,
args.output_path,
args.model_weights,
args.model_type,
args.optimize,
)