shivambhosale commited on
Commit
80bb308
1 Parent(s): ab2f1c9

Create new file

Browse files
Files changed (1) hide show
  1. make_predictions.py +20 -0
make_predictions.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+
5
+ def make_predictions(model, input_img, threshold = 0.05):
6
+ model.eval()
7
+ with torch.no_grad():
8
+ image = input_img
9
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
10
+ image = image.astype("float32") / 255.0
11
+ image = cv2.resize(image, (height, width))
12
+ image = np.transpose(image, (2, 0, 1))
13
+ image = np.expand_dims(image, 0)
14
+ image = torch.from_numpy(image).to(device)
15
+ predMask = model(image).squeeze()
16
+ predMask = torch.sigmoid(predMask)
17
+ predMask = predMask.cpu().numpy()
18
+ predMask = (predMask > threshold) * 255
19
+ predMask = predMask.astype(np.uint8)
20
+ return predMask