Jasminder commited on
Commit
a468d98
1 Parent(s): 94e7cfe

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/restaurant-menus.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ data/
3
+ .venv/
4
+ .vscode/
5
+ *.csv
6
+ .env
README.md CHANGED
@@ -1,13 +1,16 @@
1
- ---
2
- title: Food Feud
3
- emoji: 🏃
4
- colorFrom: pink
5
- colorTo: yellow
6
- sdk: streamlit
7
- sdk_version: 1.40.2
8
- app_file: app.py
9
- pinned: false
10
- short_description: Survey game to recommend recipes.
11
- ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
1
+ # Food Feud
2
+ ## How to run locally
3
+ Python 3 and git are required.
4
+ 1) `https://github.com/jsgarcha/food-feud`
5
+ 2) `cd ./food-feud`
6
+ 3) `pip install -r requirements.txt`
7
+ 4) `python clean_data.py`
8
+ 5) `streamlit run main.py`
 
 
 
9
 
10
+ NOTE: JAX v0.4.36 does not work for this Huggingface model (https://huggingface.co/flax-community/t5-recipe-generation)
11
+ `pip install --force-reinstall -v "jax==0.4.34"`
12
+
13
+ Running the first time may take a minute or so, depending on your internet connection, because the model has to be downloaded from Huggingface (~900mb) .
14
+ Subsequent executions will not pause for long since the model will already be in cache.
15
+
16
+ You also need to provide your own key for Gemini in `.env` under the `GEMINI_API_KEY` key.
clean_data.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import kagglehub
3
+ import shutil
4
+ import os
5
+
6
+ # Download data
7
+ kaggle_path = kagglehub.dataset_download("ahmedshahriarsakib/uber-eats-usa-restaurants-menus")
8
+ data_path = "data/"
9
+
10
+ print("Downloaded datasets from Kaggle.")
11
+
12
+ if not os.path.exists(data_path):
13
+ os.makedirs(data_path)
14
+
15
+ for file in os.listdir(kaggle_path):
16
+ source = os.path.join(kaggle_path, file)
17
+ destination = os.path.join(data_path, file)
18
+ if os.path.isfile(source):
19
+ shutil.copy(source, destination)
20
+
21
+ print("Moved datasets to data/")
22
+
23
+ # Load data
24
+ restaurants_df = pd.read_csv('data/restaurants.csv')
25
+ print("Loaded data.")
26
+
27
+ # Clean data
28
+ restaurants_df = restaurants_df.dropna(subset=['category']) # Drop rows with null values in 'category'
29
+
30
+ restaurants_df = restaurants_df[
31
+ (restaurants_df['price_range'].isin(['$$', '$$$', '$$$$'])) & # Keep $$ to $$$$
32
+ (restaurants_df['score'] >= 3.5) # Keep ratings 3.5 and above
33
+ ]
34
+
35
+ # Splitting each entry in category into a single series element
36
+ all_categories = (
37
+ restaurants_df['category']
38
+ .str.lower() # Convert all entries to lowercase
39
+ .str.split(', ') # Split each entry into a list by ", "
40
+ .explode() # Flatten the lists into a single series
41
+ .str.strip() # Remove any leading/trailing whitespace
42
+ )
43
+ # Define the 7 selected categories
44
+ selected_categories = ['steak', 'chinese', 'japanese', 'italian', 'indian', 'mediterranean']
45
+
46
+ # Finding the frequencies of each unique category
47
+ category_counts = all_categories.value_counts()
48
+ sorted_category_counts = category_counts.sort_values(ascending=False)
49
+
50
+ # Create an updated dataset
51
+ selected_categories = ['steak', 'chinese', 'japanese', 'italian', 'indian', 'mediterranean']
52
+
53
+ # Final DataFrame to store the results
54
+ final_result = pd.DataFrame()
55
+
56
+ # Loop through each category
57
+ for category in selected_categories:
58
+ # Filter rows where the category is in the 'categories' column
59
+ filtered = restaurants_df[restaurants_df['category'].str.contains(category, case=False, na=False)]
60
+
61
+ # Sort by ratings in descending order and select the top 100 entries
62
+ top_entries = filtered.sort_values(by='score', ascending=False).head(100)
63
+
64
+ # Append the results to the final DataFrame
65
+ final_result = pd.concat([final_result, top_entries])
66
+
67
+ # Reset index for the final result
68
+ final_result = final_result.reset_index(drop=True)
69
+ final_result = final_result.sort_values(by = 'score', ascending=False)
70
+
71
+ # Export the final dataset to a CSV file
72
+ result_file = "top_restaurants.csv"
73
+ final_result.to_csv(data_path+result_file, index=False)
74
+ print("Cleaned data and exported to "+data_path+result_file)
data/restaurant-menus.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fe16b49b5db6b35b7522c6f6861c52f965c16ab610c7b24113dd7cc9ec50c20
3
+ size 870834478
data/restaurants.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/top_restaurants.csv ADDED
The diff for this file is too large to render. See raw diff
 
gemini.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import google.generativeai as genai
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+ genai.configure(api_key=os.getenv('GEMINI_API_KEY'))
7
+
8
+ model = genai.GenerativeModel(model_name="gemini-1.5-flash", generation_config={"temperature": 2,"top_p": 0.95,"top_k": 40,"max_output_tokens": 8192,"response_mime_type": "application/json"})
9
+ chat_session = model.start_chat()
main.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ import json
4
+ import re
5
+ import random
6
+ from recipe_generator import generation_function
7
+ from gemini import chat_session
8
+
9
+ data_path = "data/"
10
+ data_file = "top_restaurants.csv"
11
+
12
+ RESTAURANT_SURVEY_STAGE = 1
13
+ RECIPE_GENERATION_STAGE = 2
14
+
15
+ LIKE_NUMBER = 20
16
+
17
+ top_food_categories = ['Steak', 'Chinese', 'Japanese', 'Italian', 'Indian', 'Mediterranean'] # "Top" is relative to our data set; meaning, these categories exhibited the "cleanest" data. To be changed later.
18
+
19
+ st.markdown("<h1 style='text-align: center'>Food Feud</h1>", unsafe_allow_html=True)
20
+
21
+ if 'stage' not in st.session_state:
22
+ st.session_state.stage = RESTAURANT_SURVEY_STAGE # Start stage
23
+
24
+ if "like" not in st.session_state:
25
+ st.session_state.like = []
26
+
27
+ if "dislike" not in st.session_state:
28
+ st.session_state.dislike = []
29
+
30
+ if "like_count" not in st.session_state:
31
+ st.session_state.like_count = LIKE_NUMBER
32
+
33
+ if "survey_progress" not in st.session_state:
34
+ st.session_state.survey_progress = 0
35
+
36
+ @st.cache_data
37
+ def load_restaurant_data():
38
+ return pd.read_csv(data_path+data_file)
39
+
40
+ def clear_string(s):
41
+ return re.sub(r"\(.*?\)", "", s).split('-')[0].replace("&amp;", "&").strip()
42
+
43
+ def add_like(like): # Row in a DataFrame
44
+ st.session_state.like.append(like) # Build up likes
45
+ if st.session_state.survey_progress < 100:
46
+ st.session_state.survey_progress += 100//LIKE_NUMBER
47
+ st.session_state.like_count -= 1
48
+ survey_progress_bar.progress(st.session_state.survey_progress, text=f"Select {st.session_state.like_count} more.")
49
+
50
+ def add_dislike(dislike):
51
+ st.session_state.dislike.append(dislike)
52
+
53
+ def generate_recipe(ingredients):
54
+ generated = generation_function(ingredients)
55
+ sections = generated.split("\n")
56
+ for section in sections:
57
+ section = section.strip()
58
+ if section.startswith("title:"):
59
+ section = section.replace("title:", "")
60
+ headline = "TITLE"
61
+ elif section.startswith("ingredients:"):
62
+ section = section.replace("ingredients:", "")
63
+ headline = "Ingredients"
64
+ elif section.startswith("directions:"):
65
+ section = section.replace("directions:", "")
66
+ headline = "Directions"
67
+
68
+ if headline == "TITLE":
69
+ st.markdown("<h3 style='text-align: center'>"+str(section.strip().capitalize())+"</h3>", unsafe_allow_html=True)
70
+ else:
71
+ section_info = [f" - {info.strip().capitalize()}" for i, info in enumerate(section.split("--"))]
72
+ st.markdown("<h4>"+f'{headline}'+"</h4>", unsafe_allow_html=True)
73
+ st.write("\n".join(section_info))
74
+
75
+ df_restaurants = load_restaurant_data()
76
+
77
+ placeholder = st.empty()
78
+
79
+ if st.session_state.stage == RESTAURANT_SURVEY_STAGE:
80
+ with placeholder.container():
81
+ st.markdown("<h4 style='text-align: center'>Start by taking our survey of eating establishments whose food you enjoy.</h4>", unsafe_allow_html=True)
82
+ survey_progress_bar = st.progress(st.session_state.survey_progress, text=f"Select {st.session_state.like_count} more.")
83
+ random_restaurant = df_restaurants.sample()
84
+ st.markdown("<h3 style='text-align: center'>"+clear_string(random_restaurant.iloc[0]['name'])+"</h3>", unsafe_allow_html=True)
85
+ col1, col2 = st.columns(2)
86
+ if col1.button('Yes 👍', type="secondary", use_container_width=True):
87
+ add_like(random_restaurant)
88
+ if col2.button('No 👎', type="secondary", use_container_width=True):
89
+ add_dislike(random_restaurant)
90
+
91
+ if st.session_state.like_count == 0 and st.session_state.stage != RECIPE_GENERATION_STAGE:
92
+ placeholder.empty()
93
+ st.balloons()
94
+ st.session_state.stage = RECIPE_GENERATION_STAGE
95
+
96
+ if st.session_state.stage == RECIPE_GENERATION_STAGE:
97
+ df_restaurant_likes = pd.concat(st.session_state.like)
98
+
99
+ st.markdown("<h4 style='text-align: center'>Now generate recipes based on the restaurants your liked!</h4>", unsafe_allow_html=True)
100
+ col = st.columns([1])[0] # One column with equal width
101
+ with col:
102
+ if st.button('Generate Recipe!', type='primary', use_container_width=True):
103
+ liked_restaurant = df_restaurant_likes.sample()
104
+ liked_restaurant_categories = liked_restaurant['category'].values[0]
105
+ liked_restaurant_category = [category for category in top_food_categories if category in liked_restaurant_categories][0]
106
+
107
+ response = chat_session.send_message(f"List common ingredients in {liked_restaurant_category} food.")
108
+ model_response = response.text
109
+ response = json.loads(model_response)
110
+ ingredients = response['ingredients']
111
+ random.shuffle(ingredients) # Change
112
+
113
+ st.markdown(
114
+ "<h4 style='text-align: center'>Based on your like of <span style='color: red;'>"
115
+ + clear_string(liked_restaurant.iloc[0]['name']) +
116
+ "</span>, survey says...</h4>",
117
+ unsafe_allow_html=True
118
+ )
119
+ generate_recipe(','.join(map(str, ingredients)))
120
+
121
+ # 2 major things to be fixed:
122
+ # 1) Huggingface model input giving more than 1 recipe, but limiting to 1 produces the same recipe
123
+ # 2) Decision function - do more analytics to determine top restaurant; also randomize ingredients list
recipe_generator.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import FlaxAutoModelForSeq2SeqLM
2
+ from transformers import AutoTokenizer
3
+ import streamlit
4
+
5
+ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
6
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
7
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
8
+
9
+ prefix = "items: "
10
+ generation_kwargs = {
11
+ "max_length": 512,
12
+ "min_length": 64,
13
+ "no_repeat_ngram_size": 3,
14
+ "do_sample": True,
15
+ "top_k": 60,
16
+ "top_p": 0.95
17
+ }
18
+
19
+ special_tokens = tokenizer.all_special_tokens
20
+ tokens_map = {
21
+ "<sep>": "--",
22
+ "<section>": "\n"
23
+ }
24
+
25
+ def skip_special_tokens(text, special_tokens):
26
+ for token in special_tokens:
27
+ text = text.replace(token, "")
28
+ return text
29
+
30
+ def target_postprocessing(texts, special_tokens):
31
+ if not isinstance(texts, list):
32
+ texts = [texts]
33
+
34
+ new_texts = []
35
+ for text in texts:
36
+ text = skip_special_tokens(text, special_tokens)
37
+
38
+ for k, v in tokens_map.items():
39
+ text = text.replace(k, v)
40
+
41
+ new_texts.append(text)
42
+
43
+ return new_texts
44
+
45
+ def generation_function(text):
46
+ # Ensure the input is a single string
47
+ _input = prefix + str(text)
48
+ inputs = tokenizer(
49
+ _input,
50
+ max_length=256,
51
+ padding="max_length",
52
+ truncation=True,
53
+ return_tensors="jax"
54
+ )
55
+
56
+ input_ids = inputs.input_ids
57
+ attention_mask = inputs.attention_mask
58
+
59
+ # Generate the output sequence
60
+ output_ids = model.generate(
61
+ input_ids=input_ids,
62
+ attention_mask=attention_mask,
63
+ **generation_kwargs
64
+ )
65
+ generated = output_ids.sequences
66
+ generated_recipe = target_postprocessing(
67
+ tokenizer.batch_decode(generated, skip_special_tokens=False),
68
+ special_tokens
69
+ )
70
+ return generated_recipe[0] # Only return the first recipe generated
requirements.txt ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ altair==5.5.0
3
+ annotated-types==0.7.0
4
+ attrs==24.2.0
5
+ blinker==1.9.0
6
+ cachetools==5.5.0
7
+ certifi==2024.8.30
8
+ charset-normalizer==3.4.0
9
+ chex==0.1.87
10
+ click==8.1.7
11
+ etils==1.11.0
12
+ filelock==3.16.1
13
+ flax==0.10.2
14
+ fsspec==2024.10.0
15
+ gitdb==4.0.11
16
+ GitPython==3.1.43
17
+ google-ai-generativelanguage==0.6.10
18
+ google-api-core==2.23.0
19
+ google-api-python-client==2.154.0
20
+ google-auth==2.36.0
21
+ google-auth-httplib2==0.2.0
22
+ google-generativeai==0.8.3
23
+ googleapis-common-protos==1.66.0
24
+ grpcio==1.68.1
25
+ grpcio-status==1.68.1
26
+ httplib2==0.22.0
27
+ huggingface-hub==0.26.5
28
+ humanize==4.11.0
29
+ idna==3.10
30
+ importlib_resources==6.4.5
31
+ jax==0.4.34
32
+ jaxlib==0.4.34
33
+ Jinja2==3.1.4
34
+ jsonschema==4.23.0
35
+ jsonschema-specifications==2024.10.1
36
+ kagglehub==0.3.4
37
+ markdown-it-py==3.0.0
38
+ MarkupSafe==3.0.2
39
+ mdurl==0.1.2
40
+ ml_dtypes==0.5.0
41
+ msgpack==1.1.0
42
+ narwhals==1.16.0
43
+ nest-asyncio==1.6.0
44
+ numpy==2.1.3
45
+ opt_einsum==3.4.0
46
+ optax==0.2.4
47
+ orbax-checkpoint==0.10.2
48
+ packaging==24.2
49
+ pandas==2.2.3
50
+ pillow==11.0.0
51
+ proto-plus==1.25.0
52
+ protobuf==5.29.1
53
+ pyarrow==18.1.0
54
+ pyasn1==0.6.1
55
+ pyasn1_modules==0.4.1
56
+ pydantic==2.10.3
57
+ pydantic_core==2.27.1
58
+ pydeck==0.9.1
59
+ Pygments==2.18.0
60
+ pyparsing==3.2.0
61
+ python-dateutil==2.9.0.post0
62
+ python-dotenv==1.0.1
63
+ pytz==2024.2
64
+ PyYAML==6.0.2
65
+ referencing==0.35.1
66
+ regex==2024.11.6
67
+ requests==2.32.3
68
+ rich==13.9.4
69
+ rpds-py==0.22.3
70
+ rsa==4.9
71
+ safetensors==0.4.5
72
+ scipy==1.14.1
73
+ simplejson==3.19.3
74
+ six==1.17.0
75
+ smmap==5.0.1
76
+ streamlit==1.40.2
77
+ tenacity==9.0.0
78
+ tensorstore==0.1.69
79
+ tokenizers==0.21.0
80
+ toml==0.10.2
81
+ toolz==1.0.0
82
+ tornado==6.4.2
83
+ tqdm==4.67.1
84
+ transformers==4.47.0
85
+ typing_extensions==4.12.2
86
+ tzdata==2024.2
87
+ uritemplate==4.1.1
88
+ urllib3==2.2.3
89
+ watchdog==6.0.0
90
+ zipp==3.21.0