import numpy as np
import pandas as pd
import os
from PIL import Image
import streamlit as st
from streamlit import components
from datasets import Dataset, load_dataset, load_from_disk
import faiss
from scripts.preprocessing import preprocess
# App config
icon = Image.open('./images/wine_icon.png')
st.set_page_config(page_title="Sommeli-AI",
page_icon=icon,
layout="wide")
hide_default_format = """
"""
st.markdown(hide_default_format, unsafe_allow_html=True)
# App functions
@st.cache_data
def read_data(ds_path=None):
if ds_path is not None:
# Read in hf file
embeddings_dataset = load_from_disk(ds_path)
else:
embeddings_dataset = load_dataset("pdjewell/sommeli_ai", split="train")
# Convert to pandas df
embeddings_dataset.set_format("pandas")
df = embeddings_dataset[:]
# preprocess data (add type col, remove dups)
df = preprocess(df)
return df
def get_neighbours(df, query_embedding, k=6,
metric='inner'):
# convert from pandas df to hf ds
ds = Dataset.from_pandas(df)
ds.reset_format()
ds = ds.with_format("np")
# add faiss index
if metric == 'inner':
ds.add_faiss_index(column="embeddings",
metric_type=faiss.METRIC_INNER_PRODUCT)
else:
ds.add_faiss_index(column="embeddings",
metric_type=faiss.METRIC_L2)
scores, samples = ds.get_nearest_examples(
"embeddings", query_embedding, k=k)
samples.pop('embeddings')
samples.pop('__index_level_0__')
return scores, samples
def filter_df_search(df: pd.DataFrame) -> pd.DataFrame:
modify_search = st.checkbox("🔍 Further filter search selection")
if not modify_search:
return df
df = df.copy()
modification_container_search = st.container()
with modification_container_search:
to_filter_columns = st.multiselect("Filter on:",
['Province', 'Region', 'Winery','Score', 'Price'],
key='search')
for column in to_filter_columns:
if column in ['Score', 'Price']: # Use slider for 'points' and 'price'
min_val = 0
max_val = int(df[column].max())
user_input = st.slider(f"Values for {column}", min_val, max_val, (min_val, max_val))
df = df[(df[column] >= user_input[0]) & (df[column] <= user_input[1])]
elif column in ['Country', 'Province', 'Region', 'Variety', 'Winery']: # Use multiselect for these columns
unique_values = df[column].dropna().unique()
default_values = [unique_values[0]] if len(unique_values) > 0 else [] # Select only the first unique value if it exists
user_input = st.multiselect(f"Values for {column}", unique_values, default_values)
df = df[df[column].isin(user_input)]
return df
def filter_df_recs(df: pd.DataFrame) -> pd.DataFrame:
modify_recs = st.checkbox("🔍 Filter recommendation results")
if not modify_recs:
return df
df = df.copy()
modification_container_recs = st.container()
with modification_container_recs:
to_filter_columns2 = st.multiselect("Filter on:",
['Country','Province', 'Region', 'Variety', 'Winery',
'Score', 'Price'],
key='recs')
for column in to_filter_columns2:
if column in ['Score', 'Price']: # Use slider for 'points' and 'price'
min_val = 0
max_val = int(df[column].max())
user_input = st.slider(f"Values for {column}", min_val, max_val, (min_val, max_val))
df = df[(df[column] >= user_input[0]) & (df[column] <= user_input[1])]
elif column in ['Country', 'Province', 'Region', 'Variety', 'Winery']: # Use multiselect for these columns
unique_values = df[column].dropna().unique()
default_values = [unique_values[0]] if len(unique_values) > 0 else [] # Select only the first unique value if it exists
user_input = st.multiselect(f"Values for {column}", unique_values, default_values)
df = df[df[column].isin(user_input)]
return df
if __name__ == "__main__":
st.title("🍷 Sommeli-AI")
# Read in data
ds_path = "./data/wine_ds.hf"
df = read_data(ds_path=None)
maincol, acol = st.columns([0.999,0.001])
with maincol:
col1, col2 = st.columns([0.65,0.35], gap="medium")
with col2:
st.header("Explore the world of wine 🌍")
wine_plot = st.radio('Select plot type:', ['2D','3D'],
label_visibility = "hidden",
horizontal=True)
st.text("Click the legend categories to filter")
# Load the HTML file
with open('./images/px_2d.html', 'r') as file:
plot2d_html = file.read()
# Load the HTML file
with open('./images/px_3d.html', 'r') as file:
plot3d_html = file.read()
# Display the HTML plot in the Streamlit app
if wine_plot == '2D':
components.v1.html(plot2d_html, width=512, height=512)
elif wine_plot == '3D':
components.v1.html(plot3d_html, width=512, height=512)
with col1:
# Select all wine types initially
st.header("Search for similar wines 🥂")
# Select wine type: default is all
wine_types = df['Type'].unique()
selected_wine_types = st.multiselect("Select category 👇", wine_types, default=wine_types)
df = df[df['Type'].isin(selected_wine_types)]
#subcol1, subcol2 = st.columns([0.5,0.5], gap="small")
#with subcol1:
# Select wine variety: default is all
wine_vars = df['Variety'].unique()
selected_wine_vars = st.multiselect("Narrow down the variety 🍇",['Select all'] + list(wine_vars),
default = 'Select all')
if "Select all" in selected_wine_vars:
df_search = df
else:
df_search = df[df['Variety'].isin(selected_wine_vars)]
#with subcol2:
# Select the country: default is all
countries = df_search['Country'].unique()
selected_countries = st.multiselect("Narrow down the country 🌎",['Select all'] + list(countries),
default = 'Select all')
if "Select all" in selected_countries:
df_search = df_search
else:
df_search = df_search[df_search['Country'].isin(selected_countries)]
# Add additional filters
df_search = filter_df_search(df_search)
# Create a search bar for the wine 'title'
selected_wine = st.selectbox("Search for and select a wine 👇", [''] + list(df_search["Title"].unique()))
if selected_wine:
# Get the embedding for selected_wine
query_embedding = df.loc[df['Title']==selected_wine, 'embeddings'].iloc[0]
tasting_notes = df.loc[df['Title']==selected_wine, 'Tasting notes'].iloc[0]
st.write(f"Tasting notes: {tasting_notes}")
# CSS to inject contained in a string
hide_table_row_index = """
"""
# Inject CSS with Markdown
st.markdown(hide_table_row_index, unsafe_allow_html=True)
# Display selected wine
st.header("Your selected wine 🍷")
selected_cols = ['Title','Country','Province','Region','Winery',
'Variety','Tasting notes','Score']
st.table(df.loc[df['Title']==selected_wine, selected_cols].fillna(""))
# Slider for results to show
k = st.slider(f"Choose how many similar wines to show 👇", 1, 10, value=4)
# Filter recommendation results
df_results = filter_df_recs(df)
else:
print("Awaiting selection")
if selected_wine:
# Display results as table
if st.button("🔘 Press me to generate similar tasting wines"):
# Get neighbours
scores, samples = get_neighbours(df_results, query_embedding,
k=k+1, metric='l2')
recs_df = pd.DataFrame(samples).fillna("")
recs_df = recs_df.fillna(" ")
# Display results
st.header(f"Top {k} similar tasting wines 🍾")
st.table(recs_df.loc[1:,selected_cols])
else:
print("Awaiting selection")