cuisines / app.py
adrianpierce's picture
Update app.py
0d65562
raw
history blame contribute delete
No virus
5.51 kB
import streamlit as st
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import networkx as nx
import random
random.seed(42)
np.random.seed(42)
st.set_page_config(layout="wide")
# Load and process data
df = pd.read_csv('spices_by_cuisine_with_all_flavors.csv', index_col=0)
pivot = df.drop(columns=['Flavor Description']).sort_index()
cuisines = {}
for col in pivot.columns:
filter = pivot[col] == 1
cuisines[col] = pivot[filter].index.to_list()
spices = {}
pivot_t = pivot.T.sort_index()
for col in pivot_t.columns:
filter = pivot_t[col] == 1
spices[col] = pivot_t[filter].index.to_list()
def similarity(ratings, kind='user', epsilon=1e-9):
if kind == 'user':
sim = ratings.dot(ratings.T) + epsilon
elif kind == 'item':
sim = ratings.T.dot(ratings) + epsilon
norms = np.array([np.sqrt(np.diagonal(sim))])
return (sim / norms / norms.T)
pivot_names = pivot_t.columns
pivot_np = np.array(pivot_t)
cuisine_similarity = pd.DataFrame(similarity(pivot_np, kind='user'))
cuisine_similarity.columns = pivot_t.index.values
cuisine_similarity.index = pivot_t.index.values
st.title('Spices Across Cuisines')
col1, col2, col3 = st.columns(3)
with col1:
st.subheader('By Cuisine')
select_cuisine = st.selectbox('Select a cuisine to view the top 10 spices',cuisines.keys())
st.write(f'The top 10 ingredients in {select_cuisine} are:', cuisines[select_cuisine])
with col2:
st.subheader('By Spice')
select_spice = st.selectbox('Select a spice to view which cuisines it is present in',spices.keys())
st.write(f'{select_spice} is part of the following cuisines:', spices[select_spice])
with col3:
st.subheader("Similar Cuisines")
select_cuisine_sim = st.selectbox('Select a cuisine to view the 10 most similar cuisines by spices',cuisines.keys())
st.write(f'{select_cuisine_sim} is most similar to:', cuisine_similarity[select_cuisine_sim].sort_values(ascending=False).index[1:11].to_list())
count = pd.DataFrame(pivot.T.sum().sort_values(ascending=False).reset_index().rename(columns={0: "Count"}))
fig_bar = px.bar(count, x="Spice", y="Count", title="Most Frequently Occuring Spices Across Cuisines")
st.plotly_chart(fig_bar, use_container_width=True)
# Create a graph
G = nx.Graph()
# Add nodes for each cuisine and spice, and edges based on the DataFrame
for col in df.columns:
if col != "Flavor Description":
G.add_node(col, type='cuisine')
spices_for_cuisine = df[df[col] == 1].index.tolist()
for spice in spices_for_cuisine:
G.add_node(spice, type='spice')
G.add_edge(col, spice)
# Get node positions using the spring layout
pos = nx.spring_layout(G)
# Create edge trace
edge_trace = go.Scatter(
x=[],
y=[],
line=dict(width=0.5, color='#888'),
hoverinfo='none',
mode='lines')
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace['x'] += tuple([x0, x1, None])
edge_trace['y'] += tuple([y0, y1, None])
# Assign a unique color to each cuisine
cuisine_colors = {cuisine: f"hsl({i * (360 // len(df.columns[:-1]))}, 80%, 50%)"
for i, cuisine in enumerate(df.columns) if cuisine != "Flavor Description"}
# Create node trace for cuisines
node_trace_cuisines = go.Scatter(
x=[],
y=[],
text=[],
hovertext=[],
mode='markers+text',
hoverinfo='text',
marker=dict(
showscale=False,
size=20,
color=[],
line=dict(width=0)))
# Create node trace for spices
node_trace_spices = go.Scatter(
x=[],
y=[],
text=[],
hovertext=[],
mode='markers+text',
hoverinfo='text',
marker=dict(
showscale=False,
color='grey',
size=10,
line=dict(width=0)))
for node in G.nodes():
x, y = pos[node]
if G.nodes[node]['type'] == 'cuisine':
node_trace_cuisines['x'] += tuple([x])
node_trace_cuisines['y'] += tuple([y])
node_trace_cuisines['text'] += tuple([node])
node_trace_cuisines['marker']['color'] += tuple([cuisine_colors[node]])
# Collect all spices associated with this cuisine
spices_associated = df[df[node] == 1].index.tolist()
hover_text = f"{node} uses: {', '.join(spices_associated)}"
node_trace_cuisines['hovertext'] += tuple([hover_text])
else:
node_trace_spices['x'] += tuple([x])
node_trace_spices['y'] += tuple([y])
node_trace_spices['text'] += tuple([node])
# Collect all cuisines that use this spice
cuisines_using_spice = df.columns[df.loc[node] == 1].tolist()
hover_text = f"{node} is used in: {', '.join(cuisines_using_spice)}"
node_trace_spices['hovertext'] += tuple([hover_text])
# Create the network graph figure with updated hover information
fig_graph = go.Figure(data=[edge_trace, node_trace_cuisines, node_trace_spices],
layout=go.Layout(
title="Network Graph of Cuisines and their Spices",
titlefont_size=16,
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
)
st.plotly_chart(fig_graph, use_container_width=True)