Spaces:
Running
Running
File size: 3,907 Bytes
bc97962 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import argparse
from pathlib import Path
from PIL import Image
import torch
import torchvision
from skimage.morphology import erosion
import matplotlib.pyplot as plt
import time
from segroot.utils import init_weights
from segroot.dataloader import pad_pair_256, normalize
from segroot.model import SegRoot
parser = argparse.ArgumentParser()
parser.add_argument(
"--image", default="test.jpg", type=str, help="filename of one test image"
)
parser.add_argument(
"--thres", default=0.9, type=float, help="threshold of the final binarization"
)
parser.add_argument(
"--all", action="store_true", help="make prediction on all images in the folder"
)
parser.add_argument(
"--data_dir",
default="../data/prediction",
type=Path,
help="define the data directory",
)
parser.add_argument(
"--weights",
default="../weights/best_segnet-(8,5)-0.6441.pt",
type=Path,
help="path of pretrained weights",
)
parser.add_argument("--width", default=8, type=int, help="width of SegRoot")
parser.add_argument("--depth", default=5, type=int, help="depth of SegRoot")
def pad_256(img_path):
image = Image.open(img_path)
W, H = image.size
img, _ = pad_pair_256(image, image)
NW, NH = img.size
img = torchvision.transforms.ToTensor()(img)
img = normalize(img)
return img, (H, W, NH, NW)
def predict(model, test_img, device):
for p in model.parameters():
p.requires_grad = False
model.eval()
# test_img.shape = (3, 2304, 2560)
test_img = test_img.unsqueeze(0)
output = model(test_img)
# output.shape = (1, 1, 2304, 2560)
output = torch.squeeze(output)
torch.cuda.empty_cache()
return output
def predict_gen(model, img_path, thres, device, info):
img, dims = pad_256(img_path)
H, W, NH, NW = dims
img = img.to(device)
prediction = predict(model, img, device)
prediction[prediction >= thres] = 1.0
prediction[prediction < thres] = 0.0
if device.type == "cpu":
prediction = prediction.detach().numpy()
else:
prediction = prediction.cpu().detach().numpy()
prediction = erosion(prediction)
# reverse padding
prediction = prediction[
(NH - H) // 2 : (NH - H) // 2 + H, (NW - W) // 2 : (NW - W) // 2 + W
]
save_path = img_path.parent / (
img_path.parts[-1].split(".jpg")[0] + "-pre-mask-segnet-({},5).jpg".format(info)
)
plt.imsave(save_path.as_posix(), prediction, cmap="gray")
print("{} generated!".format(save_path.parts[-1]))
if __name__ == "__main__":
args = parser.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# define model
print("using segnet, width : {}, depth : {}".format(args.width, args.depth))
model = SegRoot(args.width, args.depth).to(device)
weights_path = args.weights
if device.type == "cpu":
print("load weights to cpu")
print(weights_path.as_posix())
model.load_state_dict(torch.load(weights_path.as_posix(), map_location="cpu"))
else:
print("load weights to gpu")
print(weights_path.as_posix())
model.load_state_dict(torch.load(weights_path.as_posix()))
# define the prediction's saving directory
pre_dir = Path("../data/prediction")
pre_dir.mkdir(parents=True, exist_ok=True)
if not args.all:
# load and pad image
img_path = pre_dir / args.image
start_time = time.time()
predict_gen(model, img_path, args.thres, device, 8)
end_time = time.time()
print("{:.4f}s for one image".format(end_time - start_time))
else:
img_paths = args.data_dir.glob("*.jpg")
for img_path in img_paths:
start_time = time.time()
predict_gen(model, img_path, args.thres, device, 8)
end_time = time.time()
print("{:.4f}s for one image".format(end_time - start_time))
|