Spaces:
Sleeping
Sleeping
# hkust_bnb_visualiser.py | |
# This module provides the main visualization for the HKUST BNB+ platform. | |
# It handles database connections, data retrieval, search relevance calculation, and map visualization | |
# for BNB listings across different neighborhoods in Hong Kong. The class integrates with traffic data | |
# to provide eco-friendly discount calculations based on traffic conditions. | |
# Key capabilities: | |
# - Text search functionality using sentence transformers | |
# - Traffic spot integration for eco-friendly discount calculations | |
# Author: Gordon Li (20317033) | |
# Date: March 2025 | |
import oracledb | |
import pandas as pd | |
import folium | |
from html import escape | |
from sentence_transformers import SentenceTransformer, util | |
from geopy.distance import geodesic | |
import logging | |
from visualiser.td_traffic_spot_visualiser import TrafficSpotManager | |
from constant.hkust_bnb_constant import ( | |
GET_ALL_NEIGHBORHOODS, | |
GET_NEIGHBORHOOD_LISTINGS, | |
GET_LISTING_REVIEWS, | |
GET_LISTING_REVIEWS_FOR_SEARCH, | |
DISCOUNT_INFO_TEMPLATE, | |
TRAFFIC_SPOT_INFO_TEMPLATE, | |
RELEVANCE_INFO_TEMPLATE, | |
POPUP_CONTENT_TEMPLATE, | |
MAP_SCRIPT | |
) | |
class HKUSTBNBVisualiser: | |
# Main class for BNB data visualization and management. | |
# Handles database connections, data retrieval, and rendering of interactive maps. | |
# Initializes the BNB visualizer with database connection, traffic spot manager, and NLP model. | |
# Sets up connection pool, loads traffic data, initializes sentence transformer model, | |
# and prepares neighborhood data with caching structures. | |
def __init__(self): | |
self.connection_params = { | |
'user': 'slliac', | |
'password': '7033', | |
'dsn': 'imz409.ust.hk:1521/imz409' | |
} | |
self.pool = oracledb.SessionPool( | |
user=self.connection_params['user'], | |
password=self.connection_params['password'], | |
dsn=self.connection_params['dsn'], | |
min=2, | |
max=5, | |
increment=1, | |
getmode=oracledb.SPOOL_ATTRVAL_WAIT | |
) | |
self.traffic_manager = TrafficSpotManager(self.connection_params) | |
logging.info(f"Traffic spots initialized, {len(self.traffic_manager.traffic_spots)} spots loaded") | |
try: | |
model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
self.model = SentenceTransformer(model_name) | |
print(f"Loaded Sentence Transformer model: {model_name}") | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
self.model = None | |
try: | |
self.neighborhoods = self.get_all_neighborhoods() | |
self.cached_listings = {} | |
self.cached_embeddings = {} | |
except Exception as e: | |
print(f"Initialization error: {str(e)}") | |
self.neighborhoods = [] | |
self.cached_listings = {} | |
self.cached_embeddings = {} | |
# Finds the nearest traffic spot to a given BNB listing location. | |
# Parameters: | |
# airbnb_lat: The latitude of the BNB listing | |
# airbnb_lng: The longitude of the BNB listing | |
# max_distance_km: Maximum distance in kilometers to consider a traffic spot (default: 0.7) | |
# Returns: | |
# Tuple containing (nearest_traffic_spot, distance_in_km) or (None, None) if no spot is found | |
def find_nearest_traffic_spot(self, airbnb_lat, airbnb_lng, max_distance_km=0.7): | |
nearest_spot = None | |
min_distance = float('inf') | |
for spot in self.traffic_manager.traffic_spots: | |
if not spot.is_valid(): | |
continue | |
distance = geodesic( | |
(airbnb_lat, airbnb_lng), | |
(spot.latitude, spot.longitude) | |
).kilometers | |
if distance < min_distance and distance <= max_distance_km: | |
min_distance = distance | |
nearest_spot = spot | |
if nearest_spot: | |
return nearest_spot, min_distance | |
else: | |
return None, None | |
# Retrieves all available neighborhoods from the database. | |
# Returns: | |
# List of neighborhood names as strings | |
def get_all_neighborhoods(self): | |
connection = self.pool.acquire() | |
try: | |
cursor = connection.cursor() | |
cursor.prefetchrows = 50 | |
cursor.arraysize = 50 | |
cursor.execute(GET_ALL_NEIGHBORHOODS) | |
neighborhoods = [row[0] for row in cursor.fetchall()] | |
return neighborhoods | |
except Exception as e: | |
print(f"Database error getting neighborhoods: {str(e)}") | |
return [] | |
finally: | |
self.pool.release(connection) | |
# Retrieves BNB listings for a specific neighborhood with caching. | |
# Parameters: | |
# neighborhood: The neighborhood name to retrieve listings for | |
# limit: Maximum number of listings to retrieve (default: 10) | |
# Returns: | |
# List of listing data rows from the database | |
def get_neighborhood_listings(self, neighborhood, limit=10): | |
if limit not in [10, 20, 30, 40, 50]: | |
limit = 10 | |
if neighborhood in self.cached_listings and limit in self.cached_listings[neighborhood]: | |
return self.cached_listings[neighborhood][limit] | |
if neighborhood not in self.cached_listings: | |
self.cached_listings[neighborhood] = {} | |
connection = self.pool.acquire() | |
try: | |
cursor = connection.cursor() | |
cursor.prefetchrows = 50 | |
cursor.arraysize = 50 | |
cursor.execute( | |
GET_NEIGHBORHOOD_LISTINGS, | |
neighborhood=neighborhood, | |
limit=limit | |
) | |
listings = cursor.fetchall() | |
self.cached_listings[neighborhood][limit] = listings | |
return listings | |
except Exception as e: | |
print(f"Database error: {str(e)}") | |
return [] | |
finally: | |
self.pool.release(connection) | |
# Retrieves reviews for a specific listing ID. | |
# Parameters: | |
# listing_id: The ID of the listing to get reviews for | |
# Returns: | |
# List of tuples containing (review_date, reviewer_name, comments) | |
def get_listing_reviews(self, listing_id): | |
connection = self.pool.acquire() | |
try: | |
cursor = connection.cursor() | |
cursor.execute( | |
GET_LISTING_REVIEWS, | |
listing_id=int(listing_id) | |
) | |
reviews = cursor.fetchall() | |
formatted_reviews = [] | |
for review in reviews: | |
review_date, reviewer_name, comments = review | |
formatted_review = ( | |
str(review_date) if review_date else '', | |
str(reviewer_name) if reviewer_name else '', | |
str(comments) if comments else '' | |
) | |
formatted_reviews.append(formatted_review) | |
return formatted_reviews | |
except Exception as e: | |
print(f"Error fetching reviews: {str(e)}") | |
return [] | |
finally: | |
self.pool.release(connection) | |
# Retrieves review content for search functionality. | |
# Parameters: | |
# listing_id: The ID of the listing to get reviews for | |
# Returns: | |
# List of review comment strings for semantic search | |
def get_listing_reviews_for_search(self, listing_id): | |
connection = self.pool.acquire() | |
try: | |
cursor = connection.cursor() | |
cursor.execute( | |
GET_LISTING_REVIEWS_FOR_SEARCH, | |
listing_id=int(listing_id) | |
) | |
reviews = cursor.fetchall() | |
formatted_reviews = [] | |
for review in reviews: | |
if review[0] is not None: | |
if hasattr(review[0], 'read'): | |
formatted_reviews.append(review[0].read()) | |
else: | |
formatted_reviews.append(str(review[0])) | |
return formatted_reviews | |
except Exception as e: | |
print(f"Error fetching reviews for search: {str(e)}") | |
return [] | |
finally: | |
self.pool.release(connection) | |
# Computes cosine similarity between two embeddings. | |
# Parameters: | |
# query_embedding: Embedding tensor for the search query | |
# target_embedding: Embedding tensor for the target text | |
# Returns: | |
# Float value representing similarity (0.0-1.0) | |
def compute_similarity(self, query_embedding, target_embedding): | |
if query_embedding is None or target_embedding is None: | |
return 0.0 | |
try: | |
similarity = util.pytorch_cos_sim(query_embedding, target_embedding).item() | |
return similarity | |
except Exception as e: | |
print(f"Error computing similarity: {str(e)}") | |
return 0.0 | |
# Computes relevance scores for listings based on search query. | |
# Parameters: | |
# df: DataFrame containing listing data | |
# search_query: User's search query string | |
# Returns: | |
# List of relevance scores for each listing in the DataFrame | |
def compute_search_scores(self, df, search_query): | |
if not search_query or self.model is None: | |
return [0.0] * len(df) | |
try: | |
query_key = f"query_{search_query}" | |
if query_key not in self.cached_embeddings: | |
self.cached_embeddings[query_key] = self.model.encode(search_query, convert_to_tensor=True) | |
query_embedding = self.cached_embeddings[query_key] | |
scores = [] | |
for idx, row in df.iterrows(): | |
title = str(row['name']) | |
reviews = self.get_listing_reviews_for_search(row['id']) | |
title_key = f"title_{row['id']}" | |
review_key = f"review_{row['id']}" | |
if title_key not in self.cached_embeddings: | |
title_embedding = self.model.encode(title, convert_to_tensor=True) | |
self.cached_embeddings[title_key] = title_embedding | |
else: | |
title_embedding = self.cached_embeddings[title_key] | |
review_embedding = None | |
if reviews and len(reviews) > 0: | |
if review_key not in self.cached_embeddings: | |
review_text = " ".join(reviews[:5]) | |
review_embedding = self.model.encode(review_text, convert_to_tensor=True) | |
self.cached_embeddings[review_key] = review_embedding | |
else: | |
review_embedding = self.cached_embeddings[review_key] | |
title_similarity = self.compute_similarity(query_embedding, title_embedding) | |
review_similarity = 0.0 | |
if review_embedding is not None: | |
review_similarity = self.compute_similarity(query_embedding, review_embedding) | |
final_score = title_similarity * 0.7 + review_similarity * 0.3 if review_embedding is not None else title_similarity | |
scores.append(final_score) | |
return scores | |
except Exception as e: | |
print(f"Error in search scoring: {str(e)}") | |
return [0.0] * len(df) | |
# Sorts a DataFrame of listings by their relevance to a search query. | |
# Parameters: | |
# df: DataFrame containing listing data | |
# search_query: User's search query string | |
# Returns: | |
# DataFrame sorted by relevance to the search query | |
def sort_by_relevance(self, df, search_query): | |
if not search_query: | |
return df | |
scores = self.compute_search_scores(df, search_query) | |
df['relevance_score'] = scores | |
df['relevance_percentage'] = df['relevance_score'] * 100 | |
return df.sort_values('relevance_score', ascending=False) | |
# Creates an interactive map and DataFrame for display in the UI. | |
# Parameters: | |
# neighborhood: The neighborhood to display listings for (default: "Sha Tin") | |
# show_traffic: Whether to show traffic spots on the map (default: True) | |
# center_lat: Latitude to center the map on (default: None, will use mean of listings) | |
# center_lng: Longitude to center the map on (default: None, will use mean of listings) | |
# selected_id: ID of the currently selected listing (default: None) | |
# search_query: User's search query string (default: None) | |
# current_page: Current page number for pagination (default: 1) | |
# items_per_page: Number of items to show per page (default: 3) | |
# listings_limit: Maximum number of listings to retrieve (default: 10) | |
# Returns: | |
# Tuple containing (folium_map, listings_dataframe) | |
def create_map_and_data(self, neighborhood="Sha Tin", show_traffic=True, center_lat=None, center_lng=None, | |
selected_id=None, search_query=None, current_page=1, items_per_page=3, listings_limit=10): | |
if listings_limit not in [10, 20, 30, 40, 50]: | |
listings_limit = 10 | |
listings = self.get_neighborhood_listings(neighborhood, listings_limit) | |
if not listings: | |
return None, None | |
df = pd.DataFrame(listings, columns=[ | |
'id', 'name', 'host_name', 'neighbourhood', | |
'latitude', 'longitude', 'room_type', 'price', | |
'number_of_reviews', 'reviews_per_month', | |
'minimum_nights', 'availability_365' | |
]) | |
numeric_cols = ['latitude', 'longitude', 'price', 'number_of_reviews', | |
'minimum_nights', 'availability_365', 'reviews_per_month'] | |
for col in numeric_cols: | |
df[col] = pd.to_numeric(df[col], errors='coerce') | |
if search_query: | |
df = self.sort_by_relevance(df, search_query) | |
if df.empty: | |
return None, None | |
if center_lat is None or center_lng is None: | |
center_lat = df['latitude'].mean() | |
center_lng = df['longitude'].mean() | |
m = folium.Map( | |
location=[center_lat, center_lng], | |
zoom_start=16 if (center_lat is not None and center_lng is not None) else 14, | |
tiles='OpenStreetMap' | |
) | |
all_traffic_spots_to_display = set() | |
all_nearest_traffic_spots = {} | |
for idx, row in df.iterrows(): | |
nearest_spot, distance = self.find_nearest_traffic_spot(row['latitude'], row['longitude']) | |
if nearest_spot: | |
all_nearest_traffic_spots[row['id']] = (nearest_spot, distance) | |
all_traffic_spots_to_display.add(nearest_spot.key) | |
lines_group = folium.FeatureGroup(name="Connection Lines") | |
m.add_child(lines_group) | |
if show_traffic and all_traffic_spots_to_display: | |
self.traffic_manager.add_spots_to_map(m, all_traffic_spots_to_display) | |
for idx, row in df.iterrows(): | |
marker_id = f"marker_{row['id']}" | |
traffic_spot_info = "" | |
discount_info = "" | |
discounted_price = row['price'] | |
if row['id'] in all_nearest_traffic_spots: | |
nearest_spot, distance = all_nearest_traffic_spots[row['id']] | |
discount_rate = nearest_spot.get_discount_rate() | |
if discount_rate > 0: | |
discounted_price = row['price'] * (1 - discount_rate) | |
discount_percentage = int(discount_rate * 100) | |
discount_info = DISCOUNT_INFO_TEMPLATE.format( | |
discount_percentage=discount_percentage, | |
original_price=row['price'], | |
discounted_price=discounted_price, | |
avg_vehicle_count=nearest_spot.avg_vehicle_count, | |
observation_count=len(nearest_spot.dataset_rows) | |
) | |
distance_str = f"{distance:.2f} km" if distance >= 0.1 else f"{distance * 1000:.0f} meters" | |
traffic_spot_info = TRAFFIC_SPOT_INFO_TEMPLATE.format( | |
spot_key=escape(str(nearest_spot.key)), | |
distance_str=distance_str | |
) | |
folium.PolyLine( | |
locations=[ | |
[row['latitude'], row['longitude']], | |
[nearest_spot.latitude, nearest_spot.longitude] | |
], | |
color='blue', | |
weight=2, | |
opacity=0.7, | |
dash_array='5', | |
tooltip=f"Distance: {distance_str}" | |
).add_to(lines_group) | |
relevance_info = "" | |
if search_query and 'relevance_percentage' in row and 'relevance_features' in row: | |
relevance_info = RELEVANCE_INFO_TEMPLATE.format( | |
relevance_percentage=row['relevance_percentage'], | |
relevance_features=row['relevance_features'], | |
matching_features=row['matching_features'] | |
) | |
price_display = f"<strong>Price:</strong> ${row['price']:.0f}" | |
if discount_info: | |
price_display = (f"<strong>Price:</strong> " | |
f"<span style='text-decoration: line-through;'>${row['price']:.0f}</span> " | |
f"<span style='color: #2e7d32; font-weight: bold;'>${discounted_price:.0f}</span>") | |
popup_content = POPUP_CONTENT_TEMPLATE.format( | |
listing_name=escape(str(row['name'])), | |
host_name=escape(str(row['host_name'])), | |
room_type=escape(str(row['room_type'])), | |
price_display=price_display, | |
review_count=row['number_of_reviews'], | |
discount_info=discount_info, | |
traffic_spot_info=traffic_spot_info, | |
relevance_info=relevance_info | |
) | |
marker_color = 'green' if selected_id == row['id'] else 'red' | |
marker = folium.Marker( | |
location=[row['latitude'], row['longitude']], | |
popup=popup_content, | |
icon=folium.Icon(color=marker_color, icon='home'), | |
) | |
marker.add_to(m) | |
if selected_id is not None and row['id'] == selected_id: | |
marker._name = marker_id | |
folium.Element(MAP_SCRIPT).add_to(m) | |
folium.LayerControl().add_to(m) | |
return m, df |