vartur-search-engine / src /db_agent.py
agoor97's picture
Add application files
9e22989
import chromadb
import os, logging, re
import pandas as pd
from chromadb.utils import embedding_functions
from typing import List, Dict, Union, Optional
from datetime import datetime
from src.utils.config import OPENAI_API_KEY, COLLECTION_NAME
class VarturRealEstateSearch:
def __init__(self, persist_directory: str="database"):
"""Initialize the search system with OpenAI embeddings."""
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
if not os.path.exists(persist_directory):
os.makedirs(persist_directory)
self.client = chromadb.PersistentClient(path=persist_directory)
self.embedding_function = embedding_functions.OpenAIEmbeddingFunction(
api_key=OPENAI_API_KEY,
model_name="text-embedding-3-small"
)
self._file_stats = {}
self._collection_initialized = False # Flag to track collection initialization
# call below function
self._initialize_collection()
def _initialize_collection(self):
"""Initialize the collection if it hasn't been already."""
if self._collection_initialized:
return # Avoid redundant initialization
try:
self.collection = self.client.get_or_create_collection(
name=COLLECTION_NAME,
embedding_function=self.embedding_function
)
self._collection_initialized = True # Mark collection as initialized
logging.info(f'Collection {COLLECTION_NAME} is ready')
except Exception as e:
logging.error(f"Error initializing collection: {str(e)}")
raise
def _normalize_text(self, text: str) -> str:
"""Normalize any text input for consistent matching."""
if not text:
return ""
if not isinstance(text, str):
text = str(text)
# Remove extra spaces and convert to lowercase
return " ".join(text.lower().split())
def _generate_unit_variations(self, unit_type: str) -> List[str]:
"""Generate different variations of unit type for better matching."""
variations = set()
# Normalize the input
unit = self._normalize_text(unit_type)
variations.add(unit)
# Handle bedroom variations
bedroom_match = re.match(r'(\d+)\s*(?:bedroom|br|b/r|bed)', unit)
if bedroom_match:
num = bedroom_match.group(1)
variations.update([
f"{num} bedroom",
f"{num} br",
f"{num}br",
f"{num} bed",
f"{num}bed",
f"{num} b/r",
f"{num}b/r",
f"{num} bedrooms"
])
# Handle other common types
type_variations = {
'studio': ['studio', 'studio apartment', 'studio apt'],
'penthouse': ['penthouse', 'ph', 'pent house'],
'townhouse': ['townhouse', 'th', 'town house'],
'villa': ['villa', 'vil'],
'apartment': ['apartment', 'apt', 'appartment']
}
# Add variations for the unit type
for base_type, vars in type_variations.items():
if any(var in unit for var in vars):
variations.update(vars)
return list(variations)
@staticmethod
def clean_price(price: Union[str, float, int]) -> float:
"""Clean price string and convert to float."""
if isinstance(price, (int, float)):
return float(price)
if not isinstance(price, str):
raise ValueError(f"Unsupported price type: {type(price)}")
price = price.strip()
if not price:
raise ValueError("Empty price value")
try:
# Handle million notation (e.g., "1.5M" or "1.5 million")
million_pattern = r'^(\d+\.?\d*)\s*[mM](?:illion)?$'
match = re.match(million_pattern, price)
if match:
return float(match.group(1)) * 1_000_000
# Remove any non-numeric characters except dots and commas
price = ''.join(c for c in price if c.isdigit() or c == '.' or c == ',')
# Remove commas and convert to float
price = price.replace(',', '')
return float(price)
except Exception as e:
raise ValueError(f"Failed to parse price '{price}': {str(e)}")
def prepare_document(self, row: pd.Series) -> str:
"""Create a searchable document from a row of data with normalized text."""
try:
price_value = self.clean_price(row['AskingPrice'])
formatted_price = f"{price_value/1_000_000:.2f}M" if price_value >= 1_000_000 else f"{price_value:,.2f}"
# Create alternative forms for unit type
unit_type = str(row['UnitType'])
unit_variations = self._generate_unit_variations(unit_type)
# Normalize all fields
developer = self._normalize_text(str(row['Developer']))
view = self._normalize_text(str(row['View']))
total_area = str(row['TotalArea']).replace(',', '')
# Create searchable document with all variations
document = (
f"Unit {row['UnitCode']} is a {unit_type}. "
f"Alternative names: {' '.join(unit_variations)}. "
f"Located on floor {row['Floor']} and developed by {developer}. "
f"Features a {view} view and total area of {total_area} square feet. "
f"The asking price is AED {formatted_price}."
)
return document
except Exception as e:
logging.error(f"Error preparing document: {str(e)}, Row: {row.to_dict()}")
raise
def update_file_stats(self, filename: str, count: int):
"""Update file statistics."""
self._file_stats[filename] = {
"count": count,
"last_updated": datetime.now().isoformat()
}
def get_file_stats(self):
"""Get current file statistics."""
return self._file_stats
def remove_file_stats(self, filename: str):
"""Remove file statistics."""
if filename in self._file_stats:
del self._file_stats[filename]
def load_data(self, csv_path: str, reset_collection: bool = False) -> Dict:
"""Load and process data from CSV file."""
try:
filename = os.path.basename(csv_path)
# Handle reset collection if needed
if reset_collection:
existing_ids = [id for id, meta in zip(self.collection.get().get("ids", []),
self.collection.get().get("metadatas", []))
if meta.get("source_file") == filename]
if existing_ids:
self.collection.delete(ids=existing_ids)
self.remove_file_stats(filename)
# Read and process CSV
df = pd.read_csv(csv_path)
documents = []
metadatas = []
ids = []
for idx, row in df.iterrows():
try:
doc = self.prepare_document(row)
# Clean and prepare metadata
metadata = {
"unit_code": str(row["UnitCode"]),
"unit_type": self._normalize_text(str(row["UnitType"])),
"developer": self._normalize_text(str(row["Developer"])),
"price": float(self.clean_price(row["AskingPrice"])),
"total_area": float(str(row["TotalArea"]).replace(',', '')),
"view": self._normalize_text(str(row["View"])),
"floor": str(row["Floor"]),
"source_file": filename,
"indexed_date": datetime.now().isoformat()
}
documents.append(doc)
metadatas.append(metadata)
ids.append(str(row["UnitCode"]))
except Exception as row_error:
logging.warning(f"Error processing row {idx}: {str(row_error)}")
continue
if documents:
# Add documents to collection
self.collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
self.update_file_stats(filename, len(documents))
return {
"status": "success",
"message": f"Added {len(documents)} properties from {filename}",
"count": len(documents)
}
else:
return {
"status": "error",
"message": "No valid data to add",
"count": 0
}
except Exception as e:
logging.error(f"Error loading data: {str(e)}")
return {
"status": "error",
"message": str(e),
"count": 0
}
def _clean_metadata(self, metadata: Dict) -> Dict:
"""Clean and normalize metadata values."""
return {
'unit_code': str(metadata['unit_code']),
'unit_type': self._normalize_text(metadata['unit_type']),
'developer': self._normalize_text(metadata['developer']),
'view': self._normalize_text(metadata['view']),
'floor': str(metadata['floor']),
'price': float(str(metadata['price']).replace(',', '')),
'total_area': float(str(metadata['total_area']).replace(',', ''))
}
def _normalize_floor(self, floor_value: str) -> Optional[int]:
"""Convert floor value to normalized integer for comparison."""
if not floor_value:
return None
floor = str(floor_value).strip().upper()
# Special cases
floor_map = {
'G': 0, # Ground
'GR': 0, # Ground
'LG': 0, # Lower Ground
'UG': 0, # Upper Ground
'B': -1, # Basement
'B1': -1, # Basement 1
'B2': -2, # Basement 2
'B3': -3, # Basement 3
'M': 1, # Mezzanine
'P': 1, # Podium
'P1': 1, # Podium 1
'P2': 2, # Podium 2
'P3': 3, # Podium 3
'L': 1, # Lower level
'PH': 999 # Penthouse (high value for sorting)
}
# Check exact matches
if floor in floor_map:
return floor_map[floor]
# Check prefixes
for prefix, value in floor_map.items():
if floor.startswith(prefix):
# Try to get number after prefix
num_part = ''.join(filter(str.isdigit, floor[len(prefix):]))
if num_part:
if prefix in ['B', 'B1', 'B2', 'B3']:
return -int(num_part) # Negative for basement levels
return int(num_part)
return value
# Try to extract just numbers
nums = ''.join(filter(str.isdigit, floor))
if nums:
return int(nums)
return None
def _in_range(self, value: float, min_val: float = None, max_val: float = None) -> bool:
"""Check if value is within specified range."""
if min_val is not None and value < min_val:
return False
if max_val is not None and value > max_val:
return False
return True
def _generate_variations(self, text: str, property_type: str) -> List[str]:
"""Generate variations based on property type."""
if not text:
return []
variations = set([text.lower()])
if property_type == "unit_type":
# Handle bedroom variations
bedroom_match = re.match(r'(\d+)\s*(?:bedroom|br|b/r|bed)', text.lower())
if bedroom_match:
num = bedroom_match.group(1)
variations.update([
f"{num} bedroom",
f"{num}bedroom",
f"{num} br",
f"{num}br",
f"{num} b/r",
f"{num}b/r",
f"{num} bed",
f"{num}bed"
])
# Handle special types
type_map = {
'studio': ['studio', 'studio apartment', 'studio apt'],
'penthouse': ['penthouse', 'ph', 'pent house'],
'townhouse': ['townhouse', 'th', 'town house'],
'duplex': ['duplex', 'du', 'dup'],
'villa': ['villa', 'vil', 'v'],
'apartment': ['apartment', 'apt', 'flat']
}
for base_type, vars in type_map.items():
if any(var in text.lower() for var in vars):
variations.update(vars)
elif property_type == "view":
# Handle view variations
view_map = {
'sea': ['sea', 'ocean', 'water', 'beach'],
'burj': ['burj', 'burj khalifa', 'downtown'],
'garden': ['garden', 'park', 'landscape'],
'marina': ['marina', 'yacht', 'boat'],
'palm': ['palm', 'jumeirah', 'atlantis']
}
for base_view, vars in view_map.items():
if any(var in text.lower() for var in vars):
variations.update(vars)
return list(variations)
def _check_floor_range(self, floor: str, min_floor: str = None, max_floor: str = None) -> bool:
"""Check if floor value is within specified range."""
if min_floor is None and max_floor is None:
return True
floor_num = self._normalize_floor(floor)
min_floor_num = self._normalize_floor(min_floor) if min_floor else None
max_floor_num = self._normalize_floor(max_floor) if max_floor else None
if floor_num is None: # If we can't parse the floor, don't filter it out
return True
if min_floor_num is not None and floor_num < min_floor_num:
return False
if max_floor_num is not None and floor_num > max_floor_num:
return False
return True
def _calculate_similarity(self, query: str, metadata: Dict, base_score: float) -> float:
"""Calculate enhanced similarity score."""
score = base_score
if not query:
return score
normalized_query = self._normalize_text(query)
# Unit type matching
if any(var in metadata['unit_type']
for var in self._generate_unit_variations(normalized_query)):
score *= 1.3 # 30% boost for unit type match
# View matching
if normalized_query in metadata['view']:
score *= 1.2 # 20% boost for view match
# Developer matching
if normalized_query in metadata['developer']:
score *= 1.15 # 15% boost for developer match
return min(score, 1.0) # Cap at 1.0
def search(self, query: str, n_results: int = 10) -> List[Dict]:
"""Base vectorized search implementation."""
try:
if not self.collection.count():
return []
if not query or query.strip() == "*":
return []
# Process query with variations
normalized_query = self._normalize_text(query)
variations = self._generate_unit_variations(normalized_query)
# Build semantic query with boosted relevant terms
search_query = f"{normalized_query} {' '.join(variations)}"
# Get raw results with buffer for better scoring
buffer_multiplier = 2
results = self.collection.query(
query_texts=[search_query],
n_results=min(n_results * buffer_multiplier, self.collection.count()),
include=['metadatas', 'documents', 'distances']
)
if not results['ids'][0]:
return []
processed_results = []
for i, doc_id in enumerate(results['ids'][0]):
try:
metadata = results['metadatas'][0][i]
cleaned_metadata = self._clean_metadata(metadata)
# Calculate enhanced similarity score
base_score = 1 / (1 + float(results['distances'][0][i]))
final_score = self._calculate_similarity(query, cleaned_metadata, base_score)
processed_results.append({
'id': doc_id,
'document': results['documents'][0][i],
'metadata': cleaned_metadata,
'similarity': final_score
})
except Exception as e:
logging.warning(f"Error processing result {doc_id}: {e}")
continue
# Sort by enhanced similarity score
processed_results.sort(key=lambda x: x['similarity'], reverse=True)
return processed_results[:n_results]
except Exception as e:
logging.error(f"Search error: {e}", exc_info=True)
return []
def filter_search(self,
query: str = None,
min_price: float = None,
max_price: float = None,
min_area: float = None,
max_area: float = None,
min_floor: str = None,
max_floor: str = None,
unit_type: str = None,
developer: str = None,
view: str = None,
n_results: int = 10) -> List[Dict]:
"""Combined search and filter implementation."""
try:
# Get base results
if query:
base_results = self.search(query, n_results=self.collection.count())
else:
results = self.collection.get(include=['metadatas', 'documents'])
base_results = [{
'id': id,
'document': doc,
'metadata': self._clean_metadata(meta),
'similarity': 1.0
} for id, doc, meta in zip(
results.get('ids', []),
results.get('documents', []),
results.get('metadatas', [])
)]
if not base_results:
return []
# Apply filters with enhanced scoring
filtered_results = []
for result in base_results:
try:
metadata = result['metadata']
score = result['similarity']
# Basic filters
if not all([
self._in_range(metadata['price'], min_price, max_price),
self._in_range(metadata['total_area'], min_area, max_area),
self._check_floor_range(metadata['floor'], min_floor, max_floor)
]):
continue
# Enhanced unit type matching
if unit_type:
unit_vars = self._generate_unit_variations(metadata['unit_type'])
if any(unit_type.lower() == var.lower() for var in unit_vars):
score *= 1.3 # Exact match
elif any(unit_type.lower() in var.lower() for var in unit_vars):
score *= 1.15 # Partial match
else:
continue
# Enhanced developer matching
if developer:
if developer.lower() == metadata['developer'].lower():
score *= 1.25 # Exact match
elif developer.lower() in metadata['developer'].lower():
score *= 1.1 # Partial match
else:
continue
# Enhanced view matching
if view:
if view.lower() == metadata['view'].lower():
score *= 1.25 # Exact match
elif view.lower() in metadata['view'].lower():
score *= 1.1 # Partial match
else:
continue
result['similarity'] = min(score, 1.0) # Cap final score
filtered_results.append(result)
except Exception as e:
logging.warning(f"Filter error for {result['id']}: {e}")
continue
# Sort by final similarity score
filtered_results.sort(key=lambda x: x['similarity'], reverse=True)
return filtered_results[:n_results]
except Exception as e:
logging.error(f"Filter search error: {e}", exc_info=True)
return []
def delete_properties(self, files: List[str] = None):
"""Delete properties by source filename."""
try:
if files:
for filename in files:
results = self.collection.get()
to_delete = [
id for id, meta in zip(results.get("ids", []),
results.get("metadatas", []))
if meta.get("source_file") == filename
]
if to_delete:
self.collection.delete(ids=to_delete)
self.remove_file_stats(filename)
logging.info(f"Deleted {len(to_delete)} properties from {filename}")
else:
logging.warning(f"No properties found for file {filename}")
except Exception as e:
logging.error(f"Error deleting properties: {str(e)}")
def get_collection_stats(self) -> Dict:
"""Get collection statistics."""
try:
total_properties = sum(stat["count"] for stat in self._file_stats.values())
stats = {
"total_properties": total_properties,
"file_stats": self._file_stats,
"collection_name": self.collection.name,
"last_updated": max((stat["last_updated"] for stat in self._file_stats.values()),
default=None) if self._file_stats else None
}
return stats
except Exception as e:
logging.error(f"Error getting collection stats: {str(e)}")
return {"error": str(e)}