Spaces:
Sleeping
Sleeping
| """ | |
| Vector database manager for NBA data using ChromaDB and sentence-transformers. | |
| """ | |
| import os | |
| import pandas as pd | |
| import chromadb | |
| from chromadb.config import Settings | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Dict, Optional | |
| import json | |
| class NBAVectorDB: | |
| """ | |
| Manages vector embeddings and semantic search for NBA data. | |
| Uses sentence-transformers for embeddings and ChromaDB for storage. | |
| """ | |
| def __init__(self, csv_path: str, collection_name: str = "nba_data", persist_directory: str = "./chroma_db"): | |
| """ | |
| Initialize the vector database. | |
| Args: | |
| csv_path: Path to the NBA CSV file | |
| collection_name: Name of the ChromaDB collection | |
| persist_directory: Directory to persist the vector database | |
| """ | |
| self.csv_path = csv_path | |
| self.collection_name = collection_name | |
| self.persist_directory = persist_directory | |
| # Initialize embedding model (open-source, runs locally) | |
| # Using all-MiniLM-L6-v2: fast, good quality, 384 dimensions | |
| print("Loading embedding model...") | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| print("Embedding model loaded!") | |
| # Initialize ChromaDB client | |
| os.makedirs(persist_directory, exist_ok=True) | |
| self.client = chromadb.PersistentClient( | |
| path=persist_directory, | |
| settings=Settings(anonymized_telemetry=False) | |
| ) | |
| # Get or create collection | |
| self.collection = self.client.get_or_create_collection( | |
| name=collection_name, | |
| metadata={"description": "NBA 2024-25 season data"} | |
| ) | |
| # Check if collection is empty and needs indexing | |
| if self.collection.count() == 0: | |
| print("Vector database is empty. Indexing CSV data...") | |
| self._index_csv() | |
| else: | |
| print(f"Vector database loaded with {self.collection.count()} records") | |
| def _create_text_representation(self, row: pd.Series) -> str: | |
| """ | |
| Convert a DataFrame row to a text representation for embedding. | |
| Args: | |
| row: A pandas Series representing a row | |
| Returns: | |
| str: Text representation of the row | |
| """ | |
| # Create a natural language description of the row | |
| parts = [] | |
| if 'Player' in row: | |
| parts.append(f"Player: {row['Player']}") | |
| if 'Tm' in row: | |
| parts.append(f"Team: {row['Tm']}") | |
| if 'Opp' in row: | |
| parts.append(f"Opponent: {row['Opp']}") | |
| if 'Res' in row: | |
| parts.append(f"Result: {'Win' if row['Res'] == 'W' else 'Loss'}") | |
| if 'PTS' in row and pd.notna(row['PTS']): | |
| parts.append(f"Points: {row['PTS']}") | |
| if 'AST' in row and pd.notna(row['AST']): | |
| parts.append(f"Assists: {row['AST']}") | |
| if 'TRB' in row and pd.notna(row['TRB']): | |
| parts.append(f"Rebounds: {row['TRB']}") | |
| if 'FG%' in row and pd.notna(row['FG%']): | |
| parts.append(f"Field Goal Percentage: {row['FG%']:.3f}") | |
| if '3P%' in row and pd.notna(row['3P%']): | |
| parts.append(f"Three Point Percentage: {row['3P%']:.3f}") | |
| if 'Data' in row: | |
| parts.append(f"Date: {row['Data']}") | |
| return ". ".join(parts) | |
| def _index_csv(self): | |
| """ | |
| Read CSV file, create embeddings, and store in ChromaDB. | |
| """ | |
| print(f"Reading CSV from {self.csv_path}...") | |
| df = pd.read_csv(self.csv_path) | |
| print(f"Creating embeddings for {len(df)} records...") | |
| texts = [] | |
| metadatas = [] | |
| ids = [] | |
| # Process in batches for efficiency | |
| batch_size = 100 | |
| total_batches = (len(df) + batch_size - 1) // batch_size | |
| for batch_idx in range(0, len(df), batch_size): | |
| batch_df = df.iloc[batch_idx:batch_idx + batch_size] | |
| batch_num = (batch_idx // batch_size) + 1 | |
| batch_texts = [] | |
| batch_metadatas = [] | |
| batch_ids = [] | |
| for idx, row in batch_df.iterrows(): | |
| # Create text representation | |
| text = self._create_text_representation(row) | |
| batch_texts.append(text) | |
| # Store metadata (original row data as JSON) | |
| metadata = { | |
| 'row_index': int(idx), | |
| 'player': str(row.get('Player', '')), | |
| 'team': str(row.get('Tm', '')), | |
| 'opponent': str(row.get('Opp', '')), | |
| 'result': str(row.get('Res', '')), | |
| 'points': float(row.get('PTS', 0)) if pd.notna(row.get('PTS')) else 0.0, | |
| 'date': str(row.get('Data', '')), | |
| } | |
| batch_metadatas.append(metadata) | |
| batch_ids.append(f"row_{idx}") | |
| # Generate embeddings for this batch | |
| print(f"Processing batch {batch_num}/{total_batches} ({len(batch_texts)} records)...") | |
| embeddings = self.embedding_model.encode( | |
| batch_texts, | |
| show_progress_bar=False, | |
| convert_to_numpy=True | |
| ).tolist() | |
| # Add to ChromaDB | |
| self.collection.add( | |
| embeddings=embeddings, | |
| documents=batch_texts, | |
| metadatas=batch_metadatas, | |
| ids=batch_ids | |
| ) | |
| texts.extend(batch_texts) | |
| metadatas.extend(batch_metadatas) | |
| ids.extend(batch_ids) | |
| print(f"Successfully indexed {len(df)} records in vector database!") | |
| def search(self, query: str, n_results: int = 10) -> List[Dict]: | |
| """ | |
| Perform semantic search on the NBA data. | |
| Args: | |
| query: Natural language query | |
| n_results: Number of results to return | |
| Returns: | |
| List of dictionaries containing search results with metadata | |
| """ | |
| # Generate embedding for the query | |
| query_embedding = self.embedding_model.encode( | |
| query, | |
| convert_to_numpy=True | |
| ).tolist() | |
| # Search in ChromaDB | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=n_results, | |
| include=['documents', 'metadatas', 'distances'] | |
| ) | |
| # Format results | |
| formatted_results = [] | |
| if results['ids'] and len(results['ids'][0]) > 0: | |
| for i in range(len(results['ids'][0])): | |
| formatted_results.append({ | |
| 'id': results['ids'][0][i], | |
| 'document': results['documents'][0][i], | |
| 'metadata': results['metadatas'][0][i], | |
| 'distance': results['distances'][0][i], | |
| 'similarity': 1 - results['distances'][0][i] # Convert distance to similarity | |
| }) | |
| return formatted_results | |
| def get_original_row(self, row_index: int) -> Optional[pd.Series]: | |
| """ | |
| Retrieve the original CSV row by index. | |
| Args: | |
| row_index: Index of the row in the original CSV | |
| Returns: | |
| pandas Series or None if not found | |
| """ | |
| try: | |
| df = pd.read_csv(self.csv_path) | |
| if 0 <= row_index < len(df): | |
| return df.iloc[row_index] | |
| except Exception as e: | |
| print(f"Error retrieving row {row_index}: {e}") | |
| return None | |
| # Global instance (will be initialized when needed) | |
| _vector_db_instance: Optional[NBAVectorDB] = None | |
| def get_vector_db(csv_path: str) -> NBAVectorDB: | |
| """ | |
| Get or create the global vector database instance. | |
| Args: | |
| csv_path: Path to the CSV file | |
| Returns: | |
| NBAVectorDB instance | |
| """ | |
| global _vector_db_instance | |
| if _vector_db_instance is None or _vector_db_instance.csv_path != csv_path: | |
| _vector_db_instance = NBAVectorDB(csv_path) | |
| return _vector_db_instance | |