| | """Visual RAG Tool - Image similarity search using local SigLIP embeddings. |
| | |
| | Schema: place_image_embeddings (place_id, embedding, image_url, metadata) |
| | places_metadata (place_id, name, category, rating, raw_data) |
| | |
| | Uses local SigLIP model (ViT-B-16-SigLIP) for generating 768-dim image embeddings. |
| | """ |
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | from sqlalchemy import text |
| | from sqlalchemy.ext.asyncio import AsyncSession |
| |
|
| | from app.shared.integrations.siglip_client import get_siglip_client |
| |
|
| |
|
| | @dataclass |
| | class ImageSearchResult: |
| | """Result from visual similarity search.""" |
| |
|
| | place_id: str |
| | name: str |
| | category: str |
| | rating: float |
| | similarity: float |
| | matched_images: int = 1 |
| | image_url: str = "" |
| |
|
| | |
| | from app.shared.prompts import RETRIEVE_SIMILAR_VISUALS_TOOL as TOOL_DEFINITION |
| |
|
| |
|
| | async def retrieve_similar_visuals( |
| | db: AsyncSession, |
| | image_url: str | None = None, |
| | image_bytes: bytes | None = None, |
| | limit: int = 10, |
| | threshold: float = 0.2, |
| | ) -> list[ImageSearchResult]: |
| | """ |
| | Visual similarity search using local SigLIP embeddings. |
| | |
| | Uses place_image_embeddings table with JOIN to places_metadata. |
| | |
| | Args: |
| | db: Database session |
| | image_url: URL of the query image |
| | image_bytes: Raw image bytes (alternative to URL) |
| | limit: Maximum results |
| | threshold: Minimum similarity threshold |
| | |
| | Returns: |
| | List of places with visual similarity scores |
| | """ |
| | |
| | siglip = get_siglip_client() |
| | |
| | |
| | if image_bytes: |
| | image_embedding = siglip.embed_image_bytes(image_bytes) |
| | elif image_url: |
| | image_embedding = siglip.embed_image_url(image_url) |
| | else: |
| | return [] |
| |
|
| | if image_embedding is None: |
| | return [] |
| |
|
| | |
| | embedding_str = "[" + ",".join(str(x) for x in image_embedding.tolist()) + "]" |
| |
|
| | |
| | sql = text(f""" |
| | SELECT |
| | e.place_id, |
| | e.image_url, |
| | 1 - (e.embedding <=> '{embedding_str}'::vector) as similarity, |
| | m.name, |
| | m.category, |
| | m.rating |
| | FROM place_image_embeddings e |
| | JOIN places_metadata m ON e.place_id = m.place_id |
| | WHERE 1 - (e.embedding <=> '{embedding_str}'::vector) > :threshold |
| | AND m.name IS NOT NULL |
| | AND m.name != '' |
| | ORDER BY e.embedding <=> '{embedding_str}'::vector |
| | LIMIT 100 |
| | """) |
| |
|
| | results = await db.execute(sql, { |
| | "threshold": threshold, |
| | }) |
| |
|
| | rows = results.fetchall() |
| |
|
| | |
| | place_scores: dict = {} |
| |
|
| | for r in rows: |
| | pid = r.place_id |
| |
|
| | if pid not in place_scores: |
| | place_scores[pid] = { |
| | 'total_score': 0.0, |
| | 'count': 0, |
| | 'data': r, |
| | 'best_image': r.image_url, |
| | } |
| |
|
| | place_scores[pid]['total_score'] += float(r.similarity) |
| | place_scores[pid]['count'] += 1 |
| | |
| | |
| | if float(r.similarity) > place_scores[pid]['total_score'] / place_scores[pid]['count']: |
| | place_scores[pid]['best_image'] = r.image_url |
| |
|
| | |
| | sorted_places = sorted( |
| | place_scores.items(), |
| | key=lambda x: x[1]['total_score'] / x[1]['count'], |
| | reverse=True |
| | )[:limit] |
| |
|
| | |
| | return [ |
| | ImageSearchResult( |
| | place_id=pid, |
| | name=data['data'].name or '', |
| | category=data['data'].category or '', |
| | rating=float(data['data'].rating or 0), |
| | similarity=round(data['total_score'] / data['count'], 4), |
| | matched_images=data['count'], |
| | image_url=data['best_image'] or '', |
| | ) |
| | for pid, data in sorted_places |
| | ] |
| |
|
| |
|
| | async def search_by_image_url( |
| | db: AsyncSession, |
| | image_url: str, |
| | limit: int = 10, |
| | ) -> list[ImageSearchResult]: |
| | """Search places by image URL.""" |
| | return await retrieve_similar_visuals(db=db, image_url=image_url, limit=limit) |
| |
|
| |
|
| | async def search_by_image_bytes( |
| | db: AsyncSession, |
| | image_bytes: bytes, |
| | limit: int = 10, |
| | ) -> list[ImageSearchResult]: |
| | """Search places by uploading image bytes.""" |
| | return await retrieve_similar_visuals(db=db, image_bytes=image_bytes, limit=limit) |
| |
|