Andres Felipe Ruiz-Hurtado
initial
bc97962
import argparse
from pathlib import Path
from PIL import Image
import torch
import torchvision
from skimage.morphology import erosion
import matplotlib.pyplot as plt
import time
from segroot.utils import init_weights
from segroot.dataloader import pad_pair_256, normalize
from segroot.model import SegRoot
parser = argparse.ArgumentParser()
parser.add_argument(
"--image", default="test.jpg", type=str, help="filename of one test image"
)
parser.add_argument(
"--thres", default=0.9, type=float, help="threshold of the final binarization"
)
parser.add_argument(
"--all", action="store_true", help="make prediction on all images in the folder"
)
parser.add_argument(
"--data_dir",
default="../data/prediction",
type=Path,
help="define the data directory",
)
parser.add_argument(
"--weights",
default="../weights/best_segnet-(8,5)-0.6441.pt",
type=Path,
help="path of pretrained weights",
)
parser.add_argument("--width", default=8, type=int, help="width of SegRoot")
parser.add_argument("--depth", default=5, type=int, help="depth of SegRoot")
def pad_256(img_path):
image = Image.open(img_path)
W, H = image.size
img, _ = pad_pair_256(image, image)
NW, NH = img.size
img = torchvision.transforms.ToTensor()(img)
img = normalize(img)
return img, (H, W, NH, NW)
def predict(model, test_img, device):
for p in model.parameters():
p.requires_grad = False
model.eval()
# test_img.shape = (3, 2304, 2560)
test_img = test_img.unsqueeze(0)
output = model(test_img)
# output.shape = (1, 1, 2304, 2560)
output = torch.squeeze(output)
torch.cuda.empty_cache()
return output
def predict_gen(model, img_path, thres, device, info):
img, dims = pad_256(img_path)
H, W, NH, NW = dims
img = img.to(device)
prediction = predict(model, img, device)
prediction[prediction >= thres] = 1.0
prediction[prediction < thres] = 0.0
if device.type == "cpu":
prediction = prediction.detach().numpy()
else:
prediction = prediction.cpu().detach().numpy()
prediction = erosion(prediction)
# reverse padding
prediction = prediction[
(NH - H) // 2 : (NH - H) // 2 + H, (NW - W) // 2 : (NW - W) // 2 + W
]
save_path = img_path.parent / (
img_path.parts[-1].split(".jpg")[0] + "-pre-mask-segnet-({},5).jpg".format(info)
)
plt.imsave(save_path.as_posix(), prediction, cmap="gray")
print("{} generated!".format(save_path.parts[-1]))
if __name__ == "__main__":
args = parser.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# define model
print("using segnet, width : {}, depth : {}".format(args.width, args.depth))
model = SegRoot(args.width, args.depth).to(device)
weights_path = args.weights
if device.type == "cpu":
print("load weights to cpu")
print(weights_path.as_posix())
model.load_state_dict(torch.load(weights_path.as_posix(), map_location="cpu"))
else:
print("load weights to gpu")
print(weights_path.as_posix())
model.load_state_dict(torch.load(weights_path.as_posix()))
# define the prediction's saving directory
pre_dir = Path("../data/prediction")
pre_dir.mkdir(parents=True, exist_ok=True)
if not args.all:
# load and pad image
img_path = pre_dir / args.image
start_time = time.time()
predict_gen(model, img_path, args.thres, device, 8)
end_time = time.time()
print("{:.4f}s for one image".format(end_time - start_time))
else:
img_paths = args.data_dir.glob("*.jpg")
for img_path in img_paths:
start_time = time.time()
predict_gen(model, img_path, args.thres, device, 8)
end_time = time.time()
print("{:.4f}s for one image".format(end_time - start_time))