import json from typing import Any, Dict, List import tensorflow as tf from tensorflow import keras import base64 import io import os import numpy as np from PIL import Image class PreTrainedPipeline(): def __init__(self, path: str): # load the model self.model = keras.models.load_model(os.path.join(path, "tf_model.h5")) def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]: # convert img to numpy array, resize and normalize to make the prediction #with Image.open(inputs) as img: img = np.array(inputs) im = tf.image.resize(img, (128, 128)) im = tf.cast(im, tf.float32) / 255.0 pred_mask = self.model.predict(im[tf.newaxis, ...]) # take the best performing class for each pixel # the output of argmax looks like this [[1, 2, 0], ...] pred_mask_arg = tf.argmax(pred_mask, axis=-1) labels = [] # convert the prediction mask into binary masks for each class binary_masks = {} mask_codes = {} # when we take tf.argmax() over pred_mask, it becomes a tensor object # the shape becomes TensorShape object, looking like this TensorShape([128]) # we need to take get shape, convert to list and take the best one rows = pred_mask_arg[0][1].get_shape().as_list()[0] cols = pred_mask_arg[0][2].get_shape().as_list()[0] for cls in range(pred_mask.shape[-1]): binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class for row in range(rows): for col in range(cols): if pred_mask_arg[0][row][col] == cls: binary_masks[f"mask_{cls}"][row][col] = 1 else: binary_masks[f"mask_{cls}"][row][col] = 0 mask = binary_masks[f"mask_{cls}"] mask *= 255 img = Image.fromarray(mask.astype(np.int8), mode="L") # we need to make it readable for the widget with io.BytesIO() as out: img.save(out, format="PNG") png_string = out.getvalue() mask = base64.b64encode(png_string).decode("utf-8") mask_codes[f"mask_{cls}"] = mask # widget needs the below format, for each class we return label and mask string labels.append({ "label": f"LABEL_{cls}", "mask": mask_codes[f"mask_{cls}"], "score": 1.0, }) return labels