File size: 3,784 Bytes
b9fc179
 
 
 
 
01c6427
b9fc179
3956336
b9fc179
f65d404
492193c
5760476
492193c
 
 
5760476
492193c
 
 
 
 
 
 
 
 
b9fc179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import streamlit as st
import plotly.graph_objects as go
import networkx as nx
import pandas as pd

st.set_page_config(layout="wide")
# Load the CSV file
df = pd.read_csv('spices_by_cuisine_with_all_flavors.csv', index_col=0)

st.dataframe(df)

pivot = df.drop(columns=['Flavor Description']).rename(columns={"Unnamed: 0": "Spice"}).set_index("Spice")

spices = {}
for col in pivot.columns:
    filter = pivot[col] == 1
    spices[col] = pivot[filter].index.to_list()


option = st.selectbox(
    'How would you like to be contacted?',['test', 'hello'])

st.write('You selected:', option)


# Create a graph
G = nx.Graph()

# Add nodes for each cuisine and spice, and edges based on the DataFrame
for col in df.columns:
    if col != "Flavor Description":
        G.add_node(col, type='cuisine')
        spices_for_cuisine = df[df[col] == 1].index.tolist()
        for spice in spices_for_cuisine:
            G.add_node(spice, type='spice')
            G.add_edge(col, spice)

# Get node positions using the spring layout
pos = nx.spring_layout(G)

# Create edge trace
edge_trace = go.Scatter(
    x=[],
    y=[],
    line=dict(width=0.5, color='#888'),
    hoverinfo='none',
    mode='lines')

for edge in G.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_trace['x'] += tuple([x0, x1, None])
    edge_trace['y'] += tuple([y0, y1, None])

# Assign a unique color to each cuisine
cuisine_colors = {cuisine: f"hsl({i * (360 // len(df.columns[:-1]))}, 80%, 50%)" 
                  for i, cuisine in enumerate(df.columns) if cuisine != "Flavor Description"}


# Create node trace for cuisines
node_trace_cuisines = go.Scatter(
    x=[],
    y=[],
    text=[],
    hovertext=[],
    mode='markers+text',
    hoverinfo='text',
    marker=dict(
        showscale=False,
        size=20,
        color=[],
        line=dict(width=0)))

# Create node trace for spices
node_trace_spices = go.Scatter(
    x=[],
    y=[],
    text=[],
    hovertext=[],
    mode='markers+text',
    hoverinfo='text',
    marker=dict(
        showscale=False,
        color='grey',
        size=10,
        line=dict(width=0)))

for node in G.nodes():
    x, y = pos[node]
    if G.nodes[node]['type'] == 'cuisine':
        node_trace_cuisines['x'] += tuple([x])
        node_trace_cuisines['y'] += tuple([y])
        node_trace_cuisines['text'] += tuple([node])
        node_trace_cuisines['marker']['color'] += tuple([cuisine_colors[node]])
        
        # Collect all spices associated with this cuisine
        spices_associated = df[df[node] == 1].index.tolist()
        hover_text = f"{node} uses: {', '.join(spices_associated)}"
        node_trace_cuisines['hovertext'] += tuple([hover_text])
        
    else:
        node_trace_spices['x'] += tuple([x])
        node_trace_spices['y'] += tuple([y])
        node_trace_spices['text'] += tuple([node])
        
        # Collect all cuisines that use this spice
        cuisines_using_spice = df.columns[df.loc[node] == 1].tolist()
        hover_text = f"{node} is used in: {', '.join(cuisines_using_spice)}"
        node_trace_spices['hovertext'] += tuple([hover_text])

# Create the network graph figure with updated hover information
fig = go.Figure(data=[edge_trace, node_trace_cuisines, node_trace_spices],
                layout=go.Layout(
                    title="Network Graph of Cuisines and their Spices",
                    titlefont_size=16,
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20, l=5, r=5, t=40),
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )

st.plotly_chart(fig, use_container_width=True)