| |
|
|
| import re |
| import os |
| from typing import List, Dict |
| from rank_bm25 import BM25Okapi |
| from sqlalchemy import create_engine |
| from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column |
| from dotenv import load_dotenv |
| |
| from .logging_utils import get_logger |
|
|
| |
| load_dotenv() |
|
|
| logger = get_logger("bm25_utils") |
|
|
| |
| TOKEN_RE = re.compile(r"\b\w+\b") |
|
|
| def tokenize(text: str) -> List[str]: |
| """Tokenize text into lowercase words.""" |
| if not text: |
| return [] |
| return TOKEN_RE.findall(text.lower()) |
|
|
|
|
| |
| class Base(DeclarativeBase): |
| pass |
|
|
|
|
| class Whole_Blogs(Base): |
| __tablename__ = "travel_blogs" |
|
|
| id: Mapped[int] = mapped_column(primary_key=True) |
| blog_url: Mapped[str] |
| page_url: Mapped[str] = mapped_column(unique=True, nullable=False) |
| page_title: Mapped[str] |
| page_description: Mapped[str] |
| page_author: Mapped[str] |
| location_name: Mapped[str] |
| latitude: Mapped[float] |
| longitude: Mapped[float] |
| content: Mapped[str] |
| |
|
|
| def __repr__(self) -> str: |
| return f"Whole_Blogs(id={self.id!r}, location_name={self.location_name!r}, page_title={self.page_title!r})" |
|
|
|
|
| |
| _cached_posts = None |
| _cached_bm25 = None |
|
|
|
|
| def _load_blogs_from_db(): |
| """Load blog posts from database and build BM25 index.""" |
| global _cached_posts, _cached_bm25 |
| |
| database_url = os.getenv("DATABASE_URL") |
| if not database_url: |
| logger.error("DATABASE_URL not found in environment variables") |
| raise ValueError("DATABASE_URL not found in environment variables") |
| |
| logger.info("Loading blog posts from database...") |
| engine = create_engine(database_url) |
| |
| posts = [] |
| corpus = [] |
| |
| with Session(engine) as session: |
| blog_posts = session.query(Whole_Blogs).all() |
| logger.info(f"Loaded {len(blog_posts)} blog posts from database") |
| |
| for post in blog_posts: |
| posts.append({ |
| "id": post.id, |
| "location_name": post.location_name, |
| "page_title": post.page_title, |
| "page_description": post.page_description, |
| "page_author": post.page_author, |
| "page_url": post.page_url, |
| "blog_url": post.blog_url, |
| "latitude": post.latitude, |
| "longitude": post.longitude, |
| "content": post.content, |
| }) |
| |
| |
| search_text = f"{post.page_title} {post.page_description} {post.content}" |
| corpus.append(search_text) |
| |
| |
| logger.info("Building BM25 index...") |
| tokenized_corpus = [tokenize(doc) for doc in corpus] |
| bm25 = BM25Okapi(tokenized_corpus) |
| logger.info(f"BM25 index built with {len(posts)} documents") |
| |
| |
| _cached_posts = posts |
| _cached_bm25 = bm25 |
| |
| return posts, bm25 |
|
|
|
|
| def search_bm25(query: str, top_n: int = 12) -> List[Dict]: |
| """ |
| Search blog posts using BM25. |
| |
| Args: |
| query: Search query string |
| top_n: Number of top results to return |
| |
| Returns: |
| List of dicts with search results |
| """ |
| global _cached_posts, _cached_bm25 |
| |
| |
| if _cached_posts is None or _cached_bm25 is None: |
| _load_blogs_from_db() |
| |
| if not query.strip(): |
| return [] |
| |
| |
| tokenized_query = tokenize(query) |
| |
| if not tokenized_query: |
| return [] |
| |
| |
| scores = _cached_bm25.get_scores(tokenized_query) |
| |
| |
| top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n] |
| |
| |
| results = [] |
| for idx in top_indices: |
| post = _cached_posts[idx] |
| |
| |
| content_preview = post.get("content", "")[:300] |
| if len(post.get("content", "")) > 300: |
| content_preview += "..." |
| |
| |
| location_parts = post.get("location_name", "").split(",") |
| country = location_parts[-1].strip() if len(location_parts) > 1 else "" |
| |
| results.append({ |
| "destination": post.get("location_name", "Unknown"), |
| "country": country, |
| "lat": float(post.get("latitude", 0)) if post.get("latitude") else None, |
| "lon": float(post.get("longitude", 0)) if post.get("longitude") else None, |
| "score": float(scores[idx]), |
| "page_title": post.get("page_title", ""), |
| "page_url": post.get("page_url", ""), |
| "blog_url": post.get("blog_url", ""), |
| "author": post.get("page_author", ""), |
| "description": post.get("page_description", ""), |
| "content_preview": content_preview, |
| "full_content": post.get('content', ""), |
| }) |
| |
| return results |