hf-models / inference.py
DimasMP3
Normalise image inputs and force RGB channels
10d4547
from __future__ import annotations
import io
from typing import Any, Dict, List
import numpy as np
from PIL import Image
import tensorflow as tf
LABELS: List[str] = [
"Heart",
"Oblong",
"Oval",
"Round",
"Square",
]
TARGET_SIZE = 244
def _load_image(image_bytes: bytes) -> Image.Image:
image = Image.open(io.BytesIO(image_bytes))
if image.mode != "RGB":
image = image.convert("RGB")
return image
def _ensure_three_channels(array: np.ndarray) -> np.ndarray:
if array.ndim == 2:
array = np.stack([array] * 3, axis=-1)
elif array.ndim == 3:
if array.shape[-1] == 1:
array = np.repeat(array, 3, axis=-1)
elif array.shape[-1] > 3:
array = array[..., :3]
return array
def _preprocess(image_bytes: bytes) -> np.ndarray:
image = _load_image(image_bytes)
resized = image.resize((TARGET_SIZE, TARGET_SIZE), Image.BILINEAR)
array = np.asarray(resized, dtype="float32")
array = _ensure_three_channels(array)
array /= 255.0
return np.expand_dims(array, axis=0)
class PreTrainedModel:
def __init__(self, model_path: str = "model/best_model.keras") -> None:
self.model = tf.keras.models.load_model(model_path)
def predict(self, inputs: bytes) -> List[Dict[str, Any]]:
x = _preprocess(inputs)
preds = self.model.predict(x, verbose=0)
if isinstance(preds, (list, tuple)):
preds = preds[0]
probs = np.asarray(preds).squeeze().tolist()
idx = int(np.argmax(probs))
return [
{"label": LABELS[idx], "score": float(probs[idx])},
]
def load_model(model_dir: str = ".") -> PreTrainedModel:
return PreTrainedModel(model_path=f"{model_dir}/model/best_model.keras")