mona / utils /semantic_search.py
mrradix's picture
Upload 48 files
8e4018d verified
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import re
from collections import defaultdict
from utils.logging import setup_logger
from utils.error_handling import handle_exceptions, AIModelError
# Initialize logger
logger = setup_logger(__name__)
# Global model cache
MODEL_CACHE = {}
def get_embedding_model():
"""Load and cache the sentence embedding model"""
model_name = "all-MiniLM-L6-v2" # A good balance of performance and speed
if model_name not in MODEL_CACHE:
logger.info(f"Loading embedding model: {model_name}")
try:
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer(model_name, device=device)
MODEL_CACHE[model_name] = model
logger.info(f"Embedding model loaded successfully on {device}")
except Exception as e:
logger.error(f"Error loading embedding model: {str(e)}")
raise AIModelError(f"Error loading embedding model", {"original_error": str(e)}) from e
return MODEL_CACHE[model_name]
def extract_text_from_item(item: Dict[str, Any]) -> str:
"""Extract searchable text from an item"""
text_parts = []
# Extract title and content
if "title" in item and item["title"]:
text_parts.append(item["title"])
if "content" in item and item["content"]:
text_parts.append(item["content"])
# Extract description if available
if "description" in item and item["description"]:
text_parts.append(item["description"])
# Extract tags if available
if "tags" in item and item["tags"]:
if isinstance(item["tags"], list):
text_parts.append(" ".join(item["tags"]))
elif isinstance(item["tags"], str):
text_parts.append(item["tags"])
# Join all parts with spaces
return " ".join(text_parts)
def get_item_embeddings(items: List[Dict[str, Any]]) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
"""Get embeddings for a list of items"""
model = get_embedding_model()
texts = []
valid_items = []
for item in items:
text = extract_text_from_item(item)
if text.strip(): # Only include items with non-empty text
texts.append(text)
valid_items.append(item)
if not texts:
return np.array([]), []
try:
embeddings = model.encode(texts, convert_to_numpy=True)
return embeddings, valid_items
except Exception as e:
logger.error(f"Error generating embeddings: {str(e)}")
return np.array([]), []
def search_content(query: str, items: List[Dict[str, Any]], top_k: int = 10) -> List[Dict[str, Any]]:
"""Search content using semantic search with fallback to keyword search
Args:
query: Search query
items: List of items to search
top_k: Number of top results to return
Returns:
List of items sorted by relevance
"""
if not query or not items:
return []
logger.info(f"Performing semantic search for query: {query}")
try:
# Get embeddings for items
item_embeddings, valid_items = get_item_embeddings(items)
if len(valid_items) == 0:
logger.warning("No valid items with text content found")
return []
# Get embedding for query
model = get_embedding_model()
query_embedding = model.encode([query], convert_to_numpy=True)
# Calculate similarity scores
similarity_scores = cosine_similarity(query_embedding, item_embeddings)[0]
# Create result items with scores
results = []
for i, (item, score) in enumerate(zip(valid_items, similarity_scores)):
item_copy = item.copy()
item_copy["relevance_score"] = float(score)
results.append(item_copy)
# Sort by relevance score
results.sort(key=lambda x: x["relevance_score"], reverse=True)
# Return top k results
return results[:top_k]
except Exception as e:
logger.error(f"Error in semantic search: {str(e)}. Falling back to keyword search.")
# Fallback to keyword search
return keyword_search(query, items, top_k)
def keyword_search(query: str, items: List[Dict[str, Any]], top_k: int = 10) -> List[Dict[str, Any]]:
"""Fallback keyword search when semantic search fails
Args:
query: Search query
items: List of items to search
top_k: Number of top results to return
Returns:
List of items sorted by relevance
"""
logger.info(f"Performing keyword search for query: {query}")
# Prepare query terms
query_terms = re.findall(r'\w+', query.lower())
if not query_terms:
return []
results = []
for item in items:
text = extract_text_from_item(item).lower()
# Calculate simple relevance score based on term frequency
score = 0
for term in query_terms:
term_count = text.count(term)
if term_count > 0:
# Give more weight to terms in title
title = item.get("title", "").lower()
title_count = title.count(term)
score += (term_count + title_count * 2) # Title matches count double
if score > 0:
item_copy = item.copy()
item_copy["relevance_score"] = score
results.append(item_copy)
# Sort by relevance score
results.sort(key=lambda x: x["relevance_score"], reverse=True)
# Return top k results
return results[:top_k]
def find_similar_items(item: Dict[str, Any], items: List[Dict[str, Any]], top_k: int = 3) -> List[Dict[str, Any]]:
"""Find items similar to a given item
Args:
item: Reference item
items: List of items to search
top_k: Number of top results to return
Returns:
List of similar items
"""
if not item or not items:
return []
# Extract text from reference item
reference_text = extract_text_from_item(item)
if not reference_text.strip():
return []
try:
# Get embedding for reference item
model = get_embedding_model()
reference_embedding = model.encode([reference_text], convert_to_numpy=True)
# Get embeddings for items
item_embeddings, valid_items = get_item_embeddings(items)
if len(valid_items) == 0:
return []
# Calculate similarity scores
similarity_scores = cosine_similarity(reference_embedding, item_embeddings)[0]
# Create result items with scores
results = []
for i, (similar_item, score) in enumerate(zip(valid_items, similarity_scores)):
# Skip the reference item itself
if similar_item.get("id") == item.get("id"):
continue
similar_item_copy = similar_item.copy()
similar_item_copy["similarity_score"] = float(score)
results.append(similar_item_copy)
# Sort by similarity score
results.sort(key=lambda x: x["similarity_score"], reverse=True)
# Return top k results
return results[:top_k]
except Exception as e:
logger.error(f"Error finding similar items: {str(e)}. Falling back to keyword similarity.")
return keyword_similarity(item, items, top_k)
def keyword_similarity(item: Dict[str, Any], items: List[Dict[str, Any]], top_k: int = 3) -> List[Dict[str, Any]]:
"""Fallback keyword-based similarity when semantic similarity fails
Args:
item: Reference item
items: List of items to search
top_k: Number of top results to return
Returns:
List of similar items
"""
# Extract text from reference item
reference_text = extract_text_from_item(item).lower()
if not reference_text.strip():
return []
# Extract words from reference text
reference_words = set(re.findall(r'\w+', reference_text))
results = []
for other_item in items:
# Skip the reference item itself
if other_item.get("id") == item.get("id"):
continue
other_text = extract_text_from_item(other_item).lower()
other_words = set(re.findall(r'\w+', other_text))
# Calculate Jaccard similarity
if not other_words or not reference_words:
continue
intersection = len(reference_words.intersection(other_words))
union = len(reference_words.union(other_words))
similarity = intersection / union if union > 0 else 0
if similarity > 0:
other_item_copy = other_item.copy()
other_item_copy["similarity_score"] = similarity
results.append(other_item_copy)
# Sort by similarity score
results.sort(key=lambda x: x["similarity_score"], reverse=True)
# Return top k results
return results[:top_k]
def build_knowledge_graph(items: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Build a simple knowledge graph from items
Args:
items: List of items to include in the graph
Returns:
Knowledge graph as a dictionary
"""
graph = {
"nodes": [],
"edges": []
}
# Track node IDs to avoid duplicates
node_ids = set()
# Add items as nodes
for item in items:
item_id = item.get("id")
if not item_id or item_id in node_ids:
continue
node_type = item.get("type", "unknown")
node = {
"id": item_id,
"label": item.get("title", "Untitled"),
"type": node_type
}
graph["nodes"].append(node)
node_ids.add(item_id)
# Find connections between nodes
for i, item1 in enumerate(items):
item1_id = item1.get("id")
if not item1_id or item1_id not in node_ids:
continue
# Find similar items
similar_items = find_similar_items(item1, items, top_k=5)
for similar_item in similar_items:
similar_id = similar_item.get("id")
if not similar_id or similar_id not in node_ids or similar_id == item1_id:
continue
# Add edge between items
edge = {
"source": item1_id,
"target": similar_id,
"weight": similar_item.get("similarity_score", 0.5),
"type": "similar"
}
graph["edges"].append(edge)
return graph
def detect_duplicates(items: List[Dict[str, Any]], threshold: float = 0.85) -> List[List[Dict[str, Any]]]:
"""Detect potential duplicate items
Args:
items: List of items to check
threshold: Similarity threshold for considering items as duplicates
Returns:
List of groups of duplicate items
"""
if not items or len(items) < 2:
return []
try:
# Get embeddings for items
item_embeddings, valid_items = get_item_embeddings(items)
if len(valid_items) < 2:
return []
# Calculate pairwise similarity
similarity_matrix = cosine_similarity(item_embeddings)
# Find duplicate groups
duplicate_groups = []
processed = set()
for i in range(len(valid_items)):
if i in processed:
continue
group = [valid_items[i]]
processed.add(i)
for j in range(i+1, len(valid_items)):
if j in processed:
continue
if similarity_matrix[i, j] >= threshold:
group.append(valid_items[j])
processed.add(j)
if len(group) > 1:
duplicate_groups.append(group)
return duplicate_groups
except Exception as e:
logger.error(f"Error detecting duplicates: {str(e)}")
return []
def cluster_content(items: List[Dict[str, Any]], num_clusters: int = 5) -> Dict[str, List[Dict[str, Any]]]:
"""Cluster content into groups
Args:
items: List of items to cluster
num_clusters: Number of clusters to create
Returns:
Dictionary mapping cluster labels to lists of items
"""
if not items or len(items) < num_clusters:
return {}
try:
# Get embeddings for items
item_embeddings, valid_items = get_item_embeddings(items)
if len(valid_items) < num_clusters:
return {}
# Perform clustering
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=min(num_clusters, len(valid_items)), random_state=42)
cluster_labels = kmeans.fit_predict(item_embeddings)
# Group items by cluster
clusters = defaultdict(list)
for i, label in enumerate(cluster_labels):
clusters[str(label)].append(valid_items[i])
# Generate cluster names based on common terms
named_clusters = {}
for label, cluster_items in clusters.items():
# Extract all text from cluster items
cluster_text = " ".join([extract_text_from_item(item) for item in cluster_items])
# Find most common words (excluding stopwords)
words = re.findall(r'\b[a-zA-Z]{3,}\b', cluster_text.lower())
word_counts = defaultdict(int)
# Simple stopwords list
stopwords = {"the", "and", "for", "with", "this", "that", "from", "have", "not"}
for word in words:
if word not in stopwords:
word_counts[word] += 1
# Get top words
top_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:3]
if top_words:
cluster_name = ", ".join([word for word, _ in top_words])
named_clusters[cluster_name] = cluster_items
else:
named_clusters[f"Cluster {label}"] = cluster_items
return named_clusters
except Exception as e:
logger.error(f"Error clustering content: {str(e)}")
return {}
def identify_trends(items: List[Dict[str, Any]], time_field: str = "created_at") -> Dict[str, Any]:
"""Identify trends in content over time
Args:
items: List of items to analyze
time_field: Field containing timestamp
Returns:
Dictionary with trend information
"""
if not items:
return {}
try:
import datetime
from collections import Counter
# Group items by time periods
daily_counts = defaultdict(int)
weekly_counts = defaultdict(int)
monthly_counts = defaultdict(int)
# Track topics over time
topics_by_month = defaultdict(Counter)
for item in items:
timestamp = item.get(time_field)
if not timestamp:
continue
# Convert timestamp to datetime
if isinstance(timestamp, (int, float)):
dt = datetime.datetime.fromtimestamp(timestamp)
elif isinstance(timestamp, str):
try:
dt = datetime.datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
except ValueError:
continue
else:
continue
# Count by time period
date_str = dt.strftime("%Y-%m-%d")
week_str = dt.strftime("%Y-%W")
month_str = dt.strftime("%Y-%m")
daily_counts[date_str] += 1
weekly_counts[week_str] += 1
monthly_counts[month_str] += 1
# Extract topics (tags or keywords)
topics = []
if "tags" in item and item["tags"]:
if isinstance(item["tags"], list):
topics.extend(item["tags"])
elif isinstance(item["tags"], str):
topics.extend(item["tags"].split(","))
# If no tags, extract keywords from title
if not topics and "title" in item:
title_words = re.findall(r'\b[a-zA-Z]{3,}\b', item["title"].lower())
stopwords = {"the", "and", "for", "with", "this", "that", "from", "have", "not"}
topics = [word for word in title_words if word not in stopwords][:3]
# Add topics to monthly counter
for topic in topics:
topics_by_month[month_str][topic] += 1
# Find trending topics by month
trending_topics = {}
for month, counter in topics_by_month.items():
trending_topics[month] = counter.most_common(5)
# Calculate growth rates
growth_rates = {}
if len(monthly_counts) >= 2:
months = sorted(monthly_counts.keys())
for i in range(1, len(months)):
current_month = months[i]
prev_month = months[i-1]
current_count = monthly_counts[current_month]
prev_count = monthly_counts[prev_month]
if prev_count > 0:
growth_rate = (current_count - prev_count) / prev_count * 100
growth_rates[current_month] = growth_rate
return {
"daily_counts": dict(daily_counts),
"weekly_counts": dict(weekly_counts),
"monthly_counts": dict(monthly_counts),
"trending_topics": trending_topics,
"growth_rates": growth_rates
}
except Exception as e:
logger.error(f"Error identifying trends: {str(e)}")
return {}
def identify_information_gaps(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Identify potential information gaps in the content
Args:
items: List of items to analyze
Returns:
List of identified information gaps
"""
if not items:
return []
try:
# Cluster the content
clusters = cluster_content(items)
# Identify potential gaps based on cluster sizes and coverage
gaps = []
# Find small clusters that might need more content
for cluster_name, cluster_items in clusters.items():
if len(cluster_items) <= 2: # Small clusters might indicate gaps
gaps.append({
"type": "underdeveloped_topic",
"topic": cluster_name,
"description": f"Limited content on topic: {cluster_name}",
"item_count": len(cluster_items),
"sample_items": [item.get("title", "Untitled") for item in cluster_items]
})
# Identify potential missing connections between clusters
if len(clusters) >= 2:
cluster_names = list(clusters.keys())
for i in range(len(cluster_names)):
for j in range(i+1, len(cluster_names)):
name1 = cluster_names[i]
name2 = cluster_names[j]
# Check if there are connections between clusters
has_connection = False
for item1 in clusters[name1]:
similar_items = find_similar_items(item1, clusters[name2], top_k=1)
if similar_items and similar_items[0].get("similarity_score", 0) > 0.5:
has_connection = True
break
if not has_connection:
gaps.append({
"type": "missing_connection",
"topics": [name1, name2],
"description": f"Potential missing connection between {name1} and {name2}"
})
return gaps
except Exception as e:
logger.error(f"Error identifying information gaps: {str(e)}")
return []