davidaf3's picture
Added config.json
c7ab34d
raw
history blame
1.27 kB
from typing import Dict, List, Any
from PIL import Image
from fcnutr import FCNutr
import os
import tensorflow as tf
class PreTrainedPipeline():
def __init__(self, path=""):
crop_size = (224, 224)
self.nutr_names = ('energy', 'fat', 'protein', 'carbs')
self.model = FCNutr(self.nutr_names, crop_size, 4096, 3, False)
self.model.compile()
self.model(tf.zeros((1, crop_size[0], crop_size[1], 3)))
self.model.load_weights(os.path.join(path, "fcnutr.h5"))
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
image = tf.keras.preprocessing.image.img_to_array(inputs)
height = tf.shape(image)[0]
width = tf.shape(image)[1]
if width > height:
image = tf.image.resize(image, (self.img_size, int(float(self.img_size * width) / float(height))))
else:
image = tf.image.resize(image, (int(float(self.img_size * height) / float(width)), self.img_size))
image = tf.keras.applications.inception_v3.preprocess_input(image)
image = tf.keras.layers.CenterCrop(*self.crop_size)(image)
prediction = self.model(image[tf.newaxis, :])
return {name: float(prediction[name].numpy()[0, 0]) for name in self.nutr_names}