File size: 2,685 Bytes
a7ce59e
 
 
 
 
 
d299b84
9471efd
c6b5997
 
a7ce59e
 
 
c6b5997
9471efd
83b866c
9471efd
c6b5997
 
 
83b866c
c6b5997
 
 
 
 
 
83b866c
 
 
c6b5997
a7ce59e
6cb57f7
83b866c
 
c6b5997
 
83b866c
 
 
 
 
 
 
 
c6b5997
 
83b866c
 
 
c6b5997
83b866c
c6b5997
 
 
 
 
 
 
d299b84
 
 
83b866c
 
d299b84
 
 
176687d
c6b5997
d299b84
c6b5997
7aad423
83b866c
6cb57f7
7aad423
 
6cb57f7
 
d299b84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import json
from typing import Any, Dict, List

import tensorflow as tf
from tensorflow import keras
import base64
import io
import os
import numpy as np
from PIL import Image



class PreTrainedPipeline():
    def __init__(self, path: str):
        # load the model
        self.model = keras.models.load_model(os.path.join(path, "tf_model.h5"))

    def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:

        # convert img to numpy array, resize and normalize to make the prediction
        with Image.open(inputs) as img:
            img = np.array(img)

        im = tf.image.resize(img, (128, 128))
        im = tf.cast(im, tf.float32) / 255.0
        pred_mask = model.predict(im[tf.newaxis, ...])
        
        # take the best performing class for each pixel
        # the output of argmax looks like this [[1, 2, 0], ...]
        pred_mask_arg = tf.argmax(pred_mask, axis=-1)

        labels = []
        
        # convert the prediction mask into binary masks for each class
        binary_masks = {}
        mask_codes = {}
        
        # when we take tf.argmax() over pred_mask, it becomes a tensor object
        # the shape becomes TensorShape object, looking like this TensorShape([128]) 
        # we need to take get shape, convert to list and take the best one
        
        rows = pred_mask_arg[0][1].get_shape().as_list()[0]
        cols = pred_mask_arg[0][2].get_shape().as_list()[0]
        
        for cls in range(pred_mask.shape[-1]):

            binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
            
            for row in range(rows):

                for col in range(cols):

                    if pred_mask_arg[0][row][col] == cls:
                        
                        binary_masks[f"mask_{cls}"][row][col] = 1
                    else:
                        binary_masks[f"mask_{cls}"][row][col] = 0

            mask = binary_masks[f"mask_{cls}"]
            mask *= 255
            img = Image.fromarray(mask.astype(np.int8), mode="L")
               
            # we need to make it readable for the widget
            with io.BytesIO() as out:
                img.save(out, format="PNG")
                png_string = out.getvalue()
                mask = base64.b64encode(png_string).decode("utf-8")

            mask_codes[f"mask_{cls}"] = mask
    

            # widget needs the below format, for each class we return label and mask string
            labels.append({
                "label": f"LABEL_{cls}",
                "mask": mask_codes[f"mask_{cls}"],
                "score": 1.0,
            })
        return labels