spacenet3-unet-1024-1024 / make_predictions.py
shivambhosale's picture
Update make_predictions.py
ad4887b
raw history blame
No virus
701 Bytes
import torch
import numpy as np
import cv2
def make_predictions(model, input_img, threshold = 0.05):
model.eval()
with torch.no_grad():
image = input_img
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.astype("float32") / 255.0
image = cv2.resize(image, (256, 256))
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image, 0)
image = torch.from_numpy(image).to('cpu')
predMask = model(image).squeeze()
predMask = torch.sigmoid(predMask)
predMask = predMask.cpu().numpy()
predMask = (predMask > threshold) * 255
predMask = predMask.astype(np.uint8)
return predMask