from datasets import load_dataset from typing import List, Optional, Dict, Any from datetime import datetime from models import ArticleResponse, ArticleDetail, Argument, FiltersResponse from collections import Counter from functools import lru_cache from whoosh import index from whoosh.fields import Schema, TEXT, ID from whoosh.qparser import QueryParser from whoosh.filedb.filestore import RamStorage from dateutil import parser as date_parser import numpy as np from sentence_transformers import SentenceTransformer # Constants SEARCH_CACHE_MAX_SIZE = 1000 LABOR_SCORE_WEIGHT = 0.1 # Weight for labor score in relevance calculation DATE_RANGE_START = "2022-01-01" DATE_RANGE_END = "2025-12-31" class DataLoader: """ Handles loading, indexing, and searching of AI labor economy articles. Uses Whoosh for full-text search and maintains in-memory data structures for fast filtering and pagination. """ def __init__(self): self.dataset = None self.articles = [] self.articles_by_id = {} # ID -> article mapping self.filter_options = None # Initialize Whoosh search index for full-text search self.search_schema = Schema( id=ID(stored=True), title=TEXT(stored=False), summary=TEXT(stored=False), content=TEXT(stored=False) # Combined title + summary for search ) # Create in-memory index using RamStorage storage = RamStorage() self.search_index = storage.create_index(self.search_schema) self.query_parser = QueryParser("content", self.search_schema) # Dense retrieval components (lazy-loaded for efficiency) self.embeddings = None # Article embeddings from dataset self.embedding_model = None # SentenceTransformer model self.model_path = "ibm-granite/granite-embedding-english-r2" # Note: Using lru_cache for search caching instead of manual cache management async def load_dataset(self): """Load and process the HuggingFace dataset""" # Load dataset self.dataset = load_dataset("yjernite/ai-economy-labor-articles-annotated-embed", split="train") # Convert to list of dicts for easier processing self.articles = [] # Prepare Whoosh index writer writer = self.search_index.writer() for i, row in enumerate(self.dataset): # Parse date using dateutil (more flexible than pandas) date = date_parser.parse(row['date']) if isinstance(row['date'], str) else row['date'] # Parse arguments arguments = [] if row.get('arguments'): for arg in row['arguments']: arguments.append(Argument( argument_quote=arg.get('argument_quote', []), argument_summary=arg.get('argument_summary', ''), argument_source=arg.get('argument_source', ''), argument_type=arg.get('argument_type', ''), )) article = { 'id': i, 'title': row.get('title', ''), 'source': row.get('source', ''), 'url': row.get('url', ''), 'date': date, 'summary': row.get('summary', ''), 'ai_labor_relevance': row.get('ai_labor_relevance', 0), 'document_type': row.get('document_type', ''), 'author_type': row.get('author_type', ''), 'document_topics': row.get('document_topics', []), 'text': row.get('text', ''), 'arguments': arguments, } self.articles.append(article) self.articles_by_id[i] = article # Add to search index search_content = f"{article['title']} {article['summary']}" writer.add_document( id=str(i), title=article['title'], summary=article['summary'], content=search_content ) # Commit search index writer.commit() print(f"DEBUG: Search index populated with {len(self.articles)} articles") # Load pre-computed embeddings for dense retrieval print("DEBUG: Loading pre-computed embeddings...") raw_embeddings = np.array(self.dataset['embeddings-granite']) # Normalize embeddings for cosine similarity self.embeddings = raw_embeddings / np.linalg.norm(raw_embeddings, axis=1, keepdims=True) print(f"DEBUG: Loaded {len(self.embeddings)} article embeddings") # Build filter options self._build_filter_options() def _build_filter_options(self): """Build available filter options from the dataset""" document_types = sorted(set(article['document_type'] for article in self.articles if article['document_type'])) author_types = sorted(set(article['author_type'] for article in self.articles if article['author_type'])) # Flatten all topics all_topics = [] for article in self.articles: if article['document_topics']: all_topics.extend(article['document_topics']) topics = sorted(set(all_topics)) # Date range - fixed for research period min_date = DATE_RANGE_START max_date = DATE_RANGE_END # Relevance range relevances = [article['ai_labor_relevance'] for article in self.articles if article['ai_labor_relevance'] is not None] min_relevance = min(relevances) if relevances else 0 max_relevance = max(relevances) if relevances else 10 self.filter_options = FiltersResponse( document_types=document_types, author_types=author_types, topics=topics, date_range={"min_date": min_date, "max_date": max_date}, relevance_range={"min_relevance": min_relevance, "max_relevance": max_relevance} ) def get_filter_options(self) -> FiltersResponse: """Get all available filter options""" return self.filter_options def _filter_articles( self, document_types: Optional[List[str]] = None, author_types: Optional[List[str]] = None, min_relevance: Optional[float] = None, max_relevance: Optional[float] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, topics: Optional[List[str]] = None, search_query: Optional[str] = None, search_type: Optional[str] = None, ) -> List[Dict[str, Any]]: """Filter articles based on criteria""" filtered = self.articles if document_types: filtered = [a for a in filtered if a['document_type'] in document_types] if author_types: filtered = [a for a in filtered if a['author_type'] in author_types] if min_relevance is not None: filtered = [a for a in filtered if a['ai_labor_relevance'] >= min_relevance] if max_relevance is not None: filtered = [a for a in filtered if a['ai_labor_relevance'] <= max_relevance] if start_date: start_dt = date_parser.parse(start_date) filtered = [a for a in filtered if a['date'] >= start_dt] if end_date: end_dt = date_parser.parse(end_date) filtered = [a for a in filtered if a['date'] <= end_dt] if topics: filtered = [a for a in filtered if any(topic in a['document_topics'] for topic in topics)] if search_query: print(f"DEBUG: Applying search filter for query: '{search_query}' with type: '{search_type}'") if search_type == 'dense': # For dense search, get similarity scores for all articles dense_scores = self._dense_search_all_articles(search_query) dense_score_dict = {idx: score for idx, score in dense_scores} # Attach dense scores to filtered articles and filter by similarity threshold filtered_with_scores = [] for article in filtered: article_id = article['id'] if article_id in dense_score_dict: # Create a copy to avoid modifying the original article_copy = article.copy() article_copy['dense_score'] = dense_score_dict[article_id] # Only include articles with meaningful similarity (> 0.8) if dense_score_dict[article_id] > 0.8: filtered_with_scores.append(article_copy) filtered = filtered_with_scores print(f"DEBUG: After dense search filtering: {len(filtered)} articles remaining") else: # Existing exact search logic - inline the matching check search_results = self._search_articles(search_query, search_type) filtered = [a for a in filtered if a.get('id') in search_results] print(f"DEBUG: After exact search filtering: {len(filtered)} articles remaining") return filtered def _search_articles(self, search_query: str, search_type: Optional[str] = None) -> Dict[int, float]: """Search articles using Whoosh and return article_id -> score mapping Note: Dense search is handled separately in _filter_articles method. This method only handles exact/Whoosh search. """ if not search_query: return {} # Use cached Whoosh search (lru_cache handles caching automatically) return self._cached_whoosh_search(search_query) @lru_cache(maxsize=SEARCH_CACHE_MAX_SIZE) def _cached_whoosh_search(self, search_query: str) -> Dict[int, float]: """Cached version of Whoosh search using lru_cache""" return self._whoosh_search(search_query) def _whoosh_search(self, search_query: str) -> Dict[int, float]: """Perform search using Whoosh index""" try: with self.search_index.searcher() as searcher: # Parse query - Whoosh handles tokenization automatically query = self.query_parser.parse(search_query) results = searcher.search(query, limit=None) # Get all results print(f"DEBUG: Search query '{search_query}' found {len(results)} results") # Return mapping of article_id -> normalized score article_scores = {} max_score = max((r.score for r in results), default=1.0) for result in results: article_id = int(result['id']) # Normalize score to 0-1 range normalized_score = result.score / max_score if max_score > 0 else 0.0 article_scores[article_id] = normalized_score print(f"DEBUG: Returning {len(article_scores)} scored articles") return article_scores except Exception as e: print(f"Search error: {e}") return {} def _initialize_embedding_model(self): """Lazy initialization of embedding model (CPU-only)""" if self.embedding_model is None: print("DEBUG: Initializing embedding model (CPU-only)...") # Force CPU usage and disable problematic features import os os.environ['CUDA_VISIBLE_DEVICES'] = '' # Initialize model with CPU device and specific config self.embedding_model = SentenceTransformer( self.model_path, device='cpu', model_kwargs={ 'dtype': 'float32', # Fixed deprecation warning 'attn_implementation': 'eager' # Use eager attention instead of flash attention } ) print("DEBUG: Embedding model initialized") @lru_cache(maxsize=100) # Cache encoded queries (smaller cache for this) def _encode_query_cached(self, query: str) -> tuple: """Cache-friendly version of query encoding (returns tuple for hashing)""" embedding = self._encode_query_internal(query) return tuple(embedding.tolist()) # Convert to tuple for caching def _encode_query(self, query: str) -> np.ndarray: """Encode a query string into an embedding vector""" cached_result = self._encode_query_cached(query) return np.array(cached_result) # Convert back to numpy array def _encode_query_internal(self, query: str) -> np.ndarray: """Internal method that does the actual encoding""" self._initialize_embedding_model() query_embedding = self.embedding_model.encode([query]) # Normalize for cosine similarity return query_embedding[0] / np.linalg.norm(query_embedding[0]) def _dense_search_all_articles(self, query: str, k: int = None) -> List[tuple]: """ Perform dense retrieval across ALL articles and return (index, score) pairs. This computes all similarities upfront for maximum flexibility. """ if self.embeddings is None: print("ERROR: Embeddings not loaded") return [] print(f"DEBUG: Performing dense search for query: '{query}'") # Encode query query_embed = self._encode_query(query) # Compute similarities with ALL articles similarities = np.dot(self.embeddings, query_embed) # Get all articles with their similarity scores article_scores = [(i, float(similarities[i])) for i in range(len(similarities))] # Sort by similarity (highest first) article_scores.sort(key=lambda x: x[1], reverse=True) # Apply k limit if specified if k is not None: article_scores = article_scores[:k] print(f"DEBUG: Dense search found {len(article_scores)} scored articles") return article_scores def _calculate_query_score(self, article: Dict[str, Any], search_query: str, search_type: Optional[str] = None) -> float: """Calculate query relevance score based on search type""" if not search_query: return 0.0 if search_type == 'dense': # For dense search, return the pre-computed similarity score return article.get('dense_score', 0.0) else: # Existing exact search logic using Whoosh search_results = self._search_articles(search_query, search_type) article_id = article.get('id') # Return Whoosh score or 0.0 if not found return search_results.get(article_id, 0.0) def _sort_by_relevance(self, articles: List[Dict[str, Any]], search_query: str, search_type: str = 'exact') -> List[Dict[str, Any]]: """Sort articles by relevance score (query score + labor score)""" def relevance_key(article): query_score = self._calculate_query_score(article, search_query, search_type) labor_score = article.get('ai_labor_relevance', 0) / 10.0 # Normalize to 0-1 # Prioritize query score but include labor score as tiebreaker return query_score + (labor_score * LABOR_SCORE_WEIGHT) return sorted(articles, key=relevance_key, reverse=True) def get_articles( self, page: int = 1, limit: int = 20, **filters ) -> List[ArticleResponse]: """Get filtered and paginated articles""" # Extract sort_by, search_query, and search_type for special handling sort_by = filters.pop('sort_by', 'date') search_query = filters.get('search_query') search_type = filters.get('search_type', 'exact') filtered_articles = self._filter_articles(**filters) # Apply sorting if sort_by == 'score' and search_query: # Sort by query relevance score descending, then by labor score filtered_articles = self._sort_by_relevance(filtered_articles, search_query, search_type) else: # Sort by date (oldest first) - default filtered_articles.sort(key=lambda x: x['date'], reverse=False) # Paginate start_idx = (page - 1) * limit end_idx = start_idx + limit page_articles = filtered_articles[start_idx:end_idx] # Convert to response models - use the original ID from the sorted/filtered results return [ ArticleResponse( id=article['id'], title=article['title'], source=article['source'], url=article['url'], date=article['date'], summary=article['summary'], ai_labor_relevance=article['ai_labor_relevance'], query_score=self._calculate_query_score(article, search_query or '', search_type), document_type=article['document_type'], author_type=article['author_type'], document_topics=article['document_topics'], ) for article in page_articles ] def get_articles_count(self, **filters) -> int: """Get count of articles matching filters""" filtered_articles = self._filter_articles(**filters) return len(filtered_articles) def get_filter_counts(self, filter_type: str, **filters) -> Dict[str, int]: """Get counts for each option in a specific filter type, given other filters""" # Remove the current filter type from filters to avoid circular filtering filters_copy = filters.copy() filters_copy.pop(filter_type, None) # Get base filtered articles (without the current filter type) base_filtered = self._filter_articles(**filters_copy) # Extract values based on filter type and count with Counter if filter_type == 'document_types': values = [article.get('document_type') for article in base_filtered if article.get('document_type')] elif filter_type == 'author_types': values = [article.get('author_type') for article in base_filtered if article.get('author_type')] elif filter_type == 'topics': values = [topic for article in base_filtered for topic in article.get('document_topics', []) if topic] else: return {} return dict(Counter(values)) def get_article_detail(self, article_id: int) -> ArticleDetail: """Get detailed article by ID""" if article_id not in self.articles_by_id: raise ValueError(f"Article ID {article_id} not found") article = self.articles_by_id[article_id] return ArticleDetail( id=article['id'], title=article['title'], source=article['source'], url=article['url'], date=article['date'], summary=article['summary'], ai_labor_relevance=article['ai_labor_relevance'], query_score=0.0, # Detail view doesn't have search context document_type=article['document_type'], author_type=article['author_type'], document_topics=article['document_topics'], text=article['text'], arguments=article['arguments'], )