|
|
""" |
|
|
Embedding Service Module |
|
|
Provides multilingual semantic search using sentence-transformers. |
|
|
|
|
|
Uses paraphrase-multilingual-MiniLM-L12-v2 by default which supports 50+ languages |
|
|
including English, Ukrainian, Russian, Spanish, German, French, etc. |
|
|
|
|
|
References: |
|
|
- Reimers & Gurevych (2019): Sentence-BERT |
|
|
- Reimers & Gurevych (2020): Making Monolingual Sentence Embeddings Multilingual |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
from dataclasses import dataclass |
|
|
import hashlib |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SearchResult: |
|
|
"""Result from semantic search.""" |
|
|
entity_id: str |
|
|
score: float |
|
|
entity_data: Dict[str, Any] |
|
|
|
|
|
|
|
|
class EmbeddingService: |
|
|
""" |
|
|
Multilingual embedding service for semantic search. |
|
|
|
|
|
Replaces keyword-based matching with embedding similarity, |
|
|
enabling language-agnostic symptom/entity matching. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", |
|
|
cache_dir: str = "./data/embeddings", |
|
|
device: str = "cpu" |
|
|
): |
|
|
self.model_name = model_name |
|
|
self.cache_dir = Path(cache_dir) |
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
self.device = device |
|
|
|
|
|
self._model = None |
|
|
self._entity_embeddings: Dict[str, np.ndarray] = {} |
|
|
self._entity_data: Dict[str, Dict] = {} |
|
|
self._embedding_dim: int = 384 |
|
|
|
|
|
|
|
|
self._index = None |
|
|
self._index_ids: List[str] = [] |
|
|
|
|
|
@property |
|
|
def model(self): |
|
|
"""Lazy load the embedding model.""" |
|
|
if self._model is None: |
|
|
try: |
|
|
from sentence_transformers import SentenceTransformer |
|
|
logger.info(f"Loading embedding model: {self.model_name}") |
|
|
self._model = SentenceTransformer(self.model_name, device=self.device) |
|
|
self._embedding_dim = self._model.get_sentence_embedding_dimension() |
|
|
logger.info(f"Model loaded. Embedding dimension: {self._embedding_dim}") |
|
|
except ImportError: |
|
|
logger.error( |
|
|
"sentence-transformers not installed. " |
|
|
"Run: pip install sentence-transformers" |
|
|
) |
|
|
raise |
|
|
return self._model |
|
|
|
|
|
def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray: |
|
|
""" |
|
|
Encode texts to embeddings. |
|
|
|
|
|
Args: |
|
|
texts: List of text strings to encode |
|
|
batch_size: Batch size for encoding |
|
|
|
|
|
Returns: |
|
|
numpy array of shape (len(texts), embedding_dim) |
|
|
""" |
|
|
if not texts: |
|
|
return np.array([]) |
|
|
|
|
|
embeddings = self.model.encode( |
|
|
texts, |
|
|
batch_size=batch_size, |
|
|
show_progress_bar=len(texts) > 100, |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True |
|
|
) |
|
|
return embeddings |
|
|
|
|
|
def encode_single(self, text: str) -> np.ndarray: |
|
|
"""Encode a single text string.""" |
|
|
return self.encode([text])[0] |
|
|
|
|
|
def index_entities( |
|
|
self, |
|
|
entities: Dict[str, Dict[str, Any]], |
|
|
text_fields: List[str] = ["name", "description", "synonyms"] |
|
|
): |
|
|
""" |
|
|
Build search index from entities. |
|
|
|
|
|
Args: |
|
|
entities: Dict of entity_id -> entity_data |
|
|
text_fields: Fields to combine for embedding text |
|
|
""" |
|
|
logger.info(f"Indexing {len(entities)} entities for semantic search") |
|
|
|
|
|
|
|
|
cache_key = self._compute_cache_key(entities) |
|
|
if self._load_from_cache(cache_key): |
|
|
logger.info("Loaded embeddings from cache") |
|
|
return |
|
|
|
|
|
|
|
|
texts = [] |
|
|
entity_ids = [] |
|
|
|
|
|
for entity_id, entity in entities.items(): |
|
|
|
|
|
text_parts = [] |
|
|
for field in text_fields: |
|
|
value = entity.get(field) |
|
|
if value: |
|
|
if isinstance(value, list): |
|
|
text_parts.extend(value) |
|
|
else: |
|
|
text_parts.append(str(value)) |
|
|
|
|
|
if text_parts: |
|
|
combined_text = " ".join(text_parts) |
|
|
texts.append(combined_text) |
|
|
entity_ids.append(entity_id) |
|
|
self._entity_data[entity_id] = entity |
|
|
|
|
|
if not texts: |
|
|
logger.warning("No texts to index") |
|
|
return |
|
|
|
|
|
|
|
|
logger.info(f"Computing embeddings for {len(texts)} entities...") |
|
|
embeddings = self.encode(texts) |
|
|
|
|
|
|
|
|
for entity_id, embedding in zip(entity_ids, embeddings): |
|
|
self._entity_embeddings[entity_id] = embedding |
|
|
|
|
|
self._index_ids = entity_ids |
|
|
|
|
|
|
|
|
self._build_index(embeddings) |
|
|
|
|
|
|
|
|
self._save_to_cache(cache_key) |
|
|
|
|
|
logger.info(f"Indexed {len(self._entity_embeddings)} entities") |
|
|
|
|
|
def _build_index(self, embeddings: np.ndarray): |
|
|
"""Build search index from embeddings.""" |
|
|
try: |
|
|
import faiss |
|
|
|
|
|
|
|
|
self._index = faiss.IndexFlatIP(self._embedding_dim) |
|
|
self._index.add(embeddings.astype(np.float32)) |
|
|
logger.info("Built FAISS index for fast similarity search") |
|
|
|
|
|
except ImportError: |
|
|
|
|
|
logger.info("FAISS not available, using numpy for similarity search") |
|
|
self._index = None |
|
|
self._embedding_matrix = embeddings |
|
|
|
|
|
def search( |
|
|
self, |
|
|
query: str, |
|
|
top_k: int = 10, |
|
|
threshold: float = 0.3, |
|
|
category_filter: Optional[str] = None |
|
|
) -> List[SearchResult]: |
|
|
""" |
|
|
Search for entities similar to query. |
|
|
|
|
|
Args: |
|
|
query: Search query (any language) |
|
|
top_k: Maximum number of results |
|
|
threshold: Minimum similarity score (0-1) |
|
|
category_filter: Optional category to filter by |
|
|
|
|
|
Returns: |
|
|
List of SearchResult sorted by score descending |
|
|
""" |
|
|
if not self._entity_embeddings: |
|
|
logger.warning("No entities indexed. Call index_entities first.") |
|
|
return [] |
|
|
|
|
|
|
|
|
query_embedding = self.encode_single(query) |
|
|
|
|
|
|
|
|
if self._index is not None: |
|
|
|
|
|
scores, indices = self._index.search( |
|
|
query_embedding.reshape(1, -1).astype(np.float32), |
|
|
min(top_k * 2, len(self._index_ids)) |
|
|
) |
|
|
scores = scores[0] |
|
|
indices = indices[0] |
|
|
else: |
|
|
|
|
|
scores = np.dot(self._embedding_matrix, query_embedding) |
|
|
indices = np.argsort(scores)[::-1][:top_k * 2] |
|
|
scores = scores[indices] |
|
|
|
|
|
|
|
|
results = [] |
|
|
for score, idx in zip(scores, indices): |
|
|
if score < threshold: |
|
|
continue |
|
|
if idx < 0 or idx >= len(self._index_ids): |
|
|
continue |
|
|
|
|
|
entity_id = self._index_ids[idx] |
|
|
entity_data = self._entity_data.get(entity_id, {}) |
|
|
|
|
|
|
|
|
if category_filter and entity_data.get("category") != category_filter: |
|
|
continue |
|
|
|
|
|
results.append(SearchResult( |
|
|
entity_id=entity_id, |
|
|
score=float(score), |
|
|
entity_data=entity_data |
|
|
)) |
|
|
|
|
|
if len(results) >= top_k: |
|
|
break |
|
|
|
|
|
return results |
|
|
|
|
|
def search_multiple( |
|
|
self, |
|
|
queries: List[str], |
|
|
top_k_per_query: int = 5, |
|
|
threshold: float = 0.3, |
|
|
deduplicate: bool = True |
|
|
) -> List[SearchResult]: |
|
|
""" |
|
|
Search with multiple queries, combining results. |
|
|
|
|
|
Useful for extracting multiple symptoms from a single user query. |
|
|
""" |
|
|
all_results: Dict[str, SearchResult] = {} |
|
|
|
|
|
for query in queries: |
|
|
results = self.search(query, top_k=top_k_per_query, threshold=threshold) |
|
|
for result in results: |
|
|
if result.entity_id not in all_results: |
|
|
all_results[result.entity_id] = result |
|
|
else: |
|
|
|
|
|
if result.score > all_results[result.entity_id].score: |
|
|
all_results[result.entity_id] = result |
|
|
|
|
|
|
|
|
return sorted(all_results.values(), key=lambda x: x.score, reverse=True) |
|
|
|
|
|
def extract_entities_from_text( |
|
|
self, |
|
|
text: str, |
|
|
category: Optional[str] = None, |
|
|
top_k: int = 5, |
|
|
threshold: float = 0.4 |
|
|
) -> List[SearchResult]: |
|
|
""" |
|
|
Extract relevant entities from free-form text. |
|
|
|
|
|
This is the main method for symptom extraction from user queries. |
|
|
Works across all supported languages. |
|
|
|
|
|
Args: |
|
|
text: User input text (any language) |
|
|
category: Filter by category (e.g., "symptom", "disease") |
|
|
top_k: Maximum entities to return |
|
|
threshold: Minimum similarity threshold |
|
|
""" |
|
|
|
|
|
results = self.search( |
|
|
query=text, |
|
|
top_k=top_k, |
|
|
threshold=threshold, |
|
|
category_filter=category |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
phrases = re.split(r'[,;.]|\band\b|\bwith\b|\balso\b|\bі\b|\bта\b|\bи\b', text) |
|
|
phrases = [p.strip() for p in phrases if p.strip() and len(p.strip()) > 2] |
|
|
|
|
|
if len(phrases) > 1: |
|
|
phrase_results = self.search_multiple( |
|
|
phrases, |
|
|
top_k_per_query=3, |
|
|
threshold=threshold |
|
|
) |
|
|
|
|
|
|
|
|
seen_ids = {r.entity_id for r in results} |
|
|
for pr in phrase_results: |
|
|
if pr.entity_id not in seen_ids: |
|
|
results.append(pr) |
|
|
seen_ids.add(pr.entity_id) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x.score, reverse=True) |
|
|
return results[:top_k] |
|
|
|
|
|
def _compute_cache_key(self, entities: Dict) -> str: |
|
|
"""Compute cache key from entities.""" |
|
|
|
|
|
entity_str = json.dumps(sorted(entities.keys())) |
|
|
key_str = f"{self.model_name}:{entity_str}" |
|
|
return hashlib.md5(key_str.encode()).hexdigest()[:16] |
|
|
|
|
|
def _load_from_cache(self, cache_key: str) -> bool: |
|
|
"""Try to load embeddings from cache.""" |
|
|
embeddings_path = self.cache_dir / f"{cache_key}_embeddings.npy" |
|
|
metadata_path = self.cache_dir / f"{cache_key}_metadata.json" |
|
|
|
|
|
if not embeddings_path.exists() or not metadata_path.exists(): |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
with open(metadata_path) as f: |
|
|
metadata = json.load(f) |
|
|
|
|
|
|
|
|
if metadata.get("model") != self.model_name: |
|
|
return False |
|
|
|
|
|
|
|
|
embeddings = np.load(embeddings_path) |
|
|
|
|
|
|
|
|
self._index_ids = metadata["entity_ids"] |
|
|
self._entity_data = metadata["entity_data"] |
|
|
|
|
|
for i, entity_id in enumerate(self._index_ids): |
|
|
self._entity_embeddings[entity_id] = embeddings[i] |
|
|
|
|
|
|
|
|
self._build_index(embeddings) |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load from cache: {e}") |
|
|
return False |
|
|
|
|
|
def _save_to_cache(self, cache_key: str): |
|
|
"""Save embeddings to cache.""" |
|
|
try: |
|
|
embeddings_path = self.cache_dir / f"{cache_key}_embeddings.npy" |
|
|
metadata_path = self.cache_dir / f"{cache_key}_metadata.json" |
|
|
|
|
|
|
|
|
embeddings = np.array([ |
|
|
self._entity_embeddings[eid] for eid in self._index_ids |
|
|
]) |
|
|
|
|
|
|
|
|
np.save(embeddings_path, embeddings) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"model": self.model_name, |
|
|
"entity_ids": self._index_ids, |
|
|
"entity_data": self._entity_data, |
|
|
"embedding_dim": self._embedding_dim |
|
|
} |
|
|
with open(metadata_path, "w") as f: |
|
|
json.dump(metadata, f) |
|
|
|
|
|
logger.info(f"Saved embeddings cache: {cache_key}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to save cache: {e}") |
|
|
|
|
|
def clear_cache(self): |
|
|
"""Clear all cached embeddings.""" |
|
|
import shutil |
|
|
if self.cache_dir.exists(): |
|
|
shutil.rmtree(self.cache_dir) |
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self._entity_embeddings.clear() |
|
|
self._entity_data.clear() |
|
|
self._index = None |
|
|
self._index_ids = [] |
|
|
|
|
|
|
|
|
|
|
|
_embedding_service: Optional[EmbeddingService] = None |
|
|
|
|
|
|
|
|
def get_embedding_service() -> EmbeddingService: |
|
|
"""Get the global embedding service instance.""" |
|
|
global _embedding_service |
|
|
if _embedding_service is None: |
|
|
from .config import get_config |
|
|
config = get_config() |
|
|
_embedding_service = EmbeddingService( |
|
|
model_name=config.embedding.model_name, |
|
|
cache_dir=config.embedding.cache_dir, |
|
|
device=config.embedding.device |
|
|
) |
|
|
return _embedding_service |
|
|
|