RobertoBarrosoLuque commited on
Commit
34d08ee
·
1 Parent(s): e59c3d6

Lexical search is working

Browse files
Files changed (2) hide show
  1. src/app.py +71 -35
  2. src/data_prep/data_prep.py +1 -1
src/app.py CHANGED
@@ -4,6 +4,7 @@ from typing import List, Dict, Tuple
4
  from pathlib import Path
5
  import os
6
  from config import GRADIO_THEME, CUSTOM_CSS, EXAMPLE_QUERIES
 
7
 
8
  _FILE_PATH = Path(__file__).parents[1]
9
 
@@ -44,16 +45,23 @@ SAMPLE_PRODUCTS = [
44
 
45
 
46
  def format_results(results: List[Dict], stage_name: str, metrics: Dict) -> str:
47
- """Format search results as HTML."""
 
 
 
 
 
 
48
  html_parts = [f"### {stage_name} Results\n\n"]
49
 
50
  for idx, result in enumerate(results, 1):
 
51
  html_parts.append(
52
  f"""
53
  <div class="result-card">
54
- <strong>{idx}. {result['title']}</strong><br/>
55
- <span style="color: #64748B; font-size: 0.9em;">{result['description']}</span><br/>
56
- <span style="color: #94A3B8; font-size: 0.85em;">Category: {result['category']}</span><br/>
57
  <span style="color: #6720FF; font-weight: 600;">Score: {result['score']:.3f}</span>
58
  </div>
59
  """
@@ -77,18 +85,21 @@ def search_stage_1(query: str) -> Tuple[str, Dict]:
77
  """Stage 1: Baseline BM25 keyword search."""
78
  start_time = time.time()
79
 
80
- # Placeholder: Simple keyword matching
81
- results = []
82
- for product in SAMPLE_PRODUCTS[:3]:
83
- results.append({**product, "score": 0.65 + (len(results) * 0.05)})
84
-
85
  latency = int((time.time() - start_time) * 1000)
86
 
 
 
 
 
 
 
87
  metrics = {
88
- "semantic_match": 0.58,
89
- "diversity": 0.60,
90
- "latency_ms": max(50, latency),
91
  }
 
92
 
93
  return format_results(results, "Stage 1: BM25 Baseline", metrics), metrics
94
 
@@ -97,10 +108,17 @@ def search_stage_2(query: str) -> Tuple[str, Dict]:
97
  """Stage 2: BM25 + Vector Embeddings."""
98
  start_time = time.time()
99
 
100
- # Placeholder: Simulated embedding search
101
- results = []
102
- for product in SAMPLE_PRODUCTS[:4]:
103
- results.append({**product, "score": 0.72 + (len(results) * 0.04)})
 
 
 
 
 
 
 
104
 
105
  latency = int((time.time() - start_time) * 1000)
106
 
@@ -117,10 +135,17 @@ def search_stage_3(query: str) -> Tuple[str, Dict]:
117
  """Stage 3: BM25 + Embeddings + Query Expansion."""
118
  start_time = time.time()
119
 
120
- # Placeholder: Simulated query expansion
121
- results = []
122
- for product in SAMPLE_PRODUCTS[:5]:
123
- results.append({**product, "score": 0.78 + (len(results) * 0.03)})
 
 
 
 
 
 
 
124
 
125
  latency = int((time.time() - start_time) * 1000)
126
 
@@ -137,10 +162,17 @@ def search_stage_4(query: str) -> Tuple[str, Dict]:
137
  """Stage 4: BM25 + Embeddings + Query Expansion + LLM Reranking."""
138
  start_time = time.time()
139
 
140
- # Placeholder: Simulated reranking
141
- results = []
142
- for product in SAMPLE_PRODUCTS[:5]:
143
- results.append({**product, "score": 0.85 + (len(results) * 0.025)})
 
 
 
 
 
 
 
144
 
145
  latency = int((time.time() - start_time) * 1000)
146
 
@@ -224,21 +256,25 @@ def set_example(example: str) -> str:
224
  # Code snippets for each stage
225
  CODE_STAGE_1 = """
226
  ```python
227
- from rank_bm25 import BM25Okapi
 
228
 
229
- # Tokenize documents
230
- tokenized_docs = [doc.split() for doc in documents]
 
 
231
 
232
- # Create BM25 index
233
- bm25 = BM25Okapi(tokenized_docs)
 
234
 
235
- # Search
236
- query_tokens = query.split()
237
- scores = bm25.get_scores(query_tokens)
 
238
 
239
- # Get top results
240
- top_indices = scores.argsort()[-5:][::-1]
241
- results = [documents[i] for i in top_indices]
242
  ```
243
  """
244
 
 
4
  from pathlib import Path
5
  import os
6
  from config import GRADIO_THEME, CUSTOM_CSS, EXAMPLE_QUERIES
7
+ from src.search.bm25_lexical_search import search_bm25
8
 
9
  _FILE_PATH = Path(__file__).parents[1]
10
 
 
45
 
46
 
47
  def format_results(results: List[Dict], stage_name: str, metrics: Dict) -> str:
48
+ """Format search results as HTML.
49
+
50
+ Args:
51
+ results: List of dicts with keys: product_name, description, main_category, secondary_category, score
52
+ stage_name: Name of the search stage
53
+ metrics: Dict with keys: semantic_match, diversity, latency_ms
54
+ """
55
  html_parts = [f"### {stage_name} Results\n\n"]
56
 
57
  for idx, result in enumerate(results, 1):
58
+ category = f"{result.get('main_category', 'N/A')} > {result.get('secondary_category', 'N/A')}"
59
  html_parts.append(
60
  f"""
61
  <div class="result-card">
62
+ <strong>{idx}. {result['product_name']}</strong><br/>
63
+ <span style="color: #64748B; font-size: 0.9em;">{result['description'][:150]}...</span><br/>
64
+ <span style="color: #94A3B8; font-size: 0.85em;">Category: {category}</span><br/>
65
  <span style="color: #6720FF; font-weight: 600;">Score: {result['score']:.3f}</span>
66
  </div>
67
  """
 
85
  """Stage 1: Baseline BM25 keyword search."""
86
  start_time = time.time()
87
 
88
+ results = search_bm25(query, top_k=5)
 
 
 
 
89
  latency = int((time.time() - start_time) * 1000)
90
 
91
+ unique_categories = len(set(r["main_category"] for r in results)) if results else 0
92
+ diversity = min(1.0, unique_categories / 5.0)
93
+
94
+ avg_score = sum(r["score"] for r in results) / len(results) if results else 0
95
+ semantic_match = min(1.0, avg_score / 10.0)
96
+
97
  metrics = {
98
+ "semantic_match": semantic_match,
99
+ "diversity": diversity,
100
+ "latency_ms": latency,
101
  }
102
+ print(f"Searched BM25 for {query} in {latency}ms")
103
 
104
  return format_results(results, "Stage 1: BM25 Baseline", metrics), metrics
105
 
 
108
  """Stage 2: BM25 + Vector Embeddings."""
109
  start_time = time.time()
110
 
111
+ # Placeholder: Simulated embedding search with correct format
112
+ results = [
113
+ {
114
+ "product_name": product["title"],
115
+ "description": product["description"],
116
+ "main_category": product["category"],
117
+ "secondary_category": "Placeholder",
118
+ "score": 0.72 + (idx * 0.04),
119
+ }
120
+ for idx, product in enumerate(SAMPLE_PRODUCTS[:4])
121
+ ]
122
 
123
  latency = int((time.time() - start_time) * 1000)
124
 
 
135
  """Stage 3: BM25 + Embeddings + Query Expansion."""
136
  start_time = time.time()
137
 
138
+ # Placeholder: Simulated query expansion with correct format
139
+ results = [
140
+ {
141
+ "product_name": product["title"],
142
+ "description": product["description"],
143
+ "main_category": product["category"],
144
+ "secondary_category": "Placeholder",
145
+ "score": 0.78 + (idx * 0.03),
146
+ }
147
+ for idx, product in enumerate(SAMPLE_PRODUCTS[:5])
148
+ ]
149
 
150
  latency = int((time.time() - start_time) * 1000)
151
 
 
162
  """Stage 4: BM25 + Embeddings + Query Expansion + LLM Reranking."""
163
  start_time = time.time()
164
 
165
+ # Placeholder: Simulated reranking with correct format
166
+ results = [
167
+ {
168
+ "product_name": product["title"],
169
+ "description": product["description"],
170
+ "main_category": product["category"],
171
+ "secondary_category": "Placeholder",
172
+ "score": 0.85 + (idx * 0.025),
173
+ }
174
+ for idx, product in enumerate(SAMPLE_PRODUCTS[:5])
175
+ ]
176
 
177
  latency = int((time.time() - start_time) * 1000)
178
 
 
256
  # Code snippets for each stage
257
  CODE_STAGE_1 = """
258
  ```python
259
+ import bm25s
260
+ import pandas as pd
261
 
262
+ # Step 1: Create BM25 index (one-time setup)
263
+ df = pd.read_parquet("data/amazon_products.parquet")
264
+ corpus = df["FullText"].tolist()
265
+ corpus_tokens = bm25s.tokenize(corpus, stopwords="en")
266
 
267
+ retriever = bm25s.BM25()
268
+ retriever.index(corpus_tokens)
269
+ retriever.save("data/bm25_index")
270
 
271
+ # Step 2: Load index and search
272
+ bm25_index = bm25s.BM25.load("data/bm25_index", load_corpus=False)
273
+ query_tokens = bm25s.tokenize(query, stopwords="en")
274
+ results, scores = bm25_index.retrieve(query_tokens, k=5)
275
 
276
+ # Extract top results
277
+ top_products = [df.iloc[idx] for idx in results[0]]
 
278
  ```
279
  """
280
 
src/data_prep/data_prep.py CHANGED
@@ -4,7 +4,7 @@ from pathlib import Path
4
  import numpy as np
5
  import faiss
6
  import bm25s
7
- from src.modules.inference import create_client
8
  from src.config import EMBEDDING_MODEL
9
 
10
  _FILE_PATH = Path(__file__).parents[2]
 
4
  import numpy as np
5
  import faiss
6
  import bm25s
7
+ from src.fireworks.inference import create_client
8
  from src.config import EMBEDDING_MODEL
9
 
10
  _FILE_PATH = Path(__file__).parents[2]