File size: 3,830 Bytes
f7cc41e
 
05d2fe5
f7cc41e
05d2fe5
f7cc41e
 
 
6d061ac
 
 
 
 
 
0e69446
 
f7cc41e
0e69446
c547fc5
0e69446
 
 
05d2fe5
 
c547fc5
e40462a
f7cc41e
7bc0418
 
5075a9e
7bc0418
 
 
f7cc41e
7bc0418
 
 
5075a9e
7bc0418
f7cc41e
7bc0418
 
 
f7cc41e
7bc0418
 
 
f7cc41e
7bc0418
 
f7cc41e
7bc0418
f7cc41e
7bc0418
f7cc41e
7bc0418
f7cc41e
7bc0418
f7cc41e
7bc0418
f7cc41e
7bc0418
 
 
f7cc41e
7bc0418
 
 
f7cc41e
7bc0418
 
 
 
 
f7cc41e
7bc0418
f7cc41e
 
7bc0418
 
 
 
 
 
f7cc41e
af3661d
 
 
 
 
 
 
 
 
c547fc5
f03ca1d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
from typing import Any, Dict, List

import tensorflow as tf
from tensorflow import keras
import base64
import io
import os
import numpy as np
from PIL import Image

# most of this code has been obtained from Datature's prediction script
# https://github.com/datature/resources/blob/main/scripts/bounding_box/prediction.py

# def load_model():
# 	return tf.saved_model.load('./saved_model')

# model = load_model()

class PreTrainedPipeline():
    def __init__(self, path: str):
        # load the model
		self.model = keras.models.load_model(os.path.join(path, "tf_model.h5"))
        #self.model = tf.saved_model.load(os.path.join(path, "saved_model"))

    def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:

        # # convert img to numpy array, resize and normalize to make the prediction
        # img = np.array(inputs)

        # im = tf.image.resize(img, (128, 128))
        # im = tf.cast(im, tf.float32) / 255.0
        # pred_mask = self.model.predict(im[tf.newaxis, ...])
        
        # # take the best performing class for each pixel
        # # the output of argmax looks like this [[1, 2, 0], ...]
        # pred_mask_arg = tf.argmax(pred_mask, axis=-1)

        # labels = []
        
        # # convert the prediction mask into binary masks for each class
        # binary_masks = {}
        # mask_codes = {}
        
        # # when we take tf.argmax() over pred_mask, it becomes a tensor object
        # # the shape becomes TensorShape object, looking like this TensorShape([128]) 
        # # we need to take get shape, convert to list and take the best one
        
        # rows = pred_mask_arg[0][1].get_shape().as_list()[0]
        # cols = pred_mask_arg[0][2].get_shape().as_list()[0]
        
        # for cls in range(pred_mask.shape[-1]):

        #     binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
            
        #     for row in range(rows):

        #         for col in range(cols):

        #             if pred_mask_arg[0][row][col] == cls:
                        
        #                 binary_masks[f"mask_{cls}"][row][col] = 1
        #             else:
        #                 binary_masks[f"mask_{cls}"][row][col] = 0

        #     mask = binary_masks[f"mask_{cls}"]
        #     mask *= 255
        #     img = Image.fromarray(mask.astype(np.int8), mode="L")
               
        #     # we need to make it readable for the widget
        #     with io.BytesIO() as out:
        #         img.save(out, format="PNG")
        #         png_string = out.getvalue()
        #         mask = base64.b64encode(png_string).decode("utf-8")

        #     mask_codes[f"mask_{cls}"] = mask
    

        #     # widget needs the below format, for each class we return label and mask string
        #     labels.append({
        #         "label": f"LABEL_{cls}",
        #         "mask": mask_codes[f"mask_{cls}"],
        #         "score": 1.0,
        #     })
		
        # labels = [{"score":0.9509243965148926,"label":"car","box":{"xmin":142,"ymin":106,"xmax":376,"ymax":229}},
        # {"score":0.9981777667999268,"label":"car","box":{"xmin":405,"ymin":146,"xmax":640,"ymax":297}},
        # {"score":0.9963648915290833,"label":"car","box":{"xmin":0,"ymin":115,"xmax":61,"ymax":167}},
        # {"score":0.974663257598877,"label":"car","box":{"xmin":155,"ymin":104,"xmax":290,"ymax":141}},
        # {"score":0.9986898303031921,"label":"car","box":{"xmin":39,"ymin":117,"xmax":169,"ymax":188}},
        # {"score":0.9998276233673096,"label":"person","box":{"xmin":172,"ymin":60,"xmax":482,"ymax":396}},
        # {"score":0.9996274709701538,"label":"skateboard","box":{"xmin":265,"ymin":348,"xmax":440,"ymax":413}}]

        labels = []

        return labels