Spaces:
Running
Running
Upload 11 files
Browse files- README.md +33 -5
- app.py +304 -0
- data/tfidf_vectorizer.pkl +3 -0
- main.py +94 -0
- requirements.txt +11 -0
- utils/.DS_Store +0 -0
- utils/clean_text.py +28 -0
- utils/ranker.py +23 -0
- utils/semantic_similarity.py +25 -0
- utils/syntactic_similarity.py +73 -0
- utils/tfidf_similarity.py +34 -0
README.md
CHANGED
|
@@ -1,12 +1,40 @@
|
|
| 1 |
---
|
| 2 |
-
title: Team 149 Project
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.0.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Team 149 Project 2
|
| 3 |
+
emoji: 🐠
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.0.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Restaurant Recommendation System - UI
|
| 13 |
+
|
| 14 |
+
A web-based interface for searching and discovering restaurants in Paris with natural language search, interactive map visualization, and popularity ranking.
|
| 15 |
+
|
| 16 |
+
## Features
|
| 17 |
+
|
| 18 |
+
- Natural language search for restaurants
|
| 19 |
+
- Interactive Paris map with color-coded rating indicators
|
| 20 |
+
- Bayesian popularity ranking
|
| 21 |
+
- Semantic and keyword search options
|
| 22 |
+
- Database of 5,277+ restaurants
|
| 23 |
+
|
| 24 |
+
## Installation
|
| 25 |
+
|
| 26 |
+
### Install Dependencies
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
python -m venv team-149-project
|
| 30 |
+
source team-149-project/bin/activate # On Windows: team-149-project\Scripts\activate
|
| 31 |
+
pip install -r requirements.txt --no-cache-dir
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### Run Application
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
python demo_app_advanced.py
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Application launches at `http://127.0.0.1:7860`
|
app.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import folium
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# Add utils to path
|
| 9 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'utils'))
|
| 10 |
+
from clean_text import clean_text
|
| 11 |
+
from semantic_similarity import Encoder
|
| 12 |
+
from ranker import compute_bayesian_popularity_score
|
| 13 |
+
from main import get_recommendations
|
| 14 |
+
|
| 15 |
+
print("Loading restaurant data...")
|
| 16 |
+
data = pd.read_csv("../data/toy_data_aggregated_embeddings.csv")
|
| 17 |
+
print(f"Loaded {len(data)} restaurants")
|
| 18 |
+
|
| 19 |
+
# Compute Bayesian popularity scores
|
| 20 |
+
print("Computing popularity scores...")
|
| 21 |
+
data = compute_bayesian_popularity_score(data)
|
| 22 |
+
print("Popularity scores computed")
|
| 23 |
+
|
| 24 |
+
print("Loading pre-computed embeddings...")
|
| 25 |
+
all_desc_embeddings = np.vstack(data["embedding"].values)
|
| 26 |
+
print(f"Loaded embeddings with shape {all_desc_embeddings.shape}")
|
| 27 |
+
|
| 28 |
+
# Initialize semantic encoder
|
| 29 |
+
print("Loading semantic encoder model...")
|
| 30 |
+
try:
|
| 31 |
+
encoder = Encoder()
|
| 32 |
+
print("Semantic encoder loaded")
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Warning: Could not load semantic encoder: {e}")
|
| 35 |
+
print("Falling back to keyword-only search")
|
| 36 |
+
|
| 37 |
+
def create_paris_map(results_df):
|
| 38 |
+
"""Create interactive map of Paris restaurants"""
|
| 39 |
+
paris_center = [48.8566, 2.3522]
|
| 40 |
+
m = folium.Map(location=paris_center, zoom_start=12, tiles='OpenStreetMap')
|
| 41 |
+
|
| 42 |
+
for idx, row in results_df.iterrows():
|
| 43 |
+
lat_offset = np.random.uniform(-0.05, 0.05)
|
| 44 |
+
lng_offset = np.random.uniform(-0.07, 0.07)
|
| 45 |
+
coords = [48.8566 + lat_offset, 2.3522 + lng_offset]
|
| 46 |
+
|
| 47 |
+
rating = float(row.get('overall_rating', 0))
|
| 48 |
+
color = 'green' if rating >= 4.5 else 'blue' if rating >= 4.0 else 'orange' if rating >= 3.5 else 'red'
|
| 49 |
+
|
| 50 |
+
popup_html = f"""
|
| 51 |
+
<div style="width:250px">
|
| 52 |
+
<h4><b>{row['name']}</b></h4>
|
| 53 |
+
<p>Rating: {row.get('overall_rating', 'N/A')}</p>
|
| 54 |
+
<p>Reviews: {row.get('review_count', 'N/A')}</p>
|
| 55 |
+
<p>Popularity Score: {row.get('pop_score', 'N/A'):.2f}</p>
|
| 56 |
+
</div>
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
folium.Marker(
|
| 60 |
+
location=coords,
|
| 61 |
+
popup=folium.Popup(popup_html, max_width=300),
|
| 62 |
+
icon=folium.Icon(color=color, icon='cutlery', prefix='fa')
|
| 63 |
+
).add_to(m)
|
| 64 |
+
|
| 65 |
+
return m._repr_html_()
|
| 66 |
+
|
| 67 |
+
# def semantic_search(query, data_source, num_results, use_popularity):
|
| 68 |
+
# """Semantic search using embeddings"""
|
| 69 |
+
# if not query.strip():
|
| 70 |
+
# return "Please enter a search query", None
|
| 71 |
+
|
| 72 |
+
# try:
|
| 73 |
+
# query_clean = clean_text(query)
|
| 74 |
+
|
| 75 |
+
# # Generate query embedding
|
| 76 |
+
# print(f"Encoding query: {query_clean}")
|
| 77 |
+
# query_embedding = encoder.encode([query_clean], show_progress_bar=False)
|
| 78 |
+
# query_embedding = query_embedding.cpu().numpy()
|
| 79 |
+
|
| 80 |
+
# # Compute semantic similarity
|
| 81 |
+
# similarities = cosine_similarity(query_embedding, all_desc_embeddings)[0]
|
| 82 |
+
|
| 83 |
+
# # Combine with popularity if requested
|
| 84 |
+
# if use_popularity:
|
| 85 |
+
# sim_normalized = (similarities - similarities.min()) / (similarities.max() - similarities.min() + 1e-10)
|
| 86 |
+
# pop_normalized = (data["pop_score"] - data["pop_score"].min()) / (data["pop_score"].max() - data["pop_score"].min() + 1e-10)
|
| 87 |
+
# # Combined score: 70% semantic, 30% popularity
|
| 88 |
+
# scores = 0.7 * sim_normalized + 0.3 * pop_normalized
|
| 89 |
+
# else:
|
| 90 |
+
# scores = similarities
|
| 91 |
+
|
| 92 |
+
# top_indices = np.argsort(scores)[-int(num_results):][::-1]
|
| 93 |
+
# results = data.iloc[top_indices].copy()
|
| 94 |
+
# results['similarity_score'] = scores[top_indices]
|
| 95 |
+
|
| 96 |
+
# map_html = create_paris_map(results)
|
| 97 |
+
|
| 98 |
+
# output = f"Found {len(results)} restaurants for '{query}'\n"
|
| 99 |
+
# output += f"Data Source: {data_source}\n"
|
| 100 |
+
# output += f"Search Method: Semantic Search {'+ Popularity' if use_popularity else ''}\n\n"
|
| 101 |
+
|
| 102 |
+
# for idx, (_, row) in enumerate(results.iterrows(), 1):
|
| 103 |
+
# name = row.get('name', 'Unknown')
|
| 104 |
+
# rating = row.get('overall_rating', 'N/A')
|
| 105 |
+
# reviews = row.get('review_count', 'N/A')
|
| 106 |
+
# similarity = row.get('similarity_score', 0)
|
| 107 |
+
# pop_score = row.get('pop_score', 0)
|
| 108 |
+
|
| 109 |
+
# output += f"{idx}. **{name}**\n"
|
| 110 |
+
# output += f" Rating: {rating} | Reviews: {reviews}\n"
|
| 111 |
+
# output += f" Match: {similarity:.3f}"
|
| 112 |
+
# if use_popularity:
|
| 113 |
+
# output += f" | Popularity: {pop_score:.2f}"
|
| 114 |
+
# output += "\n"
|
| 115 |
+
|
| 116 |
+
# if 'address' in row and pd.notna(row['address']):
|
| 117 |
+
# addr = str(row['address'])[:100]
|
| 118 |
+
# output += f" Address: {addr}\n"
|
| 119 |
+
|
| 120 |
+
# output += "\n"
|
| 121 |
+
|
| 122 |
+
# return output, map_html
|
| 123 |
+
|
| 124 |
+
# except Exception as e:
|
| 125 |
+
# import traceback
|
| 126 |
+
# return f"Error: {str(e)}\n\n{traceback.format_exc()}", None
|
| 127 |
+
|
| 128 |
+
# def keyword_search(query, data_source, num_results, use_popularity):
|
| 129 |
+
# """Keyword-based search with optional popularity ranking"""
|
| 130 |
+
# if not query.strip():
|
| 131 |
+
# return "Please enter a search query", None
|
| 132 |
+
|
| 133 |
+
# try:
|
| 134 |
+
# query_clean = clean_text(query).lower()
|
| 135 |
+
# query_words = set(query_clean.split())
|
| 136 |
+
|
| 137 |
+
# scores = []
|
| 138 |
+
# for idx, row in data.iterrows():
|
| 139 |
+
# score = 0
|
| 140 |
+
# name = str(row.get('name', '')).lower()
|
| 141 |
+
|
| 142 |
+
# # Check name matches
|
| 143 |
+
# for word in query_words:
|
| 144 |
+
# if word in name:
|
| 145 |
+
# score += 2
|
| 146 |
+
|
| 147 |
+
# rating = float(row.get('overall_rating', 0))
|
| 148 |
+
# score += rating * 0.5
|
| 149 |
+
|
| 150 |
+
# # Add popularity if requested
|
| 151 |
+
# if use_popularity:
|
| 152 |
+
# pop_score = float(row.get('pop_score', 0))
|
| 153 |
+
# score += pop_score * 0.3
|
| 154 |
+
|
| 155 |
+
# scores.append(score)
|
| 156 |
+
|
| 157 |
+
# top_indices = np.argsort(scores)[-int(num_results):][::-1]
|
| 158 |
+
# results = data.iloc[top_indices].copy()
|
| 159 |
+
# results['match_score'] = [scores[i] for i in top_indices]
|
| 160 |
+
|
| 161 |
+
# map_html = create_paris_map(results)
|
| 162 |
+
|
| 163 |
+
# output = f"Found {len(results)} restaurants for '{query}'\n"
|
| 164 |
+
# output += f"Data Source: {data_source}\n"
|
| 165 |
+
# output += f"Search Method: Keyword Search {'+ Popularity' if use_popularity else ''}\n\n"
|
| 166 |
+
|
| 167 |
+
# for idx, (_, row) in enumerate(results.iterrows(), 1):
|
| 168 |
+
# name = row.get('name', 'Unknown')
|
| 169 |
+
# rating = row.get('overall_rating', 'N/A')
|
| 170 |
+
# reviews = row.get('review_count', 'N/A')
|
| 171 |
+
# match = row.get('match_score', 0)
|
| 172 |
+
# pop_score = row.get('pop_score', 0)
|
| 173 |
+
|
| 174 |
+
# output += f"{idx}. **{name}**\n"
|
| 175 |
+
# output += f" Rating: {rating} | Reviews: {reviews}\n"
|
| 176 |
+
# output += f" Match Score: {match:.2f}"
|
| 177 |
+
# if use_popularity:
|
| 178 |
+
# output += f" | Popularity: {pop_score:.2f}"
|
| 179 |
+
# output += "\n"
|
| 180 |
+
|
| 181 |
+
# if 'address' in row and pd.notna(row['address']):
|
| 182 |
+
# addr = str(row['address'])[:100]
|
| 183 |
+
# output += f" Address: {addr}\n"
|
| 184 |
+
|
| 185 |
+
# output += "\n"
|
| 186 |
+
|
| 187 |
+
# return output, map_html
|
| 188 |
+
|
| 189 |
+
# except Exception as e:
|
| 190 |
+
# import traceback
|
| 191 |
+
# return f"Error: {str(e)}\n\n{traceback.format_exc()}", None
|
| 192 |
+
|
| 193 |
+
# def search_restaurants(query, data_source, search_method, num_results, use_popularity):
|
| 194 |
+
# """Main search function that routes to appropriate search method"""
|
| 195 |
+
# if search_method == "Semantic Search" and use_semantic:
|
| 196 |
+
# return semantic_search(query, data_source, num_results, use_popularity)
|
| 197 |
+
# else:
|
| 198 |
+
# return keyword_search(query, data_source, num_results, use_popularity)
|
| 199 |
+
|
| 200 |
+
def search_restaurants(query_input, data_source, num_results):
|
| 201 |
+
n_candidates = 100
|
| 202 |
+
query_clean = clean_text(query_input)
|
| 203 |
+
return get_recommendations(query_clean, n_candidates, num_results)
|
| 204 |
+
|
| 205 |
+
# Create Gradio interface
|
| 206 |
+
with gr.Blocks(title="Restaurant Finder", theme=gr.themes.Soft()) as app:
|
| 207 |
+
gr.Markdown("""
|
| 208 |
+
# Advanced Restaurant Recommendation System
|
| 209 |
+
### Search Through 5,000+ Restaurants with AI-Powered Semantic Search
|
| 210 |
+
|
| 211 |
+
Find restaurants using semantic understanding and popularity ranking!
|
| 212 |
+
""")
|
| 213 |
+
|
| 214 |
+
with gr.Row():
|
| 215 |
+
with gr.Column(scale=3):
|
| 216 |
+
query_input = gr.Textbox(
|
| 217 |
+
label="Search Query",
|
| 218 |
+
placeholder="e.g., Italian pasta, best sushi, romantic dinner, family-friendly pizza",
|
| 219 |
+
lines=2
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
with gr.Column(scale=2):
|
| 223 |
+
data_source = gr.Dropdown(
|
| 224 |
+
choices=["Michelin", "Google", "Yelp"],
|
| 225 |
+
value="Yelp",
|
| 226 |
+
label="Data Source",
|
| 227 |
+
info="Select restaurant data source"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
with gr.Row():
|
| 231 |
+
# with gr.Column(scale=2):
|
| 232 |
+
# search_method = gr.Radio(
|
| 233 |
+
# choices=["Keyword Search", "Semantic Search"],
|
| 234 |
+
# value="Semantic Search" if use_semantic else "Keyword Search",
|
| 235 |
+
# label="Search Method",
|
| 236 |
+
# info="Semantic uses AI embeddings, Keyword uses exact matches"
|
| 237 |
+
# )
|
| 238 |
+
|
| 239 |
+
with gr.Column(scale=1):
|
| 240 |
+
num_results = gr.Slider(
|
| 241 |
+
minimum=5,
|
| 242 |
+
maximum=30,
|
| 243 |
+
value=10,
|
| 244 |
+
step=5,
|
| 245 |
+
label="Results"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# with gr.Column(scale=1):
|
| 249 |
+
# use_popularity = gr.Checkbox(
|
| 250 |
+
# label="Use Popularity Ranking",
|
| 251 |
+
# value=True,
|
| 252 |
+
# info="Boost popular restaurants"
|
| 253 |
+
# )
|
| 254 |
+
|
| 255 |
+
search_btn = gr.Button("Search Restaurants", variant="primary", size="lg")
|
| 256 |
+
|
| 257 |
+
with gr.Row():
|
| 258 |
+
with gr.Column(scale=1):
|
| 259 |
+
results_output = gr.Textbox(
|
| 260 |
+
label="Restaurant Results",
|
| 261 |
+
lines=20,
|
| 262 |
+
max_lines=30
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
with gr.Column(scale=1):
|
| 266 |
+
map_output = gr.HTML(
|
| 267 |
+
label="Paris Map"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
gr.Markdown("### Try These Examples:")
|
| 271 |
+
|
| 272 |
+
examples = [
|
| 273 |
+
["Italian pasta", "Yelp", 10],
|
| 274 |
+
["sushi", "Michelin", 10],
|
| 275 |
+
["romantic dinner", "Google", 8],
|
| 276 |
+
["family-friendly pizza", "Yelp", 10],
|
| 277 |
+
["best seafood", "Michelin", 10],
|
| 278 |
+
["cheap burger", "Google", 10]
|
| 279 |
+
]
|
| 280 |
+
|
| 281 |
+
gr.Examples(
|
| 282 |
+
examples=examples,
|
| 283 |
+
inputs=[query_input, data_source, num_results]
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
search_btn.click(
|
| 287 |
+
fn=search_restaurants,
|
| 288 |
+
inputs=[query_input, data_source, num_results],
|
| 289 |
+
outputs=[results_output, map_output]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
query_input.submit(
|
| 293 |
+
fn=search_restaurants,
|
| 294 |
+
inputs=[query_input, data_source, num_results],
|
| 295 |
+
outputs=[results_output, map_output]
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if __name__ == "__main__":
|
| 299 |
+
print("\nStarting Advanced Restaurant Finder...")
|
| 300 |
+
print(f"{len(data)} restaurants ready to search")
|
| 301 |
+
print(f"Popularity Ranking: Enabled")
|
| 302 |
+
print("Opening at http://127.0.0.1:7860\n")
|
| 303 |
+
|
| 304 |
+
app.launch(share=False, server_name="127.0.0.1", server_port=7860, inbrowser=True)
|
data/tfidf_vectorizer.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:32c417c211041c2ffadda4776cc7eaaa03d416920f5b4541b127fb6c816cc65a
|
| 3 |
+
size 473929
|
main.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import nltk
|
| 3 |
+
import benepar
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 7 |
+
|
| 8 |
+
from utils.clean_text import clean_text
|
| 9 |
+
from utils.semantic_similarity import Encoder
|
| 10 |
+
from utils.syntactic_similarity import Parser
|
| 11 |
+
from utils.tfidf_similarity import TFIDF_Vectorizer
|
| 12 |
+
|
| 13 |
+
# Set default device to CUDA if available, otherwise CPU
|
| 14 |
+
if torch.cuda.is_available():
|
| 15 |
+
torch.set_default_device("cuda")
|
| 16 |
+
else:
|
| 17 |
+
torch.set_default_device("cpu")
|
| 18 |
+
|
| 19 |
+
# Download models/data
|
| 20 |
+
nltk.download('punkt')
|
| 21 |
+
nltk.download('punkt_tab')
|
| 22 |
+
benepar.download('benepar_en3_large')
|
| 23 |
+
|
| 24 |
+
# Load dataset
|
| 25 |
+
data = pd.read_csv("data/toy_data_aggregated_embeddings.csv")
|
| 26 |
+
|
| 27 |
+
# Load precomputed TF-IDF features
|
| 28 |
+
restaurant_tfidf_features = np.load("data/toy_data_tfidf_features.npz")
|
| 29 |
+
|
| 30 |
+
# Extract embeddings
|
| 31 |
+
all_desc_embeddings = np.vstack(data["embedding"].values)
|
| 32 |
+
|
| 33 |
+
# Initialize encoder
|
| 34 |
+
encoder = Encoder()
|
| 35 |
+
|
| 36 |
+
# Initialize syntactic parser
|
| 37 |
+
parser = Parser()
|
| 38 |
+
|
| 39 |
+
# Initialize TF-IDF vectorizer
|
| 40 |
+
tfidf_vectorizer = TFIDF_Vectorizer(load_vectorizer=True)
|
| 41 |
+
|
| 42 |
+
def retrieve_candidates(query: str, n_candidates: int):
|
| 43 |
+
# Encode query
|
| 44 |
+
query_emb = encoder.encode([query]).cpu().numpy()
|
| 45 |
+
|
| 46 |
+
# Semantic similarities
|
| 47 |
+
desc_sem_sim = cosine_similarity(query_emb, all_desc_embeddings)[0]
|
| 48 |
+
|
| 49 |
+
# TF-IDF similarities
|
| 50 |
+
tfidf_sim = tfidf_vectorizer.compute_tfidf_scores(query, restaurant_tfidf_features)
|
| 51 |
+
|
| 52 |
+
# Syntactic similarities
|
| 53 |
+
parsed_query = parser.parse_text(query)
|
| 54 |
+
parsed_query = parser.subtree_set(parsed_query)
|
| 55 |
+
|
| 56 |
+
syn_sims = []
|
| 57 |
+
for trees_list in data["syntactic_tree"]:
|
| 58 |
+
review_sims = []
|
| 59 |
+
for review_tree_subs in trees_list:
|
| 60 |
+
if review_tree_subs is None:
|
| 61 |
+
review_tree_subs = set()
|
| 62 |
+
sim = parser.compute_syntactic_similarity(parsed_query, review_tree_subs)
|
| 63 |
+
review_sims.append(sim)
|
| 64 |
+
syn_sims.append(np.mean(review_sims))
|
| 65 |
+
|
| 66 |
+
# Combined Stage 1 score
|
| 67 |
+
syn_sims = np.array(syn_sims)
|
| 68 |
+
combined_stage1_scores = 0.8*desc_sem_sim + 0.1*syn_sims + 0.1*tfidf_sim
|
| 69 |
+
|
| 70 |
+
# Get top N candidates for Stage 2 reranking
|
| 71 |
+
candidates_idx = np.argsort(combined_stage1_scores)[-n_candidates:][::-1]
|
| 72 |
+
|
| 73 |
+
return candidates_idx
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def rerank(candidates_idx: np.ndarray, n_rec: int = 10, ) -> list:
|
| 77 |
+
|
| 78 |
+
# Get popularity scores for stage 1 candidates
|
| 79 |
+
rerank_scores = data.loc[candidates_idx, "pop_score"].values
|
| 80 |
+
|
| 81 |
+
# Retrieve n_rec restaurant based on pop_score
|
| 82 |
+
topN_reranked_local_idx = np.argsort(rerank_scores)[-n_rec:][::-1]
|
| 83 |
+
topN_reranked_global_idx = candidates_idx[topN_reranked_local_idx]
|
| 84 |
+
|
| 85 |
+
# Get restaurant_id for final recommendations
|
| 86 |
+
restaurant_ids = data.loc[topN_reranked_global_idx, "id"].tolist()
|
| 87 |
+
|
| 88 |
+
return restaurant_ids
|
| 89 |
+
|
| 90 |
+
def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30):
|
| 91 |
+
query_clean = clean_text(query)
|
| 92 |
+
candidates_idx = retrieve_candidates(query_clean, n_candidates)
|
| 93 |
+
restaurant_ids = rerank(candidates_idx, n_rec)
|
| 94 |
+
return restaurant_ids
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu
|
| 2 |
+
|
| 3 |
+
numpy==1.25.2
|
| 4 |
+
scipy==1.11.2
|
| 5 |
+
pandas==2.1.1
|
| 6 |
+
scikit-learn==1.3.0
|
| 7 |
+
sentence-transformers
|
| 8 |
+
nltk==3.8.1
|
| 9 |
+
benepar==0.2.0
|
| 10 |
+
tqdm==4.66.1
|
| 11 |
+
folium
|
utils/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
utils/clean_text.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
def clean_text(text) -> str:
|
| 4 |
+
|
| 5 |
+
# Strip and lower
|
| 6 |
+
text = text.strip().lower()
|
| 7 |
+
|
| 8 |
+
# Remove mentions (@username) and hashtags (#tag)
|
| 9 |
+
text = re.sub(r'[@#][\w∆]+', '', text)
|
| 10 |
+
|
| 11 |
+
# Remove extra spaces left behind
|
| 12 |
+
text = re.sub(r'\s+', ' ', text)
|
| 13 |
+
text = text.replace("\n", " ").replace("\t", " ")
|
| 14 |
+
|
| 15 |
+
# Remove phone numbers
|
| 16 |
+
text = re.sub(r'\b\d{10}\b', '', text)
|
| 17 |
+
|
| 18 |
+
# Collapse repeated punctuation (e.g. !!!!)
|
| 19 |
+
text = re.sub(r'([^\w\s])\1+', r'\1', text)
|
| 20 |
+
|
| 21 |
+
# Collapse multiple spaces
|
| 22 |
+
text = re.sub(r'\s+', ' ', text)
|
| 23 |
+
|
| 24 |
+
# Fix "\'" like: can\'t, don\'t, etc
|
| 25 |
+
text = re.sub(r"\\'", "'", text)
|
| 26 |
+
text = re.sub(r"\\'", "'", text)
|
| 27 |
+
|
| 28 |
+
return text.strip()
|
utils/ranker.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
def compute_bayesian_popularity_score(df, rating_col="overall_rating", reviews_col="review_count", m_prior=20):
|
| 5 |
+
# Convert to numeric
|
| 6 |
+
df[rating_col] = pd.to_numeric(df[rating_col], errors="coerce")
|
| 7 |
+
df[reviews_col] = pd.to_numeric(df[reviews_col], errors="coerce").fillna(0).astype(int)
|
| 8 |
+
|
| 9 |
+
# Global mean rating
|
| 10 |
+
mu = df[rating_col].dropna().mean()
|
| 11 |
+
|
| 12 |
+
# Data
|
| 13 |
+
n = df[reviews_col]
|
| 14 |
+
r = df[rating_col].fillna(mu)
|
| 15 |
+
|
| 16 |
+
# Bayesian rating
|
| 17 |
+
df["bayes_rating"] = ((mu * m_prior + n * r) / (m_prior + n.replace(0, np.nan))).fillna(mu)
|
| 18 |
+
|
| 19 |
+
# Popularity metrics
|
| 20 |
+
df["pop_log"] = np.log1p(n)
|
| 21 |
+
df["pop_score"] = 0.7 * df["bayes_rating"] + 0.3 * df["pop_log"]
|
| 22 |
+
|
| 23 |
+
return df
|
utils/semantic_similarity.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
from sentence_transformers import SentenceTransformer
|
| 3 |
+
|
| 4 |
+
class Encoder():
|
| 5 |
+
def __init__(self):
|
| 6 |
+
print("Loading embedding model...")
|
| 7 |
+
self.model = SentenceTransformer(
|
| 8 |
+
"KaLM-Embedding/KaLM-embedding-multilingual-mini-instruct-v2.5",
|
| 9 |
+
model_kwargs={"attn_implementation": "sdpa"}
|
| 10 |
+
)
|
| 11 |
+
self.model = self.model.half()
|
| 12 |
+
|
| 13 |
+
def encode(
|
| 14 |
+
self,
|
| 15 |
+
texts: List[str],
|
| 16 |
+
batch_size: int = 8,
|
| 17 |
+
show_progress_bar: bool = False,
|
| 18 |
+
save_path: str = None):
|
| 19 |
+
|
| 20 |
+
embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=show_progress_bar, batch_size=batch_size)
|
| 21 |
+
|
| 22 |
+
# if save_path:
|
| 23 |
+
# torch.save(embeddings, save_path)
|
| 24 |
+
|
| 25 |
+
return embeddings
|
utils/syntactic_similarity.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pickle
|
| 3 |
+
import benepar
|
| 4 |
+
import nltk
|
| 5 |
+
from nltk.tree import Tree
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
+
|
| 11 |
+
nltk.data.path.append('data/nltk_data')
|
| 12 |
+
|
| 13 |
+
class Parser():
|
| 14 |
+
def __init__(self):
|
| 15 |
+
torch.set_default_device("cpu")
|
| 16 |
+
self.parser = benepar.Parser("benepar_en3_large")
|
| 17 |
+
self.parser.batch_size = 64
|
| 18 |
+
self.parsed_eval_reviews_path = "data/parsed/parsed_reviews.pkl"
|
| 19 |
+
self.parsed_toy_reviews_path = "data/parsed/parsed_toy_data_reviews.pkl"
|
| 20 |
+
|
| 21 |
+
def subtree_set(self, tree: Tree):
|
| 22 |
+
"""
|
| 23 |
+
Return a flat set of all subtrees as strings (hashable).
|
| 24 |
+
"""
|
| 25 |
+
subs = set()
|
| 26 |
+
|
| 27 |
+
def helper(t):
|
| 28 |
+
# Convert each subtree to a string and add to the set
|
| 29 |
+
subs.add(str(t))
|
| 30 |
+
for child in t:
|
| 31 |
+
if isinstance(child, Tree):
|
| 32 |
+
helper(child)
|
| 33 |
+
|
| 34 |
+
helper(tree)
|
| 35 |
+
return subs
|
| 36 |
+
|
| 37 |
+
def parse_text(self, text):
|
| 38 |
+
try:
|
| 39 |
+
return self.parser.parse(text[:10000]) # truncate long reviews
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"Parse error: {e}")
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
def parse_reviews(self, reviews: list, toy_data: bool) -> list[set]:
|
| 45 |
+
parsed_reviews = []
|
| 46 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()-1) as executor:
|
| 47 |
+
for tree in tqdm(executor.map(self.parse_text, reviews), total=len(reviews)):
|
| 48 |
+
if isinstance(tree, Tree):
|
| 49 |
+
parsed_reviews.append(self.subtree_set(tree))
|
| 50 |
+
else:
|
| 51 |
+
parsed_reviews.append(set()) # fallback for parse errors
|
| 52 |
+
|
| 53 |
+
# Save parsed reviews
|
| 54 |
+
with open(self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path, "wb") as f:
|
| 55 |
+
pickle.dump(parsed_reviews, f)
|
| 56 |
+
|
| 57 |
+
return parsed_reviews
|
| 58 |
+
|
| 59 |
+
def compute_syntactic_similarity(self, query_tree_subs: set, review_tree_subs: set) -> float:
|
| 60 |
+
"""
|
| 61 |
+
Jaccard similarity between two sets of subtrees (strings, hashable)
|
| 62 |
+
"""
|
| 63 |
+
intersect = query_tree_subs.intersection(review_tree_subs)
|
| 64 |
+
union = query_tree_subs.union(review_tree_subs)
|
| 65 |
+
if not union:
|
| 66 |
+
return 0.0
|
| 67 |
+
return len(intersect) / len(union)
|
| 68 |
+
|
| 69 |
+
def load_parsed_reviews(self, toy_data: bool) -> list[set]:
|
| 70 |
+
path = self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path
|
| 71 |
+
with open(path, "rb") as f:
|
| 72 |
+
parsed_reviews = pickle.load(f)
|
| 73 |
+
return parsed_reviews
|
utils/tfidf_similarity.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from scipy.sparse import save_npz
|
| 3 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 4 |
+
|
| 5 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 6 |
+
|
| 7 |
+
class TFIDF_Vectorizer():
|
| 8 |
+
def __init__(self, load_vectorizer=None, stop_words='english', min_df=2):
|
| 9 |
+
self.vectorizer_path = "tfidf_vectorizer.pkl"
|
| 10 |
+
self.tfidf_matrix_path = "tfidf_matrix.npz"
|
| 11 |
+
|
| 12 |
+
if load_vectorizer:
|
| 13 |
+
with open(self.vectorizer_path, 'rb') as file:
|
| 14 |
+
self.vectorizer = pickle.load(file)
|
| 15 |
+
else:
|
| 16 |
+
self.vectorizer = TfidfVectorizer(stop_words=stop_words, min_df=min_df)
|
| 17 |
+
|
| 18 |
+
def compute_tfidf_matrix(self, texts):
|
| 19 |
+
features = self.vectorizer.fit_transform(texts)
|
| 20 |
+
|
| 21 |
+
# save vectorizer
|
| 22 |
+
with open(self.vectorizer_path, 'wb') as file:
|
| 23 |
+
pickle.dump(self.vectorizer, file)
|
| 24 |
+
|
| 25 |
+
# save tfidf matrix
|
| 26 |
+
save_npz(self.tfidf_matrix_path, features)
|
| 27 |
+
return features
|
| 28 |
+
|
| 29 |
+
def transform(self, texts: list) -> any:
|
| 30 |
+
return self.vectorizer.transform(texts)
|
| 31 |
+
|
| 32 |
+
def compute_tfidf_scores(self, query: str, restaurant_tfidf_features: any) -> list:
|
| 33 |
+
query_tfidf_features = self.vectorizer.transform([query])
|
| 34 |
+
return cosine_similarity(query_tfidf_features, restaurant_tfidf_features)[0]
|