File size: 701 Bytes
80bb308
 
 
 
 
 
 
 
 
 
f9f8b5c
80bb308
 
ad4887b
80bb308
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import numpy as np
import cv2

def make_predictions(model, input_img, threshold = 0.05):
    model.eval()
    with torch.no_grad():
        image = input_img
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype("float32") / 255.0
        image = cv2.resize(image, (256, 256))
        image = np.transpose(image, (2, 0, 1))
        image = np.expand_dims(image, 0)
        image = torch.from_numpy(image).to('cpu')
        predMask = model(image).squeeze()
        predMask = torch.sigmoid(predMask)
        predMask = predMask.cpu().numpy()
        predMask = (predMask > threshold) * 255
        predMask = predMask.astype(np.uint8)
        return predMask