|
import streamlit as st |
|
import os |
|
import faiss |
|
import numpy as np |
|
import pandas as pd |
|
from sentence_transformers import SentenceTransformer |
|
import pickle |
|
from langchain_huggingface import HuggingFaceEndpoint |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('src/paraphrase-mpnet-base-v2') |
|
model = AutoModel.from_pretrained('src/paraphrase-mpnet-base-v2') |
|
|
|
def mean_pooling(model_output, attention_mask): |
|
token_embeddings = model_output[0] |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
def encode(sentences): |
|
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') |
|
with torch.no_grad(): |
|
model_output = model(**encoded_input) |
|
return mean_pooling(model_output, encoded_input['attention_mask']).cpu().numpy() |
|
|
|
|
|
def create_fragrance_card(name, rating, brand, perfumer_text, top_notes, middle_notes, base_notes, accords_text, explanation): |
|
|
|
card_html = f""" |
|
<div style="border: 1px solid #ddd; padding: 15px; margin: 10px; border-radius: 15px; |
|
background: linear-gradient(to bottom right, #ffffff, #f2f6fc); |
|
width: 400px; color: #222; box-shadow: 0 4px 8px rgba(0,0,0,0.1);"> |
|
<h3 style="color: #3a3a3a; text-align: center;">{name} β{rating}</h3> |
|
<p><strong>π·οΈ Brand:</strong> {brand}</p> |
|
<p><strong>π Perfumer(s):</strong> {perfumer_text}</p> |
|
<p><strong>πΏ Top Notes:</strong> {top_notes}</p> |
|
<p><strong>π Heart Notes:</strong> {middle_notes}</p> |
|
<p><strong>π² Base Notes:</strong> {base_notes}</p> |
|
<p><strong>πΌ Main Accords:</strong> {accords_text}</p> |
|
<p><strong>π‘ AI Explanation:</strong> {explanation}</p> |
|
</div> |
|
""" |
|
return card_html |
|
|
|
|
|
@st.cache_resource |
|
def load_resources(): |
|
index = faiss.read_index('src/fragrance_faiss.index') |
|
with open('src/fragrance_metadata.pkl', 'rb') as f: |
|
metadata = pickle.load(f) |
|
return index, metadata |
|
|
|
|
|
def get_ollama_explanation(query, description): |
|
prompt = f""" |
|
A user is searching for a fragrance with this description: "{query}" |
|
|
|
One recommendation is: |
|
{description} |
|
|
|
Explain in 1-2 sentences, in plain English, why this fragrance matches the user's query. |
|
""" |
|
response = llm.invoke(prompt) |
|
return response.strip() |
|
|
|
|
|
llm = HuggingFaceEndpoint( |
|
repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", |
|
task="text-generation", |
|
huggingfacehub_api_token=os.environ["LLM_TOKEN"] |
|
) |
|
|
|
|
|
st.set_page_config(page_title="Fragrance Recommendation System", layout="wide") |
|
|
|
|
|
st.title("Fragrance Recommendation System") |
|
|
|
|
|
st.sidebar.header("Filters") |
|
query = st.text_input("Describe your ideal fragrance:") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
k = st.slider("Number of recommendations:", 1, 10, 5) |
|
with col2: |
|
min_rating = st.slider("Minimum rating:", 1.0, 5.0, 3.5) |
|
|
|
gender_filter = st.sidebar.selectbox("Gender:", ["All", "Male", "Female", "Unisex"]) |
|
brand_filter = st.sidebar.text_input("Brand (leave empty for all):", "").title() |
|
note_filter = st.sidebar.text_input("Notes (comma-separated):", "").lower() |
|
|
|
|
|
index, metadata = load_resources() |
|
|
|
|
|
if 'rating_value' in metadata.columns: |
|
metadata['rating_value'] = pd.to_numeric( |
|
metadata['rating_value'], |
|
errors='coerce') |
|
|
|
|
|
if st.button('Get Recommendations'): |
|
with st.spinner('Finding your fragrance recs...'): |
|
if query == "": |
|
st.warning("No query entered.") |
|
else: |
|
|
|
current_df = metadata.copy() |
|
|
|
|
|
if gender_filter != "All": |
|
current_df = current_df[current_df['gender'].str.lower() == gender_filter.lower()] |
|
|
|
|
|
if brand_filter: |
|
current_df = current_df[current_df['brand'].str.contains(brand_filter, case=False, na=False)] |
|
|
|
|
|
if 'rating_value' in current_df.columns: |
|
current_df = current_df[current_df['rating_value'].ge(min_rating)] |
|
|
|
|
|
if note_filter: |
|
notes = [n.strip().lower() for n in note_filter.split(",")] |
|
def note_check(row): |
|
note_fields = [ |
|
str(row['top']).lower() if pd.notna(row['top']) else "", |
|
str(row['middle']).lower() if pd.notna(row['middle']) else "", |
|
str(row['base']).lower() if pd.notna(row['base']) else "" |
|
] |
|
return any(note in field for note in notes for field in note_fields) |
|
|
|
current_df = current_df[current_df.apply(note_check, axis=1)] |
|
|
|
valid_indices = current_df.index.tolist() |
|
|
|
|
|
if not valid_indices: |
|
st.warning("No fragrances match all your filters. Try relaxing some criteria.") |
|
st.stop() |
|
|
|
|
|
filtered_vectors = np.vstack([index.reconstruct(int(idx)) for idx in valid_indices]) |
|
temp_index = faiss.IndexFlatIP(filtered_vectors.shape[1]) |
|
temp_index.add(filtered_vectors) |
|
|
|
|
|
query_vector = encode([query]) |
|
faiss.normalize_L2(query_vector) |
|
|
|
|
|
sim_score, I = temp_index.search(query_vector, min(k, len(valid_indices))) |
|
|
|
|
|
results = [(valid_indices[i], sim_score[0][j]) for j, i in enumerate(I[0])] |
|
|
|
|
|
st.subheader(f"Recommended Fragrances ({len(results)} results)") |
|
cols = st.columns(3) |
|
|
|
for idx, (result_idx, sim_score) in enumerate(results): |
|
rec = metadata.loc[result_idx] |
|
|
|
|
|
name = rec.get('perfume', 'Unknown') |
|
brand = rec.get('brand', 'Unknown') |
|
perfumer_text = rec.get('perfumer', 'Unknown') |
|
top_notes = rec.get('top', 'Unknown') |
|
middle_notes = rec.get('middle', 'Unknown') |
|
base_notes = rec.get('base', 'Unknown') |
|
accords_text = rec.get('accord', 'Unknown') |
|
rating = rec.get('rating_value', '?') |
|
|
|
|
|
description = ( |
|
f"The fragrance is called {name}. It is by {brand}. " |
|
f"The perfumer is {perfumer_text}. The top notes are {top_notes}, " |
|
f"the heart notes are {middle_notes}, and the base notes are {base_notes}. " |
|
f"The main accords are {accords_text}." |
|
) |
|
|
|
explanation = get_ollama_explanation(query, description) |
|
|
|
|
|
card = create_fragrance_card( |
|
name, |
|
rating, |
|
brand, |
|
perfumer_text, |
|
top_notes, |
|
middle_notes, |
|
base_notes, |
|
accords_text, |
|
explanation |
|
) |
|
cols[idx % 3].markdown(card, unsafe_allow_html=True) |
|
|