pdjewell commited on
Commit
e564ec1
β€’
1 Parent(s): 0e36e26

updated cols in app

Browse files
Files changed (2) hide show
  1. app.py +85 -79
  2. app_old.py +236 -0
app.py CHANGED
@@ -132,41 +132,43 @@ def filter_df_recs(df: pd.DataFrame) -> pd.DataFrame:
132
 
133
  if __name__ == "__main__":
134
  st.title("🍷 Sommeli-AI")
135
- col1, col2 = st.columns([0.6,0.4], gap="medium")
136
-
137
  # Read in data
138
  ds_path = "./data/wine_ds.hf"
139
  df = read_data(ds_path=None)
140
 
141
- with col2:
142
- st.header("Explore the world of wine 🌍")
143
- wine_plot = st.radio('Select plot type:', ['2D','3D'],
144
- label_visibility = "hidden",
145
- horizontal=True)
146
- st.text("Click the legend categories to filter")
147
-
148
- # Load the HTML file
149
- with open('./images/px_2d.html', 'r') as file:
150
- plot2d_html = file.read()
151
- # Load the HTML file
152
- with open('./images/px_3d.html', 'r') as file:
153
- plot3d_html = file.read()
154
- # Display the HTML plot in the Streamlit app
155
- if wine_plot == '2D':
156
- components.v1.html(plot2d_html, width=512, height=512)
157
- elif wine_plot == '3D':
158
- components.v1.html(plot3d_html, width=512, height=512)
159
-
160
- with col1:
161
-
162
- # Select all wine types initially
163
- st.header("Search for similar wines πŸ₯‚")
164
- # Select wine type: default is all
165
- wine_types = df['Type'].unique()
166
- selected_wine_types = st.multiselect("Select category πŸ‘‡", wine_types, default=wine_types)
167
- df = df[df['Type'].isin(selected_wine_types)]
168
- subcol1, subcol2 = st.columns([0.5,0.5], gap="small")
169
- with subcol1:
 
 
 
170
  # Select wine variety: default is all
171
  wine_vars = df['Variety'].unique()
