|
|
|
|
|
""" |
|
|
FINAL RAG SYSTEM FOR AMAZON MULTIMODAL DATASET (LOCAL CHROMA DB) |
|
|
----------------------------------------------------------------- |
|
|
Features: |
|
|
- Clean product text before embedding |
|
|
- CLIP text + image embedding (safe 77-token truncation) |
|
|
- New Chroma PersistentClient (2025 API) |
|
|
- CSV loader for Amazon dataset |
|
|
- Image downloader |
|
|
- Build vector DB for products |
|
|
- Query using text or image |
|
|
""" |
|
|
|
|
|
import os |
|
|
import csv |
|
|
import re |
|
|
import logging |
|
|
import requests |
|
|
import torch |
|
|
import clip |
|
|
from PIL import Image |
|
|
import chromadb |
|
|
import argparse |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_text(text: str, max_chars: int = 400) -> str: |
|
|
"""Removes Amazon noise text and limits size.""" |
|
|
if not isinstance(text, str): |
|
|
return "" |
|
|
|
|
|
patterns = [ |
|
|
r"Make sure this fits.*?model number\.", |
|
|
r"Technical details:.*", |
|
|
r"Specifications:.*", |
|
|
r"ProductDimensions:.*?(?=\|)", |
|
|
r"ShippingWeight:.*?(?=\|)", |
|
|
r"ASIN:.*?(?=\|)", |
|
|
r"Item model number:.*?(?=\|)", |
|
|
r"Go to your orders.*", |
|
|
r"Learn More.*" |
|
|
] |
|
|
|
|
|
for p in patterns: |
|
|
text = re.sub(p, "", text, flags=re.IGNORECASE) |
|
|
|
|
|
text = text.replace("|", " ") |
|
|
text = re.sub(r"\s+", " ", text).strip() |
|
|
|
|
|
return text[:max_chars] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPEmbedder: |
|
|
"""Multimodal embedder using OpenAI CLIP with safe truncation.""" |
|
|
|
|
|
def __init__(self, model_name="ViT-B/32"): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"[CLIP] Loading model on {self.device} ...") |
|
|
self.model, self.preprocess = clip.load(model_name, device=self.device) |
|
|
logger.info(f"[CLIP] Model {model_name} loaded successfully") |
|
|
|
|
|
def _truncate_tokens(self, text: str): |
|
|
tokens = clip.tokenize([text])[0] |
|
|
tokens = tokens[:77] |
|
|
return tokens.unsqueeze(0).to(self.device) |
|
|
|
|
|
def embed_text(self, text: str): |
|
|
|
|
|
text = clean_text(text) |
|
|
|
|
|
|
|
|
words = text.split() |
|
|
text = " ".join(words[:50]) |
|
|
|
|
|
|
|
|
tokens = clip.tokenize([text], truncate=True).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
emb = self.model.encode_text(tokens)[0] |
|
|
emb = emb / emb.norm() |
|
|
|
|
|
return emb.cpu().numpy().astype("float32") |
|
|
|
|
|
def embed_image(self, path: str): |
|
|
image = self.preprocess(Image.open(path)).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
vec = self.model.encode_image(image)[0] |
|
|
vec = vec / vec.norm() |
|
|
|
|
|
return vec.cpu().numpy().astype("float32") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChromaVectorStore: |
|
|
"""Uses new Chroma PersistentClient.""" |
|
|
|
|
|
def __init__(self, persist_dir="chromadb_store"): |
|
|
print(f"[Chroma] Initializing DB at: {persist_dir}") |
|
|
self.client = chromadb.PersistentClient(path=persist_dir) |
|
|
self.collection = self.client.get_or_create_collection( |
|
|
name="amazon_products", |
|
|
metadata={"hnsw:space": "cosine"} |
|
|
) |
|
|
|
|
|
def add_item(self, item_id: str, embedding, metadata: dict): |
|
|
self.collection.add( |
|
|
ids=[item_id], |
|
|
embeddings=[embedding], |
|
|
metadatas=[metadata] |
|
|
) |
|
|
|
|
|
def query(self, embedding, top_k=5): |
|
|
return self.collection.query( |
|
|
query_embeddings=[embedding], |
|
|
n_results=top_k |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_first_image(urls: str, save_dir="images"): |
|
|
"""Downloads the first valid image from the |-separated list.""" |
|
|
if not urls or not isinstance(urls, str): |
|
|
return None |
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
first_url = urls.split("|")[0].strip() |
|
|
if not first_url.lower().startswith("http"): |
|
|
return None |
|
|
|
|
|
|
|
|
from urllib.parse import unquote |
|
|
img_name = os.path.join(save_dir, unquote(os.path.basename(first_url)[:50]) + ".jpg") |
|
|
|
|
|
try: |
|
|
r = requests.get(first_url, timeout=5) |
|
|
if r.status_code == 200: |
|
|
with open(img_name, "wb") as f: |
|
|
f.write(r.content) |
|
|
return img_name |
|
|
else: |
|
|
logger.debug(f"Failed to download image (status {r.status_code}): {first_url}") |
|
|
except requests.RequestException as e: |
|
|
logger.debug(f"Image download error for {first_url}: {e}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Unexpected error downloading image {first_url}: {e}") |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_index(csv_path, persist_dir, max_items=None): |
|
|
embedder = CLIPEmbedder() |
|
|
vectorstore = ChromaVectorStore(persist_dir) |
|
|
|
|
|
logger.info(f"π Loading dataset: {csv_path}") |
|
|
|
|
|
|
|
|
stats = { |
|
|
"total_processed": 0, |
|
|
"text_embed_failures": 0, |
|
|
"image_download_failures": 0, |
|
|
"image_embed_failures": 0, |
|
|
"skipped_no_image": 0 |
|
|
} |
|
|
|
|
|
with open(csv_path, newline='', encoding="utf-8") as f: |
|
|
reader = csv.DictReader(f) |
|
|
|
|
|
for i, row in enumerate(reader): |
|
|
if max_items and i >= max_items: |
|
|
break |
|
|
|
|
|
pid = row.get("uniq_id") |
|
|
name = row.get("product_name", "") |
|
|
desc = row.get("product_text", "") |
|
|
cat = row.get("main_category", "") |
|
|
img_urls = row.get("image", "") |
|
|
|
|
|
full_text = f"{name} | {cat} | {clean_text(desc)}" |
|
|
|
|
|
try: |
|
|
t_emb = embedder.embed_text(full_text) |
|
|
except Exception as e: |
|
|
logger.error(f"Could not embed text for {pid}: {e}") |
|
|
stats["text_embed_failures"] += 1 |
|
|
continue |
|
|
|
|
|
img_path = download_first_image(img_urls) |
|
|
|
|
|
if not img_path: |
|
|
logger.info(f"Skipping product {pid} - no valid image") |
|
|
stats["image_download_failures"] += 1 |
|
|
stats["skipped_no_image"] += 1 |
|
|
continue |
|
|
|
|
|
try: |
|
|
img_emb = embedder.embed_image(img_path) |
|
|
except Exception as e: |
|
|
logger.debug(f"Could not embed image for {pid}: {e}") |
|
|
stats["image_embed_failures"] += 1 |
|
|
stats["skipped_no_image"] += 1 |
|
|
continue |
|
|
|
|
|
final_emb = (t_emb + img_emb) / 2 |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"id": pid or "", |
|
|
"name": name or "", |
|
|
"category": cat or "", |
|
|
"image_path": img_path or "" |
|
|
} |
|
|
|
|
|
vectorstore.add_item(pid, final_emb, metadata) |
|
|
stats["total_processed"] += 1 |
|
|
|
|
|
if i % 20 == 0: |
|
|
logger.info(f"Indexed {i} items...") |
|
|
|
|
|
logger.info("βοΈ Index build complete.") |
|
|
logger.info(f"Statistics: {stats}") |
|
|
return vectorstore |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_query(query_text=None, image_path=None, persist_dir="chromadb_store"): |
|
|
embedder = CLIPEmbedder() |
|
|
vectorstore = ChromaVectorStore(persist_dir) |
|
|
|
|
|
if query_text: |
|
|
emb = embedder.embed_text(query_text) |
|
|
elif image_path: |
|
|
emb = embedder.embed_image(image_path) |
|
|
else: |
|
|
raise ValueError("Provide query text or image") |
|
|
|
|
|
results = vectorstore.query(emb, top_k=5) |
|
|
|
|
|
print("\nπ QUERY RESULTS") |
|
|
print("------------------------") |
|
|
|
|
|
for i in range(len(results["ids"][0])): |
|
|
pid = results["ids"][0][i] |
|
|
meta = results["metadatas"][0][i] |
|
|
dist = results["distances"][0][i] |
|
|
|
|
|
print(f"\nRank {i+1}") |
|
|
print(f"Product ID: {pid}") |
|
|
print(f"Name: {meta.get('name')}") |
|
|
print(f"Category: {meta.get('category')}") |
|
|
print(f"Distance: {dist:.4f}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_retrieval(csv_path, persist_dir="chromadb_store", max_eval=50): |
|
|
""" |
|
|
Evaluate retrieval performance using category match as ground truth. |
|
|
Computes: |
|
|
- Accuracy@1 |
|
|
- Recall@1 |
|
|
- Recall@5 |
|
|
- Recall@10 |
|
|
""" |
|
|
|
|
|
print("\nπ Starting retrieval evaluation...\n") |
|
|
|
|
|
embedder = CLIPEmbedder() |
|
|
vectorstore = ChromaVectorStore(persist_dir) |
|
|
|
|
|
queries = [] |
|
|
with open(csv_path, newline='', encoding="utf-8") as f: |
|
|
reader = csv.DictReader(f) |
|
|
for i, row in enumerate(reader): |
|
|
if i >= max_eval: |
|
|
break |
|
|
queries.append(row) |
|
|
|
|
|
total = len(queries) |
|
|
correct_at_1 = 0 |
|
|
recall_at_1 = 0 |
|
|
recall_at_5 = 0 |
|
|
recall_at_10 = 0 |
|
|
|
|
|
for idx, row in enumerate(queries): |
|
|
pid = row["uniq_id"] |
|
|
category = row["main_category"] |
|
|
text_query = clean_text(row["product_name"] + " " + row["product_text"]) |
|
|
|
|
|
query_emb = embedder.embed_text(text_query) |
|
|
|
|
|
|
|
|
results = vectorstore.query(query_emb, top_k=10) |
|
|
|
|
|
retrieved_ids = results["ids"][0] |
|
|
retrieved_metas = results["metadatas"][0] |
|
|
|
|
|
retrieved_categories = [m.get("category") for m in retrieved_metas] |
|
|
|
|
|
|
|
|
gt_category = category |
|
|
|
|
|
|
|
|
if retrieved_categories[0] == gt_category: |
|
|
correct_at_1 += 1 |
|
|
recall_at_1 += 1 |
|
|
|
|
|
|
|
|
if gt_category in retrieved_categories[:5]: |
|
|
recall_at_5 += 1 |
|
|
|
|
|
|
|
|
if gt_category in retrieved_categories[:10]: |
|
|
recall_at_10 += 1 |
|
|
|
|
|
if idx % 10 == 0: |
|
|
print(f"Evaluated {idx}/{total} queries...") |
|
|
|
|
|
|
|
|
accuracy_at_1 = correct_at_1 / total |
|
|
recall_1 = recall_at_1 / total |
|
|
recall_5 = recall_at_5 / total |
|
|
recall_10 = recall_at_10 / total |
|
|
|
|
|
print("\nπ RETRIEVAL EVALUATION RESULTS") |
|
|
print("-----------------------------------") |
|
|
print(f"Accuracy@1: {accuracy_at_1:.3f}") |
|
|
print(f"Recall@1: {recall_1:.3f}") |
|
|
print(f"Recall@5: {recall_5:.3f}") |
|
|
print(f"Recall@10: {recall_10:.3f}") |
|
|
|
|
|
return { |
|
|
"Accuracy@1": accuracy_at_1, |
|
|
"Recall@1": recall_1, |
|
|
"Recall@5": recall_5, |
|
|
"Recall@10": recall_10 |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--build", action="store_true") |
|
|
parser.add_argument("--csv", type=str) |
|
|
parser.add_argument("--max", type=int) |
|
|
parser.add_argument("--text", type=str) |
|
|
parser.add_argument("--image", type=str) |
|
|
parser.add_argument("--db", type=str, default="chromadb_store") |
|
|
parser.add_argument("--eval", action="store_true") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.build: |
|
|
build_index(args.csv, args.db, args.max) |
|
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.eval: |
|
|
evaluate_retrieval(args.csv, persist_dir=args.db, max_eval=50) |
|
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.text or args.image: |
|
|
run_query(args.text, args.image, persist_dir=args.db) |
|
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("β No action specified. Use one of:") |
|
|
print(" --build --csv yourfile.csv") |
|
|
print(" --eval --csv yourfile.csv") |
|
|
print(" --text \"your query\"") |
|
|
print(" --image path_to_image") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|