Spaces:
Running
Running
import requests | |
from io import BytesIO | |
import numpy as np | |
from gensim.models.fasttext import FastText | |
from scipy import spatial | |
import itertools | |
import gdown | |
import warnings | |
import nltk | |
# warnings.filterwarnings('ignore') | |
import pickle | |
import pdb | |
from concurrent.futures import ProcessPoolExecutor | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
import argparse | |
# NLTK Datasets | |
nltk.download('wordnet') | |
nltk.download('punkt') | |
nltk.download('averaged_perceptron_tagger') | |
# Average embedding β Compare | |
def recommend_ingredients(yum, leftovers, n=10): | |
''' | |
Uses a mean aggregation method | |
:params | |
yum -> FastText Word2Vec Obj | |
leftovers -> list of str | |
n -> int top_n to return | |
:returns | |
output -> top_n recommendations | |
''' | |
leftovers_embedding_sum = np.zeros([32,]) | |
for ingredient in leftovers: | |
# pdb.set_trace() | |
ingredient_embedding = yum.get_vector(ingredient, norm=True) | |
leftovers_embedding_sum += ingredient_embedding | |
leftovers_embedding = leftovers_embedding_sum / len(leftovers) # Embedding for leftovers | |
top_matches = yum.similar_by_vector(leftovers_embedding, topn=100) | |
top_matches = [(x[0].replace('_',' '), x[1]) for x in top_matches] | |
output = [x for x in top_matches if not any(ignore in x[0] for ignore in leftovers)] # Remove boring same item matches, e.g. "romaine lettuce" if leftovers already contain "lettuce". | |
return output[:n] | |
# Compare β Find intersection | |
def recommend_ingredients_intersect(yum, leftovers, n=10): | |
''' | |
Finds top combined probabilities | |
:params | |
yum -> FastText Word2Vec Obj | |
leftovers -> list of str | |
n -> int top_n to return | |
:returns | |
output -> top_n recommendations | |
''' | |
first = True | |
for ingredient in leftovers: | |
ingredient_embedding = yum.get_vector(ingredient, norm=True) | |
ingredient_matches = yum.similar_by_vector(ingredient_embedding, topn=10000) | |
ingredient_matches = [(x[0].replace('_',' '), x[1]) for x in ingredient_matches] | |
ingredient_output = [x for x in ingredient_matches if not any(ignore in x[0] for ignore in leftovers)] # Remove boring same item matches, e.g. "romaine lettuce" if leftovers already contain "lettuce". | |
if first: | |
output = ingredient_output | |
first = False | |
else: | |
output = [x for x in output for y in ingredient_output if x[0] == y[0]] | |
return output[:n] | |
def recommend_ingredients_subsets(model, yum,leftovers, subset_size): | |
''' | |
Returns all subsets from each ingredient | |
:params | |
model -> FastText Obj | |
yum -> FastText Word2Vec Obj | |
leftovers -> list of str | |
n -> int top_n to return | |
:returns | |
output -> top_n recommendations | |
''' | |
all_outputs = {} | |
for leftovers_subset in itertools.combinations(leftovers, subset_size): | |
leftovers_embedding_sum = np.zeros([32,]) | |
for ingredient in leftovers_subset: | |
ingredient_embedding = yum.word_vec(ingredient, use_norm=True) | |
leftovers_embedding_sum += ingredient_embedding | |
leftovers_embedding = leftovers_embedding_sum / len(leftovers_subset) # Embedding for leftovers | |
top_matches = model.similar_by_vector(leftovers_embedding, topn=100) | |
top_matches = [(x[0].replace('_',' '), x[1]) for x in top_matches] | |
output = [x for x in top_matches if not any(ignore in x[0] for ignore in leftovers_subset)] # Remove boring same item matches, e.g. "romaine lettuce" if leftovers already contain "lettuce". | |
all_outputs[leftovers_subset] = output[:10] | |
return all_outputs | |
def filter_adjectives(data): | |
''' | |
Remove adjectives that are not associated with a food item | |
:params | |
data | |
:returns | |
data | |
''' | |
recipe_ingredients_token = [nltk.word_tokenize(x) for x in data] | |
inds = [] | |
for i, r in enumerate(recipe_ingredients_token): | |
out = nltk.pos_tag(r) | |
out = [x[1] for x in out] | |
if len(out) > 1: | |
inds.append(int(i)) | |
elif 'NN' in out or 'NNS' in out: | |
inds.append(int(i)) | |
return [data[i] for i in inds] | |
def plural_to_singular(lemma, recipe): | |
''' | |
:params | |
lemma -> nltk lemma Obj | |
recipe -> list of str | |
:returns | |
recipe -> converted recipe | |
''' | |
return [lemma.lemmatize(r) for r in recipe] | |
def filter_lemma(data): | |
''' | |
Convert plural to roots | |
:params | |
data -> list of lists | |
:returns | |
data -> returns filtered data | |
''' | |
# Initialize Lemmatizer (to reduce plurals to stems) | |
lemma = nltk.wordnet.WordNetLemmatizer() | |
# NOTE: This uses all the computational resources of your computer | |
with ProcessPoolExecutor() as executor: | |
out = list(executor.map(plural_to_singular, itertools.repeat(lemma), data)) | |
return out | |
def train_model(data): | |
''' | |
Train fastfood text | |
NOTE: gensim==4.1.2 | |
:params | |
data -> list of lists of all recipes | |
save -> bool | |
:returns | |
model -> FastFood model obj | |
''' | |
model = FastText(data, vector_size=32, window=99, min_count=5, workers=40, sg=1) # Train model | |
return model | |
def load_model(filename='models/fastfood_orig_4.model'): | |
''' | |
Load the FastText Model | |
:params: | |
filename -> path to the model | |
:returns | |
model -> this is the full FastText obj | |
yum -> this is the FastText Word2Vec obj | |
''' | |
# Load Models | |
model = FastText.load(filename) | |
yum = model.wv | |
return model, yum | |
def load_data(filename='data/all_recipes_ingredients_lemma.pkl'): | |
''' | |
Load data | |
:params: | |
filename -> path to dataset | |
:return | |
data -> list of all recipes | |
''' | |
return pickle.load(open(filename,'rb')) | |
def plot_results(names, probs, n=5): | |
''' | |
Plots a bar chart of the names of the items vs. probability of similarity | |
:params: | |
names -> list of str | |
probs -> list of float values | |
n -> int of how many bars to show NOTE: Max = 100 | |
:return | |
fig -> return figure for plotting | |
''' | |
plt.bar(range(len(names)), probs, align='center') | |
ax = plt.gca() | |
ax.xaxis.set_major_locator(plt.FixedLocator(range(len(names)))) | |
ax.xaxis.set_major_formatter(plt.FixedFormatter(names)) | |
ax.set_ylabel('Probability',fontsize='large', fontweight='bold') | |
ax.set_xlabel('Ingredients', fontsize='large', fontweight='bold') | |
ax.xaxis.labelpad = 10 | |
ax.set_title(f'FastFood Top {n} Predictions for Leftovers = {st.session_state.leftovers}') | |
# mpld3.show() | |
fig = plt.gcf() | |
return fig | |
if __name__ == "__main__": | |
# Initialize argparse | |
# parser = argparse.ArgumentParser() | |
# Defaults | |
# data_path = 'data/all_recipes_ingredients_lemma.pkl' | |
# model_path = 'models/fastfood_lemma_4.model' | |
# Arguments | |
# parser.add_argument('-d', '--dataset', default=data_path, type=str, help="the filepath of the dataset") | |
# parser.add_argument('-t', '--train', default=False, type=bool, help="the filepath of the dataset") | |
# parser.add_argument('-m', '--model', default=model_path, type=str, help="the filepath of the dataset") | |
# args = parser.parse_args() | |
# print(args) | |
## Train or Test ## | |
# if args.train: | |
# # Load Dataset | |
# data = load_data(args.dataset) #pickle.load(open(args.dataset, 'rb')) | |
# # model = train_model(data) | |
# # model_path = input("Model filename and directory [eg. models/new_model.model]: ") | |
# # model.save(model_path) | |
# else: | |
# gdown.download('https://drive.google.com/uc?id=1fXGsWEbr-1BftKtOsnxc61cM3akMAIC0', 'fastfood.pth') | |
# gdown.download('https://drive.google.com/uc?id=1h_TijdSw1K9RT3dnlfIg4xtl8WPNNQmn', 'fastfood.pth.wv.vectors_ngrams.npy') | |
model, yum = load_model('fastfood.pth') | |
##### UI/UX ##### | |
## Sidebar ## | |
add_selectbox = st.sidebar.selectbox( | |
"Food Utilization App", | |
("FastFood Recommendation Model", "Food Donation Resources", "Contact Team") | |
) | |
## Selection Tool ## | |
st.multiselect("Select leftovers", list(yum.key_to_index.keys()), default=['bread', 'lettuce'], key="leftovers") | |
## Slider ## | |
st.slider("Number of Recommendations", min_value=1, max_value=100, value=5, step=1, key='top_n') | |
## Get food recommendation ## | |
out = recommend_ingredients(yum, st.session_state.leftovers, n=st.session_state.top_n) | |
names = [o[0] for o in out] | |
probs = [o[1] for o in out] | |
st.checkbox(label="Show model score", value=False, key="probs") | |
if st.session_state.probs: | |
st.table(data=out) | |
else: | |
st.table(data=names) | |
## Plot Results ## | |
st.checkbox(label="Show model bar chart", value=False, key="plot") | |
if st.session_state.plot: | |
fig = plot_results(names, probs, st.session_state.top_n) | |
## Show Plot ## | |
st.pyplot(fig) | |