davidaf3's picture
Added model and pipeline
79a66df
raw history blame
No virus
7.16 kB
from typing import Dict, List, Any
from PIL import Image
from tfing import TFIng
from tfport import TFPort, get_look_ahead_mask, get_padding_mask
import os
import json
import tensorflow as tf
import numpy as np
class PreTrainedPipeline():
def __init__(self, path=""):
crop_size = (224, 224)
embed_dim = 256,
num_layers = 3
seq_length = 20
hidden_dim = 1024
num_heads = 8
self.nutr_names = ('energy', 'fat', 'protein', 'carbs')
with open(f'ingredients_metadata.json', encoding='UTF-8') as f:
self.ingredients = json.load(f)
self.ing_names = {ing['name']: int(ing_id) for ing_id, ing in self.ingredients.items()}
self.vocab_size = len(self.ingredients) + 3
self.seq_length = seq_length
self.tfing = TFIng(
crop_size,
embed_dim,
num_layers,
seq_length,
hidden_dim,
num_heads,
self.vocab_size
)
self.tfing.compile()
self.tfing((tf.zeros((1, 224, 224, 3)), tf.zeros((1, seq_length))))
self.tfing.load_weights(f'tfing.h5')
self.tfport = TFPort(
crop_size,
embed_dim,
num_layers,
num_layers,
seq_length,
seq_length,
hidden_dim,
num_heads,
self.vocab_size
)
self.tfport.compile()
self.tfport((tf.zeros((1, 224, 224, 3)), tf.zeros((1, seq_length)), tf.zeros((1, seq_length))))
self.tfport.load_weights(f'/tfport.h5')
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
image = tf.keras.preprocessing.image.img_to_array(inputs)
height = tf.shape(image)[0]
width = tf.shape(image)[1]
if width > height:
image = tf.image.resize(image, (self.img_size, int(float(self.img_size * width) / float(height))))
else:
image = tf.image.resize(image, (int(float(self.img_size * height) / float(width)), self.img_size))
image = tf.keras.applications.inception_v3.preprocess_input(image)
image = tf.keras.layers.CenterCrop(*self.crop_size)(image)
prediction = self.predict(image)
return [
{
"label": prediction['ingredients'][i],
"score": prediction['portions'][i]
}
for i in range(len(prediction['ingredients']))
]
def encode_image(self, image):
encoder_out = self.tfing.encoder(image)
encoder_out = self.tfing.conv(encoder_out)
encoder_out = tf.reshape(
encoder_out,
(tf.shape(encoder_out)[0], -1, tf.shape(encoder_out)[3])
)
return encoder_out
def encode_ingredients(self, ingredients, padding_mask):
return self.tfport.ingredient_encoder(ingredients, padding_mask)
def decode_ingredients(self, encoded_img, decoder_in):
decoder_outputs = self.tfing.decoder(decoder_in, encoded_img)
output = self.tfing.linear(decoder_outputs)
return output + self.tfing.get_replacement_mask(decoder_in)
def decode_portions(self, encoded_img, encoded_ingr, decoder_in, padding_mask):
encoder_outputs = tf.concat([encoded_img, encoded_ingr], axis=1)
img_mask = tf.ones((tf.shape(encoded_img)[0], 1, tf.shape(encoded_img)[1]), dtype=tf.int32)
padding_mask = tf.concat([img_mask, padding_mask], axis=2)
look_ahead_mask = get_look_ahead_mask(decoder_in)
x = self.tfport.portion_embedding(decoder_in)
for i in range(len(self.tfport.decoder_layers)):
x = self.tfport.decoder_layers[i](x, encoder_outputs, look_ahead_mask, padding_mask=padding_mask)
x = self.tfport.linear(x)
return tf.squeeze(x)
def predict_ingredients(self, encoded_img, known_ing=None):
predicted = np.zeros((1, self.seq_length + 1), dtype=int)
predicted[0, 0] = self.vocab_size - 2
start_index = 0
if known_ing:
predicted[0, 1:len(known_ing) + 1] = known_ing
start_index = len(known_ing)
for i in range(start_index, self.seq_length):
decoded = self.decode_ingredients(encoded_img, predicted[:, :-1])
next_token = int(np.argmax(decoded[0, i]))
predicted[0, i + 1] = next_token
if next_token == self.vocab_size - 1:
return predicted[0, 1:]
if i == self.seq_length - 1:
predicted[0, i + 1] = self.vocab_size - 1
return predicted[0, 1:]
def predict_portions(self, encoded_image, ingredients):
predicted = np.zeros((1, self.seq_length + 1), dtype=float)
predicted[0, 0] = -1
padding_mask = get_padding_mask(ingredients)
encoded_ingr = self.encode_ingredients(ingredients, padding_mask)
for i in range(self.seq_length):
if ingredients[0, i] == self.vocab_size - 1:
return predicted[0, 1:]
next_proportion = float(
self.decode_portions(
encoded_image,
encoded_ingr,
predicted[:, :-1],
padding_mask
)[i]
)
predicted[0, i + 1] = next_proportion
return predicted[0, 1:]
def process_ingredients(self, ingredients):
processed = []
for ingredient in ingredients.split('\n'):
stripped = ingredient.strip()
if stripped == '.':
return processed, True
if stripped in self.ing_names:
processed.append(self.ing_names[stripped])
return processed, False
def predict(self, image, known_ing=None):
encoded_image = self.encode_image(image[tf.newaxis, :])
known_ing, skip_ing = self.process_ingredients(known_ing)\
if known_ing else (None, False)
if not skip_ing:
ingredients = self.predict_ingredients(encoded_image, known_ing=known_ing)
else:
ingredients = known_ing[:self.seq_length - 1]
ingredients.append(self.vocab_size - 1)
ingredients = np.pad(ingredients, (0, self.seq_length - len(ingredients)))
readable_ingredients = [
self.ingredients[str(token)]['name'] for token in ingredients
if token != 0 and token != self.vocab_size - 1
]
portions = self.predict_portions(encoded_image, ingredients[tf.newaxis, :])\
if len(readable_ingredients) > 1 else [100]
portions_slice = portions[:len(readable_ingredients)]
scale = 100 / sum(portions_slice)
return {
'ingredients': readable_ingredients,
'portions': [portion * scale for portion in portions_slice],
'nutrition': {
name: sum(
self.ingredients[str(ingredients[i])][name] * portions[i] / 100
for i in range(len(readable_ingredients))
) for name in self.nutr_names
}
}