172
  selected_wine_vars = st.multiselect("Narrow down the variety πŸ‡",['Select all'] + list(wine_vars),
@@ -175,8 +177,8 @@ if __name__ == "__main__":
175
  df_search = df
176
  else:
177
  df_search = df[df['Variety'].isin(selected_wine_vars)]
178
-
179
- with subcol2:
180
  # Select the country: default is all
181
  countries = df_search['Country'].unique()
182
  selected_countries = st.multiselect("Narrow down the country 🌎",['Select all'] + list(countries),
@@ -186,51 +188,55 @@ if __name__ == "__main__":
186
  else:
187
  df_search = df_search[df_search['Country'].isin(selected_countries)]
188
 
189
- # Add additional filters
190
- df_search = filter_df_search(df_search)
191
 
192
- # Create a search bar for the wine 'title'
193
- selected_wine = st.selectbox("Search for and select a wine πŸ‘‡", [''] + list(df_search["Title"].unique()))
 
 
 
 
 
 
 
194
 
195
- if selected_wine:
196
- # Get the embedding for selected_wine
197
- query_embedding = df.loc[df['Title']==selected_wine, 'embeddings'].iloc[0]
198
-
199
- tasting_notes = df.loc[df['Title']==selected_wine, 'Tasting notes'].iloc[0]
200
- st.write(f"Tasting notes: {tasting_notes}")
201
-
202
- # CSS to inject contained in a string
203
- hide_table_row_index = """
204
- <style>
205
- thead tr th:first-child {display:none}
206
- tbody th {display:none}
207
- </style>
208
- """
209
- # Inject CSS with Markdown
210
- st.markdown(hide_table_row_index, unsafe_allow_html=True)
211
-
212
- # Display selected wine
213
- st.header(" 🍷 Your selected wine")
214
- selected_cols = ['Title','Country','Province','Region','Winery',
215
- 'Variety','Tasting notes','Score']
216
- st.table(df.loc[df['Title']==selected_wine, selected_cols].fillna(""))
217
-
218
- # Slider for results to show
219
- k = st.slider(f"Choose how many similar wines to show πŸ‘‡", 1, 10, value=4)
220
-
221
- # Filter recommendation results
222
- df_results = filter_df_recs(df)
223
-
224
- # Display results as table
225
- if st.button("πŸ”˜ Press me to generate similar tasting wines"):
226
- # Get neighbours
227
- scores, samples = get_neighbours(df_results, query_embedding,
228
- k=k+1, metric='l2')
229
- recs_df = pd.DataFrame(samples).fillna("")
230
- recs_df = recs_df.fillna(" ")
231
- # Display results
232
- st.header(f"🍾 Top {k} similar tasting wines")
233
- st.table(recs_df.loc[1:,selected_cols])
234
-
235
- else:
236
- print("Awaiting selection")
 
132
 
133
  if __name__ == "__main__":
134
  st.title("🍷 Sommeli-AI")
135
+
 
136
  # Read in data
137
  ds_path = "./data/wine_ds.hf"
138
  df = read_data(ds_path=None)
139
 
140
+ maincol, acol = st.columns([0.999,0.001])
141
+ with maincol:
142
+ col1, col2 = st.columns([0.65,0.35], gap="medium")
143
+ with col2:
144
+ st.header("Explore the world of wine 🌍")
145
+ wine_plot = st.radio('Select plot type:', ['2D','3D'],
146
+ label_visibility = "hidden",
147
+ horizontal=True)
148
+ st.text("Click the legend categories to filter")
149
+
150
+ # Load the HTML file
151
+ with open('./images/px_2d.html', 'r') as file:
152
+ plot2d_html = file.read()
153
+ # Load the HTML file
154
+ with open('./images/px_3d.html', 'r') as file:
155
+ plot3d_html = file.read()
156
+ # Display the HTML plot in the Streamlit app
157
+ if wine_plot == '2D':
158
+ components.v1.html(plot2d_html, width=512, height=512)
159
+ elif wine_plot == '3D':
160
+ components.v1.html(plot3d_html, width=512, height=512)
161
+
162
+ with col1:
163
+
164
+ # Select all wine types initially
165
+ st.header("Search for similar wines πŸ₯‚")
166
+ # Select wine type: default is all
167
+ wine_types = df['Type'].unique()
168
+ selected_wine_types = st.multiselect("Select category πŸ‘‡", wine_types, default=wine_types)
169
+ df = df[df['Type'].isin(selected_wine_types)]
170
+ #subcol1, subcol2 = st.columns([0.5,0.5], gap="small")
171
+ #with subcol1:
172
  # Select wine variety: default is all
173
  wine_vars = df['Variety'].unique()
174
  selected_wine_vars = st.multiselect("Narrow down the variety πŸ‡",['Select all'] + list(wine_vars),
 
177
  df_search = df
178
  else:
179
  df_search = df[df['Variety'].isin(selected_wine_vars)]
180
+
181
+ #with subcol2:
182
  # Select the country: default is all
183
  countries = df_search['Country'].unique()
184
  selected_countries = st.multiselect("Narrow down the country 🌎",['Select all'] + list(countries),
 
188
  else:
189
  df_search = df_search[df_search['Country'].isin(selected_countries)]
190
 
191
+ # Add additional filters
192
+ df_search = filter_df_search(df_search)
193
 
194
+ # Create a search bar for the wine 'title'
195
+ selected_wine = st.selectbox("Search for and select a wine πŸ‘‡", [''] + list(df_search["Title"].unique()))
196
+
197
+ if selected_wine:
198
+ # Get the embedding for selected_wine
199
+ query_embedding = df.loc[df['Title']==selected_wine, 'embeddings'].iloc[0]
200
+
201
+ tasting_notes = df.loc[df['Title']==selected_wine, 'Tasting notes'].iloc[0]
202
+ st.write(f"Tasting notes: {tasting_notes}")
203
 
204
+ # CSS to inject contained in a string
205
+ hide_table_row_index = """
206
+ <style>
207
+ thead tr th:first-child {display:none}
208
+ tbody th {display:none}
209
+ </style>
210
+ """
211
+ # Inject CSS with Markdown
212
+ st.markdown(hide_table_row_index, unsafe_allow_html=True)
213
+ # Display selected wine
214
+ st.header("Your selected wine 🍷")
215
+ selected_cols = ['Title','Country','Province','Region','Winery',
216
+ 'Variety','Tasting notes','Score']
217
+ st.table(df.loc[df['Title']==selected_wine, selected_cols].fillna(""))
218
+ # Slider for results to show
219
+ k = st.slider(f"Choose how many similar wines to show πŸ‘‡", 1, 10, value=4)
220
+
221
+ # Filter recommendation results
222
+ df_results = filter_df_recs(df)
223
+
224
+ else:
225
+ print("Awaiting selection")
226
+
227
+ if selected_wine:
228
+ # Display results as table
229
+ if st.button("πŸ”˜ Press me to generate similar tasting wines"):
230
+ # Get neighbours
231
+ scores, samples = get_neighbours(df_results, query_embedding,
232
+ k=k+1, metric='l2')
233
+ recs_df = pd.DataFrame(samples).fillna("")
234
+ recs_df = recs_df.fillna(" ")
235
+ # Display results
236
+ st.header(f"Top {k} similar tasting wines 🍾")
237
+ st.table(recs_df.loc[1:,selected_cols])
238
+
239
+ else:
240
+ print("Awaiting selection")
241
+
242
+
 
 
 
app_old.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import os
4
+ from PIL import Image
5
+ import streamlit as st
6
+ from streamlit import components
7
+ from datasets import Dataset, load_dataset, load_from_disk
8
+ import faiss
9
+ from scripts.preprocessing import preprocess
10
+
11
+ # App config
12
+ icon = Image.open('./images/wine_icon.png')
13
+ st.set_page_config(page_title="Sommeli-AI",
14
+ page_icon=icon,
15
+ layout="wide")
16
+ hide_default_format = """
17
+ <style>
18
+ #MainMenu {visibility: visible; }
19
+ footer {visibility: hidden;}
20
+ </style>
21
+ """
22
+ st.markdown(hide_default_format, unsafe_allow_html=True)
23
+
24
+ # App functions
25
+ @st.cache_data
26
+ def read_data(ds_path=None):
27
+
28
+ if ds_path is not None:
29
+ # Read in hf file
30
+ embeddings_dataset = load_from_disk(ds_path)
31
+ else:
32
+ embeddings_dataset = load_dataset("pdjewell/sommeli_ai", split="train")
33
+
34
+ # Convert to pandas df
35
+ embeddings_dataset.set_format("pandas")
36
+ df = embeddings_dataset[:]
37
+
38
+ # preprocess data (add type col, remove dups)
39
+ df = preprocess(df)
40
+
41
+ return df
42
+
43
+
44
+ def get_neighbours(df, query_embedding, k=6,
45
+ metric='inner'):
46
+
47
+ # convert from pandas df to hf ds
48
+ ds = Dataset.from_pandas(df)
49
+ ds.reset_format()
50
+ ds = ds.with_format("np")
51
+
52
+ # add faiss index
53
+ if metric == 'inner':
54
+ ds.add_faiss_index(column="embeddings",
55
+ metric_type=faiss.METRIC_INNER_PRODUCT)
56
+ else:
57
+ ds.add_faiss_index(column="embeddings",
58
+ metric_type=faiss.METRIC_L2)
59
+
60
+ scores, samples = ds.get_nearest_examples(
61
+ "embeddings", query_embedding, k=k)
62
+
63
+ samples.pop('embeddings')
64
+ samples.pop('__index_level_0__')
65
+
66
+ return scores, samples
67
+
68
+
69
+ def filter_df_search(df: pd.DataFrame) -> pd.DataFrame:
70
+
71
+ modify_search = st.checkbox("πŸ” Further filter search selection")
72
+
73
+ if not modify_search:
74
+ return df
75
+
76
+ df = df.copy()
77
+
78
+ modification_container_search = st.container()
79
+
80
+ with modification_container_search:
81
+ to_filter_columns = st.multiselect("Filter on:",
82
+ ['Province', 'Region', 'Winery','Score', 'Price'],
83
+ key='search')
84
+
85
+ for column in to_filter_columns:
86
+ if column in ['Score', 'Price']: # Use slider for 'points' and 'price'
87
+ min_val = 0
88
+ max_val = int(df[column].max())
89
+ user_input = st.slider(f"Values for {column}", min_val, max_val, (min_val, max_val))
90
+ df = df[(df[column] >= user_input[0]) & (df[column] <= user_input[1])]
91
+ elif column in ['Country', 'Province', 'Region', 'Variety', 'Winery']: # Use multiselect for these columns
92
+ unique_values = df[column].dropna().unique()
93
+ default_values = [unique_values[0]] if len(unique_values) > 0 else [] # Select only the first unique value if it exists
94
+ user_input = st.multiselect(f"Values for {column}", unique_values, default_values)
95
+ df = df[df[column].isin(user_input)]
96
+
97
+ return df
98
+
99
+
100
+ def filter_df_recs(df: pd.DataFrame) -> pd.DataFrame:
101
+
102
+ modify_recs = st.checkbox("πŸ” Filter recommendation results")
103
+
104
+ if not modify_recs:
105
+ return df
106
+
107
+ df = df.copy()
108
+
109
+ modification_container_recs = st.container()
110
+
111
+ with modification_container_recs:
112
+
113
+ to_filter_columns2 = st.multiselect("Filter on:",
114
+ ['Country','Province', 'Region', 'Variety', 'Winery',
115
+ 'Score', 'Price'],
116
+ key='recs')
117
+
118
+ for column in to_filter_columns2:
119
+ if column in ['Score', 'Price']: # Use slider for 'points' and 'price'
120
+ min_val = 0
121
+ max_val = int(df[column].max())
122
+ user_input = st.slider(f"Values for {column}", min_val, max_val, (min_val, max_val))
123
+ df = df[(df[column] >= user_input[0]) & (df[column] <= user_input[1])]
124
+ elif column in ['Country', 'Province', 'Region', 'Variety', 'Winery']: # Use multiselect for these columns
125
+ unique_values = df[column].dropna().unique()
126
+ default_values = [unique_values[0]] if len(unique_values) > 0 else [] # Select only the first unique value if it exists
127
+ user_input = st.multiselect(f"Values for {column}", unique_values, default_values)
128
+ df = df[df[column].isin(user_input)]
129
+
130
+ return df
131
+
132
+
133
+ if __name__ == "__main__":
134
+ st.title("🍷 Sommeli-AI")
135
+ col1, col2 = st.columns([0.6,0.4], gap="medium")
136
+
137
+ # Read in data
138
+ ds_path = "./data/wine_ds.hf"
139
+ df = read_data(ds_path=None)
140
+
141
+ with col2:
142
+ st.header("Explore the world of wine 🌍")
143
+ wine_plot = st.radio('Select plot type:', ['2D','3D'],
144
+ label_visibility = "hidden",
145
+ horizontal=True)
146
+ st.text("Click the legend categories to filter")
147
+
148
+ # Load the HTML file
149
+ with open('./images/px_2d.html', 'r') as file:
150
+ plot2d_html = file.read()
151
+ # Load the HTML file
152
+ with open('./images/px_3d.html', 'r') as file:
153
+ plot3d_html = file.read()
154
+ # Display the HTML plot in the Streamlit app
155
+ if wine_plot == '2D':
156
+ components.v1.html(plot2d_html, width=512, height=512)
157
+ elif wine_plot == '3D':
158
+ components.v1.html(plot3d_html, width=512, height=512)
159
+
160
+ with col1:
161
+
162
+ # Select all wine types initially
163
+ st.header("Search for similar wines πŸ₯‚")
164
+ # Select wine type: default is all
165
+ wine_types = df['Type'].unique()
166
+ selected_wine_types = st.multiselect("Select category πŸ‘‡", wine_types, default=wine_types)
167
+ df = df[df['Type'].isin(selected_wine_types)]
168
+ subcol1, subcol2 = st.columns([0.5,0.5], gap="small")
169
+ with subcol1:
170
+ # Select wine variety: default is all
171
+ wine_vars = df['Variety'].unique()
172
+ selected_wine_vars = st.multiselect("Narrow down the variety πŸ‡",['Select all'] + list(wine_vars),
173
+ default = 'Select all')
174
+ if "Select all" in selected_wine_vars:
175
+ df_search = df
176
+ else:
177
+ df_search = df[df['Variety'].isin(selected_wine_vars)]
178
+
179
+ with subcol2:
180
+ # Select the country: default is all
181
+ countries = df_search['Country'].unique()
182
+ selected_countries = st.multiselect("Narrow down the country 🌎",['Select all'] + list(countries),
183
+ default = 'Select all')
184
+ if "Select all" in selected_countries:
185
+ df_search = df_search
186
+ else:
187
+ df_search = df_search[df_search['Country'].isin(selected_countries)]
188
+
189
+ # Add additional filters
190
+ df_search = filter_df_search(df_search)
191
+
192
+ # Create a search bar for the wine 'title'
193
+ selected_wine = st.selectbox("Search for and select a wine πŸ‘‡", [''] + list(df_search["Title"].unique()))
194
+
195
+ if selected_wine:
196
+ # Get the embedding for selected_wine
197
+ query_embedding = df.loc[df['Title']==selected_wine, 'embeddings'].iloc[0]
198
+
199
+ tasting_notes = df.loc[df['Title']==selected_wine, 'Tasting notes'].iloc[0]
200
+ st.write(f"Tasting notes: {tasting_notes}")
201
+
202
+ # CSS to inject contained in a string
203
+ hide_table_row_index = """
204
+ <style>
205
+ thead tr th:first-child {display:none}
206
+ tbody th {display:none}
207
+ </style>
208
+ """
209
+ # Inject CSS with Markdown
210
+ st.markdown(hide_table_row_index, unsafe_allow_html=True)
211
+
212
+ # Display selected wine
213
+ st.header(" 🍷 Your selected wine")
214
+ selected_cols = ['Title','Country','Province','Region','Winery',
215
+ 'Variety','Tasting notes','Score']
216
+ st.table(df.loc[df['Title']==selected_wine, selected_cols].fillna(""))
217
+
218
+ # Slider for results to show
219
+ k = st.slider(f"Choose how many similar wines to show πŸ‘‡", 1, 10, value=4)
220
+
221
+ # Filter recommendation results
222
+ df_results = filter_df_recs(df)
223
+
224
+ # Display results as table
225
+ if st.button("πŸ”˜ Press me to generate similar tasting wines"):
226
+ # Get neighbours
227
+ scores, samples = get_neighbours(df_results, query_embedding,
228
+ k=k+1, metric='l2')
229
+ recs_df = pd.DataFrame(samples).fillna("")
230
+ recs_df = recs_df.fillna(" ")
231
+ # Display results
232
+ st.header(f"🍾 Top {k} similar tasting wines")
233
+ st.table(recs_df.loc[1:,selected_cols])
234
+
235
+ else:
236
+ print("Awaiting selection")