merve HF Staff commited on
Commit
1d51d9d
·
1 Parent(s): c7b4a7c

Create new file

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import from_pretrained_keras
2
+ import gradio as gr
3
+ import tensorflow as tf
4
+ import numpy as np
5
+ import os
6
+
7
+ model = tf.keras.models.load_model(os.path.join(path, "tf_model.h5"))
8
+
9
+ inputs = gr.inputs.Image()
10
+ output = gr.outputs.Image()
11
+
12
+
13
+ def predict(image_input):
14
+ img = np.array(inputs)
15
+
16
+ im = tf.image.resize(img, (128, 128))
17
+ im = tf.cast(im, tf.float32) / 255.0
18
+ pred_mask = self.model.predict(im[tf.newaxis, ...])
19
+
20
+ # take the best performing class for each pixel
21
+ # the output of argmax looks like this [[1, 2, 0], ...]
22
+ pred_mask_arg = tf.argmax(pred_mask, axis=-1)
23
+
24
+ labels = []
25
+
26
+ # convert the prediction mask into binary masks for each class
27
+ binary_masks = {}
28
+ mask_codes = {}
29
+
30
+ # when we take tf.argmax() over pred_mask, it becomes a tensor object
31
+ # the shape becomes TensorShape object, looking like this TensorShape([128])
32
+ # we need to take get shape, convert to list and take the best one
33
+
34
+ rows = pred_mask_arg[0][1].get_shape().as_list()[0]
35
+ cols = pred_mask_arg[0][2].get_shape().as_list()[0]
36
+
37
+ for cls in range(pred_mask.shape[-1]):
38
+
39
+ binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
40
+
41
+ for row in range(rows):
42
+
43
+ for col in range(cols):
44
+
45
+ if pred_mask_arg[0][row][col] == cls:
46
+
47
+ binary_masks[f"mask_{cls}"][row][col] = 1
48
+ else:
49
+ binary_masks[f"mask_{cls}"][row][col] = 0
50
+
51
+ mask = binary_masks[f"mask_{cls}"]
52
+ mask *= 255
53
+ img = Image.fromarray(mask.astype(np.int8), mode="L")
54
+ return img
55
+
56
+
57
+
58
+
59
+ gr.Interface(predict, inputs = inputs, outputs = output).launch()