Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import numpy as np | |
import time | |
import string | |
import pandas as pd | |
import numpy as np | |
from transformers import BertTokenizer, BertModel | |
from collections import defaultdict, Counter | |
from tqdm.auto import tqdm | |
from sklearn.metrics.pairwise import cosine_similarity | |
import time | |
import random | |
#Loading the model | |
def get_models(): | |
st.write('Loading the model...') | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
model = BertModel.from_pretrained("bert-base-uncased") | |
st.write("_The model is loaded and ready to use! :tada:_") | |
return model, tokenizer | |
#convert numpy arrays from strings back to arrays | |
def str_to_numpy(array_string): | |
array_string = array_string.replace('\n', '').replace('[','').replace(']','') | |
numpy_array = np.fromstring(array_string, sep=' ') | |
numpy_array = numpy_array.reshape((1, -1)) | |
return numpy_array | |
# ๐ Add the caching decorator | |
def load_data(): | |
vectors_df = pd.read_csv('filtered_restaurants_dataframe_with_embeddings.csv', encoding="utf-8") | |
embeds = dict(enumerate(vectors_df['Embeddings'])) | |
rest_names = list(vectors_df['Names']) | |
vectors_df['Weights'] = [1]*len(vectors_df) | |
return embeds, rest_names, vectors_df | |
#type: dict; keys: 0-n | |
restaurants_embeds, rest_names, init_df = load_data() | |
model, tokenizer = get_models() | |
#a function that takes a sentence and converts it into embeddings | |
def get_bert_embeddings(sentence, model, tokenizer): | |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embeddings = outputs.last_hidden_state.mean(dim=1) # Average pool over tokens | |
return embeddings | |
# a function that return top-K best restaurants | |
def compute_cos_sim(input): | |
query = "" | |
query += input | |
# for el in st.session_state.preferences_1: | |
# query += el | |
# for el in st.session_state.preferences_2: | |
# query += el | |
# st.write("Your query is", query) | |
# st.write("Your restrictions are", st.session_state.restrictions) | |
embedded_query = get_bert_embeddings(query, model, tokenizer) | |
embedded_query = embedded_query.numpy() | |
top_similar = np.array([]) | |
for i in range(len(restaurants_embeds)): | |
name = rest_names[i] | |
top_similar = np.append(top_similar, cosine_similarity(embedded_query, str_to_numpy(restaurants_embeds[i]))[0][0]) | |
st.session_state.df['cos_sim'] = top_similar.tolist() | |
weights = np.array(st.session_state.df['Weights']) | |
#multiply weights by the cosine similarity | |
top_similar_weighted = dict(enumerate(np.multiply(top_similar, weights))) | |
st.session_state.df['Relevancy'] = top_similar_weighted.values() | |
return st.session_state.df | |
def sort_by_relevancy(k): | |
''' | |
k - int - how many top-matching places to show | |
''' | |
top_similar_weighted = dict(enumerate(st.session_state.precalculated_df['Relevancy'])) | |
#sort in the descending order | |
top_similar_weighted = dict(sorted(top_similar_weighted.items(), key=lambda item: item[1], reverse=True)) | |
#leave only K recommendations | |
top_k_similar = dict([(key, value) for key, value in top_similar_weighted.items()][:k]) | |
#get restaurant names | |
names = [rest_names[i] for i in top_k_similar.keys()] | |
result = dict(zip(names, top_k_similar.values())) | |
return result | |
def sort_by_price(k): | |
''' | |
k - int - how many top-matching places to show | |
''' | |
relevance = np.array(st.session_state.precalculated_df['Relevancy']) | |
prices = np.array([st.session_state.price[str(val)] for val in st.session_state.precalculated_df['Price']]) | |
top_similar_by_price = dict(enumerate(np.multiply(relevance, prices))) | |
st.session_state.precalculated_df['Sort_price'] = top_similar_by_price.values() | |
#sort in the descending order | |
top_similar_by_price = dict(sorted(top_similar_by_price.items(), key=lambda item: item[1], reverse=True)) | |
#leave only K recommendations | |
top_k_similar = dict([(key, value) for key, value in top_similar_by_price.items()][:k]) | |
#get restaurant names | |
names = [rest_names[i] for i in top_k_similar.keys()] | |
result = dict(zip(names, top_k_similar.values())) | |
return result | |
def sort_by_rating(k): | |
''' | |
k - int - how many top-matching places to show | |
''' | |
relevance = np.array(st.session_state.precalculated_df['Relevancy']) | |
rating = np.array(st.session_state.precalculated_df['Rating']) | |
top_similar_by_rating = dict(enumerate(np.multiply(relevance, rating))) | |
## Combine the three lists into a list of tuples (name, score, price) | |
# restaurant_data = list(zip(rest_names, relevance, rating)) | |
# # Sort the combined list based on rating (index 2) in descending order and relevance (index 1) in descending order | |
# sorted_data = sorted(restaurant_data, key=lambda x: (-x[1], -x[2])) | |
# # Extract the sorted lists | |
# sorted_restaurant_names, sorted_relevance, sorted_rating = zip(*sorted_data) | |
# result = {sorted_restaurant_names[i]: sorted_relevance[i] for i in range(k)} | |
st.session_state.precalculated_df['Sort_rating'] = top_similar_by_rating.values() | |
#sort in the descending order | |
top_similar_by_rating = dict(sorted(top_similar_by_rating.items(), key=lambda item: item[1], reverse=True)) | |
#leave only K recommendations | |
top_k_similar = dict([(key, value) for key, value in top_similar_by_rating.items()][:k]) | |
#get restaurant names | |
names = [rest_names[i] for i in top_k_similar.keys()] | |
result = dict(zip(names, top_k_similar.values())) | |
return result | |
#combines 2 users preferences into 1 string | |
def get_combined_preferences(user1, user2): | |
#TODO: optimize for more users | |
shared_pref = '' | |
for pref in user1: | |
shared_pref += pref.lower() | |
shared_pref += " " | |
shared_pref += " " | |
for pref in user2: | |
shared_pref += pref.lower() | |
shared_pref += " " | |
freq_words = Counter(shared_pref.split()) | |
preferences = [pref for pref in st.session_state.preferences_1 if ((pref.capitalize() in st.session_state.food) or (pref in st.session_state.ambiance))] | |
preferences.extend([pref for pref in st.session_state.preferences_2 if ((pref.capitalize() in st.session_state.food) or (pref in st.session_state.ambiance))]) | |
translator = str.maketrans('', '', string.punctuation) | |
preferences = [word.translate(translator) for phrase in preferences for word in phrase.split() if len(word) > 0] | |
st.session_state.fixed_preferences = [word.lower() for word in preferences] | |
return shared_pref, freq_words | |
def filter_places(restrictions): | |
#punish the weight of places that don't fit restrictions | |
# st.write("Here are the restrictions you provided:") | |
# st.write(restrictions) | |
taboo = set([word.lower() for word in restrictions]) | |
for i in range(len(st.session_state.df)): | |
descr = [word.lower() for word in st.session_state.df['Strings'][i].split()] | |
name = st.session_state.df['Names'][i] | |
for criteria in taboo: | |
if criteria not in descr: | |
st.session_state.df['Weights'][i] = 0.1 * st.session_state.df['Weights'][i] | |
return st.session_state.df | |
def promote_places(): | |
''' | |
input type: dict() | |
a function that takes most common words, checks if descriptions fit them, increases their weight if they do | |
''' | |
#punish the weight of places that don't fit restrictions | |
# st.write("Here are the most common preferences you provided:") | |
# st.write(st.session_state.fixed_preferences) | |
preferences = st.session_state.fixed_preferences | |
for i in range(len(st.session_state.df)): | |
descr = [word.lower() for word in st.session_state.df['Strings'][i].split()] | |
name = st.session_state.df['Names'][i] | |
for pref in preferences: | |
if pref.lower() in descr: | |
st.session_state.df['Weights'][i] = 1.5 * st.session_state.df['Weights'][i] | |
return st.session_state.df | |
def generate_results(): | |
st.session_state.results['Price'] = sort_by_price(10) | |
st.session_state.results['Rating'] = sort_by_rating(10) | |
st.session_state.results['Relevancy (default)'] = sort_by_relevancy(10) | |
st.session_state.results['Distance'] = sort_by_relevancy(10) | |
# with st.spinner("Sorting your results by relevancy..."): | |
def get_normalized_val(values): | |
if st.session_state.sort_by == 'Relevancy (default)' or st.session_state.sort_by == 'Distance': | |
# Find the minimum and maximum values | |
min_value = min(st.session_state.precalculated_df['Relevancy']) | |
max_value = max(st.session_state.precalculated_df['Relevancy']) | |
elif st.session_state.sort_by == 'Rating': | |
min_value = min(st.session_state.precalculated_df['Sort_rating']) | |
max_value = max(st.session_state.precalculated_df['Sort_rating']) | |
elif st.session_state.sort_by == 'Price': | |
min_value = min(st.session_state.precalculated_df['Sort_price']) | |
max_value = max(st.session_state.precalculated_df['Sort_price']) | |
# Define a lambda function for normalization | |
normalize = lambda x: 100 * round((x - min_value) / (max_value - min_value), 3) | |
# Use the map function to apply the lambda function to all values | |
normalized_results = dict(map(lambda item: (item[0], normalize(item[1])), values.items())) | |
return normalized_results | |
if 'preferences_1' not in st.session_state: | |
st.session_state.preferences_1 = [] | |
if 'preferences_2' not in st.session_state: | |
st.session_state.preferences_2 = [] | |
if 'fixed_preferences' not in st.session_state: | |
st.session_state.fixed_preferences = [] | |
if 'additional_1' not in st.session_state: | |
st.session_state.additional_1 = [] | |
if 'additional_2' not in st.session_state: | |
st.session_state.additional_2 = [] | |
if 'food' not in st.session_state: | |
st.session_state.food = ['Coffee', 'Italian', 'Mexican', 'Chinese', 'Indian', 'Asian', 'Fast food', 'Other'] | |
if 'ambiance' not in st.session_state: | |
st.session_state.ambiance = ['Romantic date', 'Friends catching up', 'Family gathering', 'Big group', 'Business-meeting', 'Other'] | |
if 'restrictions' not in st.session_state: | |
st.session_state.restrictions = [] | |
if 'price' not in st.session_state: | |
st.session_state.price = {'$': 2, 'โฉ': 2, '$$': 1, 'โฉโฉ': 1, '$$$': 0.5, '$$$$': 0.1, "nan": 1} | |
if 'sort_by' not in st.session_state: | |
st.session_state.sort_by = '' | |
if 'options' not in st.session_state: | |
st.session_state.options = ['Relevancy (default)', 'Price', 'Rating', 'Distance'] | |
if 'df' not in st.session_state: | |
st.session_state.df = init_df | |
if 'precalculated_df' not in st.session_state: | |
st.session_state.precalculated_df = pd.DataFrame() | |
if 'results' not in st.session_state: | |
st.session_state.results = {} | |
if 'fixed_restrictions' not in st.session_state: | |
st.session_state.fixed_restrictions = [] | |
# Configure Streamlit page and state | |
st.title("GoTogether!") | |
st.markdown("Tell us about your preferences!") | |
st.caption("In section 'Others', you can describe any wishes.") | |
# Define custom CSS styles for the orange and blue rectangles | |
css = """ | |
<style> | |
.orange-box { | |
background-color: orange; | |
border: 2px solid darkred; | |
border-radius: 10px; | |
display: inline-block; | |
padding: 5px 10px; | |
margin: 0px; | |
} | |
.blue-box { | |
background-color: #0077b6; | |
border: 2px solid navy; | |
border-radius: 10px; | |
display: inline-block; | |
padding: 5px 10px; | |
color: white; | |
} | |
.green-box { | |
border: 2px solid #004d00; /* Dark green contour */ | |
border-radius: 10px; | |
background-color: #4CAF50; /* green background */ | |
display: inline-block; | |
padding: 5px 10px; | |
color: #FFFFFF; /* White text color */ | |
} | |
.violet-box { | |
border: 2px solid #8a2be2; /* Violet contour */ | |
border-radius: 10px; | |
background-color: #4169E1; /* Blue background */ | |
display: inline-block; | |
padding: 5px 10px; | |
color: #FFFFFF; /* White text color */ | |
} | |
</style> | |
""" | |
text_css = """ | |
<style> | |
.text { | |
font-weight: bold; | |
color: #0077b6; /* Sea-blue text color */ | |
margin-right: 1px; | |
} | |
</style> | |
""" | |
# options_disability_1 = st.multiselect( | |
# 'Do you need a wheelchair?', | |
# ['Yes', 'No'], ['No'], key=101) | |
# if options_disability_1 == 'Yes': | |
# st.session_state.restrictions.append('Wheelchair') | |
st.markdown(css, unsafe_allow_html=True) | |
st.markdown(f'<div class="violet-box">User 1</div>', unsafe_allow_html=True) | |
food_1 = st.selectbox('Select the food type you prefer', st.session_state.food, key=1) | |
if food_1 == 'Other': | |
food_1 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=10) | |
ambiance_1 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=2) | |
if ambiance_1 == 'Other': | |
ambiance_1 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=11) | |
options_food_1 = st.multiselect( | |
'Do you have any dietary restrictions?', | |
['Vegan', 'Vegetarian', 'Halal'], key=100) | |
additional_1 = st.text_input(label="Your description", placeholder="Anything else you wanna share?", key=102) | |
with_kids = st.checkbox('I will come with kids', key=200) | |
st.markdown(css, unsafe_allow_html=True) | |
st.markdown(f'<div class="violet-box">User 2</div>', unsafe_allow_html=True) | |
food_2 = st.selectbox('Select the food type you prefer', st.session_state.food, key=3) | |
if food_2 == 'Other': | |
food_2 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=4) | |
ambiance_2 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=5) | |
if ambiance_2 == 'Other': | |
ambiance_2 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=6) | |
options_food_2 = st.multiselect( | |
'Do you have any dietary restrictions?', | |
['Vegan', 'Vegetarian', 'Halal'], key=7) | |
additional_2 = st.text_input(label="Your description", placeholder="Anything else you wanna share?", key=8) | |
with_kids_2 = st.checkbox('I will come with kids', key=201) | |
submitted = st.button('Submit!') | |
if submitted: | |
with st.spinner('Processing your request...'): | |
time.sleep(1) | |
if len(st.session_state.preferences_1) == 0: | |
st.session_state.preferences_1.append(food_1) | |
# if food_1 in st.session_state.food: | |
# st.session_state.preferences_1.append(food_1) | |
# else: | |
# st.session_state.additional_1.append(food_1_o) | |
st.session_state.preferences_1.append(ambiance_1) | |
# if ambiance_1 in st.session_state.ambiance: | |
# st.session_state.preferences_1.append(ambiance_1) | |
# else: | |
# st.session_state.additional_1.append(ambiance_1_o) | |
st.session_state.restrictions.extend(options_food_1) | |
if with_kids: | |
st.session_state.restrictions.append('kids') | |
if additional_1: | |
st.session_state.preferences_1.append(additional_1) | |
if len(st.session_state.preferences_2) == 0: | |
st.session_state.preferences_2.append(food_2) | |
# if food_2 in st.session_state.food: | |
# st.session_state.preferences_2.append(food_2) | |
# else: | |
# st.session_state.additional_2.append(food_2_o) | |
st.session_state.preferences_2.append(ambiance_2) | |
# if ambiance_2 in st.session_state.ambiance: | |
# st.session_state.preferences_2.append(ambiance_2) | |
# else: | |
# st.session_state.additional_2.append(ambiance_2_o) | |
st.session_state.restrictions.extend(options_food_2) | |
if additional_2: | |
st.session_state.preferences_2.append(additional_2) | |
if with_kids_2: | |
st.session_state.restrictions.append('kids') | |
st.success("Thanks, we received your preferences!") | |
else: | |
st.write('โ๏ธ Describe your preferences!') | |
submit = st.button("Find best matches!", type='primary') | |
if submit or (not st.session_state.precalculated_df.empty): | |
with st.spinner("Please wait while we are finding the best solution..."): | |
if st.session_state.precalculated_df.empty: | |
query = get_combined_preferences(st.session_state.preferences_1, st.session_state.preferences_2) | |
#sort places based on restrictions | |
st.session_state.precalculated_df = filter_places(st.session_state.restrictions) | |
st.session_state.fixed_restrictions = st.session_state.restrictions | |
#sort places by elevating preferrences | |
st.session_state.precalculated_df = promote_places() | |
st.session_state.precalculated_df = compute_cos_sim(query[0]) | |
sort_by = st.selectbox(('Sort by:'), st.session_state.options, key=400, | |
index=st.session_state.options.index('Relevancy (default)')) | |
if sort_by: | |
st.session_state.sort_by = sort_by | |
with st.spinner(f"Sorting your results by {sort_by.lower()}..."): | |
if len(st.session_state.results) == 0: | |
generate_results() | |
results = st.session_state.results[sort_by] | |
if sort_by == 'Distance': | |
st.write(":pensive: Sorry, we are still working on this option. For now, the results are sorted by relevance") | |
k = 10 | |
st.write(f"Here are the best {k} matches to your preferences:") | |
i = 1 | |
nums = list(range(1, 11)) | |
words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'one: :zero'] | |
nums_emojis = dict(zip(nums, words)) | |
results = get_normalized_val(results) | |
for name, score in results.items(): | |
condition = st.session_state.precalculated_df['Names'] == name | |
rating = st.session_state.precalculated_df.loc[condition, 'Rating'].values[0] | |
with st.expander(f":{nums_emojis[i]}: **{name}** **({str(rating)}**:star:): match score: {score}%"): | |
#f":{nums_emojis[i]}: **{name}** **({str(rating)}**:star:) :", 'match score:', score | |
try: | |
if type(st.session_state.precalculated_df.loc[condition, 'Price'].values[0]) == str: | |
st.write("Price category:", st.session_state.precalculated_df.loc[condition, 'Price'].values[0]) | |
except: | |
pass | |
descr = st.session_state.precalculated_df.loc[condition, 'Strings'].values[0] | |
for word in set([word.lower() for word in descr.split()]): | |
if word in [el.lower() for el in st.session_state.fixed_preferences]: | |
st.markdown(f'โ {word.capitalize()}') | |
if word in [el.lower() for el in st.session_state.fixed_restrictions]: | |
if word == 'kids': | |
st.markdown(f'โ Good for kids') | |
else: | |
st.markdown(f'โ {word.capitalize()}') | |
#Restaurant category | |
type = [item for item in eval(st.session_state.precalculated_df.loc[condition, 'Category'].values[0])] | |
st.markdown(text_css, unsafe_allow_html=True) | |
st.markdown('<div class="text">Category</div>', unsafe_allow_html=True) | |
# Display HTML with the custom styles | |
for word in type: | |
st.markdown(css, unsafe_allow_html=True) | |
st.markdown(f'<div class="blue-box">{word}</div>', unsafe_allow_html=True) | |
keywords = [item[0] for item in eval(st.session_state.precalculated_df.loc[condition, 'Keywords'].values[0]) if item[1] > 2] | |
if len(keywords) > 0: | |
st.markdown(text_css, unsafe_allow_html=True) | |
st.markdown('<div class="text">Other users say:</div>', unsafe_allow_html=True) | |
for pair in keywords[:3]: | |
st.markdown(css, unsafe_allow_html=True) | |
st.markdown(f'<div class="orange-box">{pair[0]} {pair[1]}</div>', unsafe_allow_html=True) | |
url = st.session_state.precalculated_df.loc[condition, 'URL'].values[0] | |
st.write(f"_Check on the_ [_map_]({url})") | |
# st.write(descr) | |
i+=1 | |
# st.markdown("This is a text with <span style='font-size: 20px;'>bigger</span> and <i>italic</i> text.", unsafe_allow_html=True) | |
# st.markdown("<span style='font-size: 24px;'>This is larger text</span>", unsafe_allow_html=True) | |
st.session_state.preferences_1, st.session_state.preferences_2 = [], [] | |
# st.session_state.restrictions = [] | |
stop = st.button("New search!", type='primary', key=500) | |
if stop: | |
st.write("New search is launched. Please specify your preferences in the form!") | |
st.session_state.preferences_1, st.session_state.preferences_2 = [], [] | |
st.session_state.restrictions = [] | |
st.session_state.additional_1, st.session_state.additional_2 = [], [] | |
st.session_state.sort_by = "" | |
st.session_state.df = init_df | |
st.session_state.precalculated_df = pd.DataFrame() | |
st.session_state.results = {} | |
st.session_state.fixed_preferences = [] | |