prithivida commited on
Commit
a72a046
β€’
1 Parent(s): 3bc0a50

Initial Commit

Browse files
Files changed (1) hide show
  1. app.py +182 -2
app.py CHANGED
@@ -1,4 +1,184 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
1
  import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ import pickle
4
+ import sentence_transformers
5
+ from sentence_transformers import SentenceTransformer, util
6
+ from PIL import Image
7
+ import torch
8
+ import spacy
9
+ import os
10
+ import glob
11
+ import random
12
+
13
+ torch.set_num_threads(4)
14
+
15
+
16
+ def get_spacy_dbpedia_highlights(ingredients):
17
+ import spacy
18
+ import spacy_dbpedia_spotlight
19
+
20
+ raw_ingredients = ingredients
21
+ import re
22
+ ingredients = re.sub("[0-9,()\/\-\.]", "", ingredients)
23
+ doc = nlp(ingredients)
24
+
25
+ for ent in doc.ents:
26
+ if ent.text.lower() not in stop_words and ent.text in raw_ingredients:
27
+ replace_str = '<mark style="color: green; background-color:yellow"> <a href="' + ent.kb_id_ + '" target="_blank"> ' + ent.text + '</a> </mark>'
28
+ raw_ingredients = raw_ingredients.replace(ent.text, replace_str)
29
+ return raw_ingredients
30
+
31
+ def detect_food(query, text_emb, labels, k=1):
32
+ query_emb = model.encode(Image.open(query), convert_to_tensor=True, show_progress_bar=False)
33
+ hits = util.semantic_search(query_emb, text_emb, top_k=k)[0]
34
+ results = []
35
+ for i, hit in enumerate(hits):
36
+ results.append((labels[hit['corpus_id']], hit['score']))
37
+ if i > 2:
38
+ break
39
+ return results
40
+
41
+ def run_search(food_image, col2):
42
+
43
+ with open("./Pretrained/labels.pkl", 'rb') as fIn:
44
+ labels = pickle.load(fIn)
45
+
46
+ emb_filename = './Pretrained/food_embeddings.pkl'
47
+ text_emb = torch.load(emb_filename, map_location=torch.device('cpu'))
48
+
49
+ results = detect_food(food_image, text_emb, labels, 3)
50
+ food_recognised, score = results[0]
51
+
52
+ del text_emb
53
+ del labels
54
+
55
+ import pysos
56
+ id2recipe = pysos.Dict("./Pretrained/id2recipe")
57
+ food2id = pysos.Dict("./Pretrained/food2id")
58
+
59
+
60
+ id = food2id[food_recognised]
61
+
62
+ recipe_name = food_recognised.title()
63
+ ingredients_list =id2recipe[id]['ingredients']
64
+ highlighted_ingredients= get_spacy_dbpedia_highlights(ingredients_list)
65
+ recipe= id2recipe[id]['instructions']
66
+ dataset = " " + id2recipe[id]['dataset']
67
+ if dataset.strip() == "Recipe1M":
68
+ nutritional_facts= "For nutritional facts, schedule and servings, visit the link in the footer"
69
+ else:
70
+ nutritional_facts = id2recipe[id]['nutrition_facts']
71
+ source= id2recipe[id]['recipesource']
72
+
73
+
74
+ del id2recipe
75
+ del food2id
76
+
77
+ st.markdown("<br/>", unsafe_allow_html=True)
78
+ with col2:
79
+ st.markdown("<b>Top 3 predictions &nbsp </b>", unsafe_allow_html=True)
80
+ results_static_tag = '<html><title>W3.CSS</title><meta name="viewport" content="width=device-width, initial-scale=1"><link rel="stylesheet" href="https://www.w3schools.com/w3css/4/w3.css"><body><div class="w3-container">{}</div></body></html>'
81
+ result_rows = ""
82
+ for i, result in enumerate(results):
83
+ results_dynamic_tag= '{} <br/> <div class="w3-light-grey"> <div class="{}" style="height:4px;width:{}%"></div> </div><br>'
84
+ if i == 0:
85
+ results_dynamic_tag = results_dynamic_tag.format("<b>" + str(i+1) + "." + result[0].title() + "</b>", 'w3-blue', result[1] * 100)
86
+ else:
87
+ results_dynamic_tag = results_dynamic_tag.format(str(i+1) + "." + result[0].title(), "w3-orange" ,result[1] * 100)
88
+ result_rows += results_dynamic_tag
89
+ results_static_tag = results_static_tag.format(result_rows)
90
+ st.markdown(results_static_tag, unsafe_allow_html=True)
91
+
92
+ title_tag = '<h4> Recipe for top result: &nbsp' + recipe_name + '</h4>'
93
+ st.markdown(title_tag, unsafe_allow_html=True)
94
+
95
+ ing_hdr_tag = '<h5> Ingredients </h5>'
96
+ ing_style= "{border: 3x outset white; background-color: #ccf5ff; color: black; text-align: left; font-size: 14px; padding: 5px;}"
97
+ ing_tag = '<html><head><style>.ingdiv{}</style></head><body><div class="ingdiv">{}</div></body></html>'
98
+ ing_tag = ing_tag.format(ing_style, highlighted_ingredients.strip())
99
+ st.markdown(ing_hdr_tag, unsafe_allow_html=True)
100
+ st.markdown(ing_tag + "<br/>", unsafe_allow_html=True)
101
+
102
+
103
+ rec_hdr_tag = '<h5> Recipe </h5>'
104
+ rec_style= "{border: 3x outset white; background-color: #ffeee6; color: black; text-align: left; font-size: 14px; padding: 5px;}"
105
+ rec_tag = '<html><head><style>.recdiv{}</style></head><body><div class="recdiv">{}</div></body></html>'
106
+ rec_tag = rec_tag.format(rec_style, recipe.strip())
107
+ st.markdown(rec_hdr_tag, unsafe_allow_html=True)
108
+ st.markdown(rec_tag + "<br/>", unsafe_allow_html=True)
109
+
110
+
111
+ src_hdr_tag = '<h5> Recipe source </h5>'
112
+ src_tag = '<a href={} target="_blank">{}</a>'
113
+ src_tag = src_tag.format(source, source)
114
+ st.markdown(src_hdr_tag, unsafe_allow_html=True)
115
+ st.markdown(src_tag + "<br/>", unsafe_allow_html=True)
116
+
117
+ return 1
118
+
119
+ if 'models_loaded' not in st.session_state:
120
+ st.session_state['models_loaded'] = False
121
+
122
+ st.title('WTF - What The Food 🀬')
123
+ st.subheader("Image to Recipe - 1.5M foods supported")
124
+ st.markdown("Built for fun with πŸ’™ by a quintessential foodie - Prithivi Da | [@prithivida](https://twitter.com/prithivida) |[[GitHub]](https://github.com/PrithivirajDamodaran) <br/> <hr style='height:1px;border:none;color:violet;background-color:gray;' />", unsafe_allow_html=True)
125
+ st.write("""Read Me: The goal is to detect a "Single food item" from the image and retrieve it's recipe. So by design the model works well on single foods. It works on platters too fx English breakfast but it may not perform well on a custom combination with multiple recipes or hyper-local foods.
126
+ """)
127
+
128
+
129
+ def load_image(image_file):
130
+ img = Image.open(image_file)
131
+ return img
132
+
133
+ def load_models():
134
+ with st.spinner(text="Loading Models..."):
135
+ os.system("python -m spacy download en_core_web_sm")
136
+ nlp = spacy.load('en_core_web_sm')
137
+ nlp.add_pipe('dbpedia_spotlight')
138
+ model = SentenceTransformer('clip-ViT-B-32')
139
+ stop_words = set(['chopped', 'freshly ground', 'freshly squeezed', 'dash', 'powder', 'rice', 'ice', 'noodles', 'pepper', 'milk', 'ced', 'cheese', 'sugar', 'salt', 'pkt', 'minced', 'onion', 'onions', 'garlic', 'butter', 'slices', 'ounce', 'sauce', 'freshly', 'grated', 'teaspoon', 'cup', 'oz', '⁄', 'to', 'or', 'diced', 'into', 'pound', 'dried', 'water', 'about', 'whole', 'small', 'vegetable', 'inch', 'tbsp', 'cooked', 'large', 'sliced', 'dry', 'optional', 'package', 'ounces', 'unsalted', 'lbs', 'green', 'flour', 'for', 'wine', 'crushed', 'drained', 'lb', 'frozen', 'tsp', 'finely', 'medium', 'tablespoon', 'tablespoons', 'juice', 'shredded', 'can', 'minced', 'fresh', 'cut', 'pieces', 'in', 'thinly', 'of', 'extract', 'teaspoons', 'ground', 'and', 'cups', 'peeled', 'taste', 'ml', 'lengths'])
140
+ st.session_state['nlp'] = nlp
141
+ st.session_state['model'] = model
142
+ st.session_state['stop_words'] = stop_words
143
+
144
+
145
+
146
+ if not st.session_state['models_loaded']:
147
+ load_models()
148
+ st.session_state['models_loaded'] = True
149
+
150
+ random_button = st.button('⚑ Try a Random Food')
151
+ st.write("(or)")
152
+ image_file = st.file_uploader("Tip: Upload HD images for better results.", type=["jpg","jpeg"])
153
+
154
+ nlp = st.session_state['nlp']
155
+ model = st.session_state['model']
156
+ stop_words = st.session_state['stop_words']
157
+ col1, col2 = st.columns(2)
158
+
159
+ if random_button:
160
+
161
+ with st.spinner(text="Detecting food..."):
162
+ samples = glob.glob('./samples' + "/*")
163
+ random_sample = random.choice(samples)
164
+ pil_image = load_image(random_sample)
165
+ with col1:
166
+ st.image(pil_image, use_column_width='auto')
167
+ return_code = run_search(random_sample, col2)
168
+ else:
169
+ if image_file is not None:
170
+ pil_image = load_image(image_file)
171
+ with open(image_file.name, 'wb') as f:
172
+ pil_image.save(f)
173
+
174
+ with col1:
175
+ st.image(pil_image, use_column_width='auto')
176
+
177
+ with st.spinner(text="Detecting food..."):
178
+ return_code = run_search(image_file.name, col2)
179
+ os.system('rm -r "' + image_file.name + '"')
180
+
181
+
182
+
183
+
184