dragon-notdragon / pipeline.py
hadilq's picture
add training script and prediction
fb5d392 unverified
from typing import Dict
from PIL import Image
import numpy as np
import os
import json
import tensorflow as tf
from tensorflow import keras
class PreTrainedPipeline():
def __init__(self, path=""):
self.model = keras.saving.load_model("./")
with open(os.path.join(path, "config.json")) as config:
config = json.load(config)
self.id2label = config["id2label"]
def __call__(self, inputs: "Image.Image")-> Dict[str, str]:
"""
Args:
inputs (:obj:`PIL.Image`):
The raw image representation as PIL.
No transformation made whatsoever from the input. Make all necessary transformations here.
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
It is preferred if the returned list is in decreasing `score` order
"""
img = keras.preprocessing.image.load_img(input, target_size=(224, 224))
x = keras.preprocessing.image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = keras.applications.vgg16.preprocess_input(x)
prediction = self.model.predict(x)
return { 'label': "detected", 'score': "dragon" if prediction[0][0] >= 0.99 else "not-dragon" }