sofmi commited on
Commit
25b5e8c
1 Parent(s): ded6773

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -3
app.py CHANGED
@@ -12,8 +12,49 @@ inputs = gr.inputs.Image()
12
  outputs = gr.outputs.Image()
13
 
14
  # inference fn
15
- def predict(image_input):
16
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- gr.Interface(predict, inputs = inputs, outputs = outputs)
 
12
  outputs = gr.outputs.Image()
13
 
14
  # inference fn
15
+ def predict(inputs):
16
+ # convert img to numpy array, resize and normalize to make the prediction
17
+ img = np.array(inputs)
18
+
19
+ im = tf.image.resize(img, (128, 128))
20
+ im = tf.cast(im, tf.float32) / 255.0
21
+ pred_mask = model.predict(im[tf.newaxis, ...])
22
+
23
+ # take the best performing class for each pixel
24
+ # the output of argmax looks like this [[1, 2, 0], ...]
25
+ pred_mask_arg = tf.argmax(pred_mask, axis=-1)
26
+
27
+ labels = []
28
+
29
+ # convert the prediction mask into binary masks for each class
30
+ binary_masks = {}
31
+ mask_codes = {}
32
+
33
+ # when we take tf.argmax() over pred_mask, it becomes a tensor object
34
+ # the shape becomes TensorShape object, looking like this TensorShape([128])
35
+ # we need to take get shape, convert to list and take the best one
36
+
37
+ rows = pred_mask_arg[0][1].get_shape().as_list()[0]
38
+ cols = pred_mask_arg[0][2].get_shape().as_list()[0]
39
 
40
+ for cls in range(pred_mask.shape[-1]):
41
+
42
+ binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
43
+
44
+ for row in range(rows):
45
+
46
+ for col in range(cols):
47
+
48
+ if pred_mask_arg[0][row][col] == cls:
49
+
50
+ binary_masks[f"mask_{cls}"][row][col] = 1
51
+ else:
52
+ binary_masks[f"mask_{cls}"][row][col] = 0
53
+
54
+ mask = binary_masks[f"mask_{cls}"]
55
+ mask *= 255
56
+ img = Image.fromarray(mask.astype(np.int8), mode="L")
57
+ return img
58
+
59
 
60
+ gr.Interface(predict, inputs = inputs, outputs = outputs).launch()