Spaces:
Sleeping
Sleeping
import chromadb | |
import os, logging, re | |
import pandas as pd | |
from chromadb.utils import embedding_functions | |
from typing import List, Dict, Union, Optional | |
from datetime import datetime | |
from src.utils.config import OPENAI_API_KEY, COLLECTION_NAME | |
class VarturRealEstateSearch: | |
def __init__(self, persist_directory: str="database"): | |
"""Initialize the search system with OpenAI embeddings.""" | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
if not os.path.exists(persist_directory): | |
os.makedirs(persist_directory) | |
self.client = chromadb.PersistentClient(path=persist_directory) | |
self.embedding_function = embedding_functions.OpenAIEmbeddingFunction( | |
api_key=OPENAI_API_KEY, | |
model_name="text-embedding-3-small" | |
) | |
self._file_stats = {} | |
self._collection_initialized = False # Flag to track collection initialization | |
# call below function | |
self._initialize_collection() | |
def _initialize_collection(self): | |
"""Initialize the collection if it hasn't been already.""" | |
if self._collection_initialized: | |
return # Avoid redundant initialization | |
try: | |
self.collection = self.client.get_or_create_collection( | |
name=COLLECTION_NAME, | |
embedding_function=self.embedding_function | |
) | |
self._collection_initialized = True # Mark collection as initialized | |
logging.info(f'Collection {COLLECTION_NAME} is ready') | |
except Exception as e: | |
logging.error(f"Error initializing collection: {str(e)}") | |
raise | |
def _normalize_text(self, text: str) -> str: | |
"""Normalize any text input for consistent matching.""" | |
if not text: | |
return "" | |
if not isinstance(text, str): | |
text = str(text) | |
# Remove extra spaces and convert to lowercase | |
return " ".join(text.lower().split()) | |
def _generate_unit_variations(self, unit_type: str) -> List[str]: | |
"""Generate different variations of unit type for better matching.""" | |
variations = set() | |
# Normalize the input | |
unit = self._normalize_text(unit_type) | |
variations.add(unit) | |
# Handle bedroom variations | |
bedroom_match = re.match(r'(\d+)\s*(?:bedroom|br|b/r|bed)', unit) | |
if bedroom_match: | |
num = bedroom_match.group(1) | |
variations.update([ | |
f"{num} bedroom", | |
f"{num} br", | |
f"{num}br", | |
f"{num} bed", | |
f"{num}bed", | |
f"{num} b/r", | |
f"{num}b/r", | |
f"{num} bedrooms" | |
]) | |
# Handle other common types | |
type_variations = { | |
'studio': ['studio', 'studio apartment', 'studio apt'], | |
'penthouse': ['penthouse', 'ph', 'pent house'], | |
'townhouse': ['townhouse', 'th', 'town house'], | |
'villa': ['villa', 'vil'], | |
'apartment': ['apartment', 'apt', 'appartment'] | |
} | |
# Add variations for the unit type | |
for base_type, vars in type_variations.items(): | |
if any(var in unit for var in vars): | |
variations.update(vars) | |
return list(variations) | |
def clean_price(price: Union[str, float, int]) -> float: | |
"""Clean price string and convert to float.""" | |
if isinstance(price, (int, float)): | |
return float(price) | |
if not isinstance(price, str): | |
raise ValueError(f"Unsupported price type: {type(price)}") | |
price = price.strip() | |
if not price: | |
raise ValueError("Empty price value") | |
try: | |
# Handle million notation (e.g., "1.5M" or "1.5 million") | |
million_pattern = r'^(\d+\.?\d*)\s*[mM](?:illion)?$' | |
match = re.match(million_pattern, price) | |
if match: | |
return float(match.group(1)) * 1_000_000 | |
# Remove any non-numeric characters except dots and commas | |
price = ''.join(c for c in price if c.isdigit() or c == '.' or c == ',') | |
# Remove commas and convert to float | |
price = price.replace(',', '') | |
return float(price) | |
except Exception as e: | |
raise ValueError(f"Failed to parse price '{price}': {str(e)}") | |
def prepare_document(self, row: pd.Series) -> str: | |
"""Create a searchable document from a row of data with normalized text.""" | |
try: | |
price_value = self.clean_price(row['AskingPrice']) | |
formatted_price = f"{price_value/1_000_000:.2f}M" if price_value >= 1_000_000 else f"{price_value:,.2f}" | |
# Create alternative forms for unit type | |
unit_type = str(row['UnitType']) | |
unit_variations = self._generate_unit_variations(unit_type) | |
# Normalize all fields | |
developer = self._normalize_text(str(row['Developer'])) | |
view = self._normalize_text(str(row['View'])) | |
total_area = str(row['TotalArea']).replace(',', '') | |
# Create searchable document with all variations | |
document = ( | |
f"Unit {row['UnitCode']} is a {unit_type}. " | |
f"Alternative names: {' '.join(unit_variations)}. " | |
f"Located on floor {row['Floor']} and developed by {developer}. " | |
f"Features a {view} view and total area of {total_area} square feet. " | |
f"The asking price is AED {formatted_price}." | |
) | |
return document | |
except Exception as e: | |
logging.error(f"Error preparing document: {str(e)}, Row: {row.to_dict()}") | |
raise | |
def update_file_stats(self, filename: str, count: int): | |
"""Update file statistics.""" | |
self._file_stats[filename] = { | |
"count": count, | |
"last_updated": datetime.now().isoformat() | |
} | |
def get_file_stats(self): | |
"""Get current file statistics.""" | |
return self._file_stats | |
def remove_file_stats(self, filename: str): | |
"""Remove file statistics.""" | |
if filename in self._file_stats: | |
del self._file_stats[filename] | |
def load_data(self, csv_path: str, reset_collection: bool = False) -> Dict: | |
"""Load and process data from CSV file.""" | |
try: | |
filename = os.path.basename(csv_path) | |
# Handle reset collection if needed | |
if reset_collection: | |
existing_ids = [id for id, meta in zip(self.collection.get().get("ids", []), | |
self.collection.get().get("metadatas", [])) | |
if meta.get("source_file") == filename] | |
if existing_ids: | |
self.collection.delete(ids=existing_ids) | |
self.remove_file_stats(filename) | |
# Read and process CSV | |
df = pd.read_csv(csv_path) | |
documents = [] | |
metadatas = [] | |
ids = [] | |
for idx, row in df.iterrows(): | |
try: | |
doc = self.prepare_document(row) | |
# Clean and prepare metadata | |
metadata = { | |
"unit_code": str(row["UnitCode"]), | |
"unit_type": self._normalize_text(str(row["UnitType"])), | |
"developer": self._normalize_text(str(row["Developer"])), | |
"price": float(self.clean_price(row["AskingPrice"])), | |
"total_area": float(str(row["TotalArea"]).replace(',', '')), | |
"view": self._normalize_text(str(row["View"])), | |
"floor": str(row["Floor"]), | |
"source_file": filename, | |
"indexed_date": datetime.now().isoformat() | |
} | |
documents.append(doc) | |
metadatas.append(metadata) | |
ids.append(str(row["UnitCode"])) | |
except Exception as row_error: | |
logging.warning(f"Error processing row {idx}: {str(row_error)}") | |
continue | |
if documents: | |
# Add documents to collection | |
self.collection.add( | |
documents=documents, | |
metadatas=metadatas, | |
ids=ids | |
) | |
self.update_file_stats(filename, len(documents)) | |
return { | |
"status": "success", | |
"message": f"Added {len(documents)} properties from {filename}", | |
"count": len(documents) | |
} | |
else: | |
return { | |
"status": "error", | |
"message": "No valid data to add", | |
"count": 0 | |
} | |
except Exception as e: | |
logging.error(f"Error loading data: {str(e)}") | |
return { | |
"status": "error", | |
"message": str(e), | |
"count": 0 | |
} | |
def _clean_metadata(self, metadata: Dict) -> Dict: | |
"""Clean and normalize metadata values.""" | |
return { | |
'unit_code': str(metadata['unit_code']), | |
'unit_type': self._normalize_text(metadata['unit_type']), | |
'developer': self._normalize_text(metadata['developer']), | |
'view': self._normalize_text(metadata['view']), | |
'floor': str(metadata['floor']), | |
'price': float(str(metadata['price']).replace(',', '')), | |
'total_area': float(str(metadata['total_area']).replace(',', '')) | |
} | |
def _normalize_floor(self, floor_value: str) -> Optional[int]: | |
"""Convert floor value to normalized integer for comparison.""" | |
if not floor_value: | |
return None | |
floor = str(floor_value).strip().upper() | |
# Special cases | |
floor_map = { | |
'G': 0, # Ground | |
'GR': 0, # Ground | |
'LG': 0, # Lower Ground | |
'UG': 0, # Upper Ground | |
'B': -1, # Basement | |
'B1': -1, # Basement 1 | |
'B2': -2, # Basement 2 | |
'B3': -3, # Basement 3 | |
'M': 1, # Mezzanine | |
'P': 1, # Podium | |
'P1': 1, # Podium 1 | |
'P2': 2, # Podium 2 | |
'P3': 3, # Podium 3 | |
'L': 1, # Lower level | |
'PH': 999 # Penthouse (high value for sorting) | |
} | |
# Check exact matches | |
if floor in floor_map: | |
return floor_map[floor] | |
# Check prefixes | |
for prefix, value in floor_map.items(): | |
if floor.startswith(prefix): | |
# Try to get number after prefix | |
num_part = ''.join(filter(str.isdigit, floor[len(prefix):])) | |
if num_part: | |
if prefix in ['B', 'B1', 'B2', 'B3']: | |
return -int(num_part) # Negative for basement levels | |
return int(num_part) | |
return value | |
# Try to extract just numbers | |
nums = ''.join(filter(str.isdigit, floor)) | |
if nums: | |
return int(nums) | |
return None | |
def _in_range(self, value: float, min_val: float = None, max_val: float = None) -> bool: | |
"""Check if value is within specified range.""" | |
if min_val is not None and value < min_val: | |
return False | |
if max_val is not None and value > max_val: | |
return False | |
return True | |
def _generate_variations(self, text: str, property_type: str) -> List[str]: | |
"""Generate variations based on property type.""" | |
if not text: | |
return [] | |
variations = set([text.lower()]) | |
if property_type == "unit_type": | |
# Handle bedroom variations | |
bedroom_match = re.match(r'(\d+)\s*(?:bedroom|br|b/r|bed)', text.lower()) | |
if bedroom_match: | |
num = bedroom_match.group(1) | |
variations.update([ | |
f"{num} bedroom", | |
f"{num}bedroom", | |
f"{num} br", | |
f"{num}br", | |
f"{num} b/r", | |
f"{num}b/r", | |
f"{num} bed", | |
f"{num}bed" | |
]) | |
# Handle special types | |
type_map = { | |
'studio': ['studio', 'studio apartment', 'studio apt'], | |
'penthouse': ['penthouse', 'ph', 'pent house'], | |
'townhouse': ['townhouse', 'th', 'town house'], | |
'duplex': ['duplex', 'du', 'dup'], | |
'villa': ['villa', 'vil', 'v'], | |
'apartment': ['apartment', 'apt', 'flat'] | |
} | |
for base_type, vars in type_map.items(): | |
if any(var in text.lower() for var in vars): | |
variations.update(vars) | |
elif property_type == "view": | |
# Handle view variations | |
view_map = { | |
'sea': ['sea', 'ocean', 'water', 'beach'], | |
'burj': ['burj', 'burj khalifa', 'downtown'], | |
'garden': ['garden', 'park', 'landscape'], | |
'marina': ['marina', 'yacht', 'boat'], | |
'palm': ['palm', 'jumeirah', 'atlantis'] | |
} | |
for base_view, vars in view_map.items(): | |
if any(var in text.lower() for var in vars): | |
variations.update(vars) | |
return list(variations) | |
def _check_floor_range(self, floor: str, min_floor: str = None, max_floor: str = None) -> bool: | |
"""Check if floor value is within specified range.""" | |
if min_floor is None and max_floor is None: | |
return True | |
floor_num = self._normalize_floor(floor) | |
min_floor_num = self._normalize_floor(min_floor) if min_floor else None | |
max_floor_num = self._normalize_floor(max_floor) if max_floor else None | |
if floor_num is None: # If we can't parse the floor, don't filter it out | |
return True | |
if min_floor_num is not None and floor_num < min_floor_num: | |
return False | |
if max_floor_num is not None and floor_num > max_floor_num: | |
return False | |
return True | |
def _calculate_similarity(self, query: str, metadata: Dict, base_score: float) -> float: | |
"""Calculate enhanced similarity score.""" | |
score = base_score | |
if not query: | |
return score | |
normalized_query = self._normalize_text(query) | |
# Unit type matching | |
if any(var in metadata['unit_type'] | |
for var in self._generate_unit_variations(normalized_query)): | |
score *= 1.3 # 30% boost for unit type match | |
# View matching | |
if normalized_query in metadata['view']: | |
score *= 1.2 # 20% boost for view match | |
# Developer matching | |
if normalized_query in metadata['developer']: | |
score *= 1.15 # 15% boost for developer match | |
return min(score, 1.0) # Cap at 1.0 | |
def search(self, query: str, n_results: int = 10) -> List[Dict]: | |
"""Base vectorized search implementation.""" | |
try: | |
if not self.collection.count(): | |
return [] | |
if not query or query.strip() == "*": | |
return [] | |
# Process query with variations | |
normalized_query = self._normalize_text(query) | |
variations = self._generate_unit_variations(normalized_query) | |
# Build semantic query with boosted relevant terms | |
search_query = f"{normalized_query} {' '.join(variations)}" | |
# Get raw results with buffer for better scoring | |
buffer_multiplier = 2 | |
results = self.collection.query( | |
query_texts=[search_query], | |
n_results=min(n_results * buffer_multiplier, self.collection.count()), | |
include=['metadatas', 'documents', 'distances'] | |
) | |
if not results['ids'][0]: | |
return [] | |
processed_results = [] | |
for i, doc_id in enumerate(results['ids'][0]): | |
try: | |
metadata = results['metadatas'][0][i] | |
cleaned_metadata = self._clean_metadata(metadata) | |
# Calculate enhanced similarity score | |
base_score = 1 / (1 + float(results['distances'][0][i])) | |
final_score = self._calculate_similarity(query, cleaned_metadata, base_score) | |
processed_results.append({ | |
'id': doc_id, | |
'document': results['documents'][0][i], | |
'metadata': cleaned_metadata, | |
'similarity': final_score | |
}) | |
except Exception as e: | |
logging.warning(f"Error processing result {doc_id}: {e}") | |
continue | |
# Sort by enhanced similarity score | |
processed_results.sort(key=lambda x: x['similarity'], reverse=True) | |
return processed_results[:n_results] | |
except Exception as e: | |
logging.error(f"Search error: {e}", exc_info=True) | |
return [] | |
def filter_search(self, | |
query: str = None, | |
min_price: float = None, | |
max_price: float = None, | |
min_area: float = None, | |
max_area: float = None, | |
min_floor: str = None, | |
max_floor: str = None, | |
unit_type: str = None, | |
developer: str = None, | |
view: str = None, | |
n_results: int = 10) -> List[Dict]: | |
"""Combined search and filter implementation.""" | |
try: | |
# Get base results | |
if query: | |
base_results = self.search(query, n_results=self.collection.count()) | |
else: | |
results = self.collection.get(include=['metadatas', 'documents']) | |
base_results = [{ | |
'id': id, | |
'document': doc, | |
'metadata': self._clean_metadata(meta), | |
'similarity': 1.0 | |
} for id, doc, meta in zip( | |
results.get('ids', []), | |
results.get('documents', []), | |
results.get('metadatas', []) | |
)] | |
if not base_results: | |
return [] | |
# Apply filters with enhanced scoring | |
filtered_results = [] | |
for result in base_results: | |
try: | |
metadata = result['metadata'] | |
score = result['similarity'] | |
# Basic filters | |
if not all([ | |
self._in_range(metadata['price'], min_price, max_price), | |
self._in_range(metadata['total_area'], min_area, max_area), | |
self._check_floor_range(metadata['floor'], min_floor, max_floor) | |
]): | |
continue | |
# Enhanced unit type matching | |
if unit_type: | |
unit_vars = self._generate_unit_variations(metadata['unit_type']) | |
if any(unit_type.lower() == var.lower() for var in unit_vars): | |
score *= 1.3 # Exact match | |
elif any(unit_type.lower() in var.lower() for var in unit_vars): | |
score *= 1.15 # Partial match | |
else: | |
continue | |
# Enhanced developer matching | |
if developer: | |
if developer.lower() == metadata['developer'].lower(): | |
score *= 1.25 # Exact match | |
elif developer.lower() in metadata['developer'].lower(): | |
score *= 1.1 # Partial match | |
else: | |
continue | |
# Enhanced view matching | |
if view: | |
if view.lower() == metadata['view'].lower(): | |
score *= 1.25 # Exact match | |
elif view.lower() in metadata['view'].lower(): | |
score *= 1.1 # Partial match | |
else: | |
continue | |
result['similarity'] = min(score, 1.0) # Cap final score | |
filtered_results.append(result) | |
except Exception as e: | |
logging.warning(f"Filter error for {result['id']}: {e}") | |
continue | |
# Sort by final similarity score | |
filtered_results.sort(key=lambda x: x['similarity'], reverse=True) | |
return filtered_results[:n_results] | |
except Exception as e: | |
logging.error(f"Filter search error: {e}", exc_info=True) | |
return [] | |
def delete_properties(self, files: List[str] = None): | |
"""Delete properties by source filename.""" | |
try: | |
if files: | |
for filename in files: | |
results = self.collection.get() | |
to_delete = [ | |
id for id, meta in zip(results.get("ids", []), | |
results.get("metadatas", [])) | |
if meta.get("source_file") == filename | |
] | |
if to_delete: | |
self.collection.delete(ids=to_delete) | |
self.remove_file_stats(filename) | |
logging.info(f"Deleted {len(to_delete)} properties from {filename}") | |
else: | |
logging.warning(f"No properties found for file {filename}") | |
except Exception as e: | |
logging.error(f"Error deleting properties: {str(e)}") | |
def get_collection_stats(self) -> Dict: | |
"""Get collection statistics.""" | |
try: | |
total_properties = sum(stat["count"] for stat in self._file_stats.values()) | |
stats = { | |
"total_properties": total_properties, | |
"file_stats": self._file_stats, | |
"collection_name": self.collection.name, | |
"last_updated": max((stat["last_updated"] for stat in self._file_stats.values()), | |
default=None) if self._file_stats else None | |
} | |
return stats | |
except Exception as e: | |
logging.error(f"Error getting collection stats: {str(e)}") | |
return {"error": str(e)} | |