davidaf3's picture
Added pipeline
65e71e9
raw
history blame
1.3 kB
from typing import Dict, List, Any
from PIL import Image
from fcnutr import FCNutr
import os
import json
import numpy as np
import tensorflow as tf
class PreTrainedPipeline():
def __init__(self, path=""):
crop_size = (224, 224)
self.nutr_names = ('energy', 'fat', 'protein', 'carbs')
self.model = FCNutr(self.nutr_names, crop_size, 4096, 3, False)
self.model.compile()
self.model(tf.zeros((1, crop_size[0], crop_size[1], 3)))
self.model.load_weights(os.path.join(path, "fcnutr.h5"))
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
image = tf.keras.preprocessing.image.img_to_array(inputs)
height = tf.shape(image)[0]
width = tf.shape(image)[1]
if width > height:
image = tf.image.resize(image, (self.img_size, int(float(self.img_size * width) / float(height))))
else:
image = tf.image.resize(image, (int(float(self.img_size * height) / float(width)), self.img_size))
image = tf.keras.applications.inception_v3.preprocess_input(image)
image = tf.keras.layers.CenterCrop(*self.crop_size)(image)
prediction = self.model(image[tf.newaxis, :])
return {name: float(prediction[name].numpy()[0, 0]) for name in self.nutr_names}