File size: 2,685 Bytes
a7ce59e d299b84 9471efd c6b5997 a7ce59e c6b5997 9471efd 83b866c 9471efd c6b5997 83b866c c6b5997 83b866c c6b5997 a7ce59e 6cb57f7 83b866c c6b5997 83b866c c6b5997 83b866c c6b5997 83b866c c6b5997 d299b84 83b866c d299b84 176687d c6b5997 d299b84 c6b5997 7aad423 83b866c 6cb57f7 7aad423 6cb57f7 d299b84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import json
from typing import Any, Dict, List
import tensorflow as tf
from tensorflow import keras
import base64
import io
import os
import numpy as np
from PIL import Image
class PreTrainedPipeline():
def __init__(self, path: str):
# load the model
self.model = keras.models.load_model(os.path.join(path, "tf_model.h5"))
def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:
# convert img to numpy array, resize and normalize to make the prediction
with Image.open(inputs) as img:
img = np.array(img)
im = tf.image.resize(img, (128, 128))
im = tf.cast(im, tf.float32) / 255.0
pred_mask = model.predict(im[tf.newaxis, ...])
# take the best performing class for each pixel
# the output of argmax looks like this [[1, 2, 0], ...]
pred_mask_arg = tf.argmax(pred_mask, axis=-1)
labels = []
# convert the prediction mask into binary masks for each class
binary_masks = {}
mask_codes = {}
# when we take tf.argmax() over pred_mask, it becomes a tensor object
# the shape becomes TensorShape object, looking like this TensorShape([128])
# we need to take get shape, convert to list and take the best one
rows = pred_mask_arg[0][1].get_shape().as_list()[0]
cols = pred_mask_arg[0][2].get_shape().as_list()[0]
for cls in range(pred_mask.shape[-1]):
binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
for row in range(rows):
for col in range(cols):
if pred_mask_arg[0][row][col] == cls:
binary_masks[f"mask_{cls}"][row][col] = 1
else:
binary_masks[f"mask_{cls}"][row][col] = 0
mask = binary_masks[f"mask_{cls}"]
mask *= 255
img = Image.fromarray(mask.astype(np.int8), mode="L")
# we need to make it readable for the widget
with io.BytesIO() as out:
img.save(out, format="PNG")
png_string = out.getvalue()
mask = base64.b64encode(png_string).decode("utf-8")
mask_codes[f"mask_{cls}"] = mask
# widget needs the below format, for each class we return label and mask string
labels.append({
"label": f"LABEL_{cls}",
"mask": mask_codes[f"mask_{cls}"],
"score": 1.0,
})
return labels |