poi_model / search_engine.py
Peersik's picture
Upload 25 files
412553b verified
# search_engine.py
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import normalize
from sklearn.metrics.pairwise import cosine_similarity
import json
import os
from geopy.distance import geodesic
from huggingface_hub import hf_hub_download
from typing import Dict, List, Tuple, Optional, Set
import logging
from dataclasses import dataclass
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class SearchConfig:
"""Конфигурация поиска"""
text_weight: float = 0.7
category_weight: float = 0.3
max_results: int = 20
max_distance_km: float = 5.0
min_similarity: float = 0.15
diversity_weight: float = 0.2
class EnhancedPOISearchEngine:
"""Поисковый движок для точек интереса"""
def __init__(self, model_path: str = 'model/enhanced'):
self.model_path = model_path
self.model = None
self.embeddings = None
self.df = None
self.category_vectors = None
self.category_keywords = self._init_category_keywords()
self.config = SearchConfig()
def _init_category_keywords(self) -> Dict[str, List[str]]:
"""Инициализация ключевых слов для категорий"""
return {
"Музеи": ["музей", "выставка", "искусство", "история", "галерея"],
"Рестораны": ["ресторан", "ужин", "обед", "кухня", "меню"],
"Кафе": ["кафе", "кофе", "чай", "завтрак", "десерт"],
"Парки": ["парк", "сквер", "прогулка", "отдых", "природа"],
"Магазины": ["магазин", "шопинг", "покупки", "торговый", "бутик"],
"Сувениры": ["сувениры", "подарки", "магниты", "памятный"],
"Театры": ["театр", "спектакль", "балет", "опера", "драма"],
"Кинотеатры": ["кино", "кинотеатр", "фильм", "премьера", "сеанс"],
"Отели": ["отель", "гостиница", "номер", "бронирование"],
"Церкви": ["церковь", "храм", "собор", "часовня", "монастырь"],
"Памятники": ["памятник", "скульптура", "статуя", "монумент"],
"Смотровые площадки": ["смотровая", "панорама", "вид", "обзор"],
"Достопримечательности": ["достопримечательность", "интересное место", "туристическое"]
}
def load_model(self) -> bool:
"""Загружает модель и данные"""
try:
logger.info(f"🔄 Загрузка модели из {self.model_path}")
if not os.path.exists(self.model_path):
logger.error(f"❌ Директория модели не найдена")
return False
# Загружаем данные
data_path = os.path.join(self.model_path, 'pois_data.csv')
if not os.path.exists(data_path):
logger.error(f"❌ Файл данных не найден")
return False
self.df = pd.read_csv(data_path)
logger.info(f"✅ Загружены данные: {len(self.df)} точек")
# Загружаем эмбеддинги
embeddings_files = ['proper_embeddings.npy', 'embeddings.npy']
for emb_file in embeddings_files:
embeddings_path = os.path.join(self.model_path, emb_file)
if os.path.exists(embeddings_path):
self.embeddings = np.load(embeddings_path)
logger.info(f"✅ Загружены эмбеддинги: {self.embeddings.shape}")
# Нормализуем эмбеддинги
self.embeddings = normalize(self.embeddings, norm='l2')
break
if self.embeddings is None:
logger.error("❌ Не найдены эмбеддинги")
return False
# Загружаем модель
model_name = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
logger.info(f"🤖 Загрузка модели: {model_name}")
self.model = SentenceTransformer(model_name)
# Загружаем дополнительные данные если есть
category_vectors_path = os.path.join(self.model_path, 'category_vectors.npy')
if os.path.exists(category_vectors_path):
self.category_vectors = np.load(category_vectors_path)
logger.info(f"✅ Загружены категориальные векторы")
# Предобработка данных
self._preprocess_data()
logger.info("✅ Модель успешно инициализирована")
return True
except Exception as e:
logger.error(f"❌ Ошибка загрузки модели: {e}")
return False
def _preprocess_data(self):
"""Предобработка данных"""
if self.df is None:
return
# Заполняем NaN
text_columns = ['name', 'description']
for col in text_columns:
if col in self.df.columns:
self.df[col] = self.df[col].fillna('')
# Создаем категории если их нет
if 'category' not in self.df.columns:
self.df['category'] = 'Достопримечательности'
def analyze_query(self, query: str) -> Dict[str, float]:
"""Анализирует запрос и определяет веса категорий"""
query_lower = query.lower()
category_scores = {}
# Убираем разговорные фразы
stop_phrases = ['я хочу', 'мне бы', 'хочу', 'нужно', 'надо', 'посоветуйте', 'подскажите']
for phrase in stop_phrases:
query_lower = query_lower.replace(phrase, '')
# Анализируем по ключевым словам
for category, keywords in self.category_keywords.items():
score = 0.0
for keyword in keywords:
if keyword in query_lower:
score += 1.0
if score > 0:
category_scores[category] = score
# Обработка составных запросов с "и"
if ' и ' in query_lower:
parts = [p.strip() for p in query_lower.split(' и ') if len(p.strip()) > 2]
for part in parts:
for category, keywords in self.category_keywords.items():
for keyword in keywords:
if keyword in part:
category_scores[category] = category_scores.get(category, 0) + 2.0
# Если не найдено категорий
if not category_scores:
category_scores['Достопримечательности'] = 1.0
# Нормализуем веса
total = sum(category_scores.values())
return {k: v / total for k, v in category_scores.items()}
def _categorize_poi(self, name: str, description: str) -> str:
"""Определяет категорию POI по названию и описанию"""
text = f"{name} {description}".lower()
for category, keywords in self.category_keywords.items():
for keyword in keywords:
if keyword in text:
return category
return 'Достопримечательности'
def search(self, query: str, **kwargs) -> List[Dict]:
"""Базовый семантический поиск"""
debug = kwargs.get('debug', False)
max_results = kwargs.get('max_results', self.config.max_results)
if self.model is None or self.embeddings is None:
return []
# Кодируем запрос
query_embedding = self.model.encode([query])
query_embedding = normalize(query_embedding, norm='l2')
# Вычисляем сходство
similarities = cosine_similarity(query_embedding, self.embeddings)[0]
# Получаем топ-N индексов
top_indices = np.argsort(similarities)[-max_results * 2:][::-1]
# Формируем результаты
results = []
for idx in top_indices:
if similarities[idx] < self.config.min_similarity:
continue
row = self.df.iloc[idx]
# Определяем категорию
name = str(row.get('name', ''))
desc = str(row.get('description', ''))
category = self._categorize_poi(name, desc)
result = {
'id': int(idx),
'name': name,
'category': category,
'type': str(row.get('type', '')),
'lat': float(row.get('lat', 0)),
'lon': float(row.get('lon', 0)),
'score': float(similarities[idx]),
'description': desc[:150] if desc else None
}
results.append(result)
return results[:max_results]
def simple_search(self, query: str, max_results: int = 10) -> List[Dict]:
"""Простой текстовый поиск"""
if self.df is None:
return []
query_lower = query.lower()
results = []
for idx, row in self.df.iterrows():
score = 0.0
# Поиск по названию
name = str(row.get('name', '')).lower()
if query_lower in name:
score += 1.0
elif any(word in name for word in query_lower.split()):
score += 0.5
# Поиск по описанию
desc = str(row.get('description', '')).lower()
if query_lower in desc:
score += 0.3
if score > 0.2:
category = self._categorize_poi(
str(row.get('name', '')),
str(row.get('description', ''))
)
result = {
'id': int(idx),
'name': str(row.get('name', '')),
'category': category,
'type': str(row.get('type', '')),
'lat': float(row.get('lat', 0)),
'lon': float(row.get('lon', 0)),
'score': min(score, 1.0),
'description': str(row.get('description', ''))[:100] if row.get('description') else None
}
results.append(result)
results.sort(key=lambda x: x['score'], reverse=True)
return results[:max_results]
def multi_category_search(self, query: str, **kwargs) -> List[Dict]:
"""Многокатегорийный поиск"""
debug = kwargs.get('debug', False)
max_results = kwargs.get('max_results', self.config.max_results)
# Анализируем запрос
query_analysis = self.analyze_query(query)
target_categories = list(query_analysis.keys())
if debug:
logger.info(f"🎯 Целевые категории: {query_analysis}")
# Если одна категория - обычный поиск
if len(target_categories) == 1:
return self.search(query, **kwargs)
# Для нескольких категорий ищем отдельно по каждой
all_results = {}
for category in target_categories:
if category == 'Достопримечательности':
continue
# Ищем точки этой категории
category_keywords = self.category_keywords.get(category, [category.lower()])
for keyword in category_keywords[:2]: # Используем 2 ключевых слова
# Простой поиск по ключевому слову
cat_results = self._search_by_keyword(keyword, max_results)
# Добавляем результаты с весом категории
for res in cat_results:
res_id = res['id']
if res_id not in all_results:
res['score'] = res['score'] * query_analysis.get(category, 0.5)
all_results[res_id] = res
# Если не нашли по категориям, используем обычный поиск
if not all_results:
return self.search(query, **kwargs)
# Преобразуем в список и сортируем
results = list(all_results.values())
results.sort(key=lambda x: x['score'], reverse=True)
# Обеспечиваем разнообразие категорий
diverse_results = self._ensure_category_diversity(results, target_categories, max_results)
if debug:
categories_found = set(r['category'] for r in diverse_results)
logger.info(f"✅ Найдено {len(diverse_results)} точек, категории: {list(categories_found)}")
return diverse_results
def _search_by_keyword(self, keyword: str, max_results: int) -> List[Dict]:
"""Поиск по ключевому слову"""
if self.df is None:
return []
results = []
keyword_lower = keyword.lower()
for idx, row in self.df.iterrows():
name = str(row.get('name', '')).lower()
desc = str(row.get('description', '')).lower()
if keyword_lower in name or keyword_lower in desc:
category = self._categorize_poi(
str(row.get('name', '')),
str(row.get('description', ''))
)
result = {
'id': int(idx),
'name': str(row.get('name', '')),
'category': category,
'type': str(row.get('type', '')),
'lat': float(row.get('lat', 0)),
'lon': float(row.get('lon', 0)),
'score': 0.7 if keyword_lower in name else 0.5,
'description': str(row.get('description', ''))[:100] if row.get('description') else None
}
results.append(result)
return results[:max_results]
def _ensure_category_diversity(self, results: List[Dict], target_categories: List[str], max_results: int) -> List[
Dict]:
"""Обеспечивает разнообразие категорий в результатах"""
if len(results) <= max_results:
return results
# Группируем результаты по категориям
category_groups = {}
for res in results:
cat = res['category']
if cat not in category_groups:
category_groups[cat] = []
category_groups[cat].append(res)
# Собираем разнообразные результаты
diverse_results = []
remaining_slots = max_results
# Сначала берем из целевых категорий
for target_cat in target_categories:
if target_cat in category_groups and remaining_slots > 0:
take_count = min(len(category_groups[target_cat]), max(1, remaining_slots // len(target_categories)))
diverse_results.extend(category_groups[target_cat][:take_count])
remaining_slots -= take_count
# Затем дополняем другими категориями
other_categories = [cat for cat in category_groups.keys() if cat not in target_categories]
for cat in other_categories:
if remaining_slots > 0:
take_count = min(len(category_groups[cat]), 1) # По 1 из каждой другой категории
diverse_results.extend(category_groups[cat][:take_count])
remaining_slots -= take_count
# Сортируем по score
diverse_results.sort(key=lambda x: x['score'], reverse=True)
return diverse_results[:max_results]