FastAPI-Backend-Models / topic_similarity_google_example.py
Yassine Mhirsi
similarity
22ad0ba
raw
history blame
6.4 kB
from datetime import datetime
import os
import json
import hashlib
from pathlib import Path
from dotenv import load_dotenv
from google import genai
from google.genai import types
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# Load environment variables from .env file
load_dotenv()
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY is not set in environment variables.")
# Get the path to topics.json relative to this file
TOPICS_FILE = Path(__file__).parent.parent / "data" / "topics.json"
# Cache file for topic embeddings
EMBEDDINGS_CACHE_FILE = Path(__file__).parent.parent / "data" / "topic_embeddings_cache.json"
# Create a Google Generative AI client with the API key
client = genai.Client(api_key=GOOGLE_API_KEY)
def load_topics():
"""Load topics from topics.json file."""
with open(TOPICS_FILE, 'r', encoding='utf-8') as f:
data = json.load(f)
return data.get("topics", [])
def get_topics_hash(topics):
"""Generate a hash of the topics list to verify cache validity."""
topics_str = json.dumps(topics, sort_keys=True)
return hashlib.md5(topics_str.encode('utf-8')).hexdigest()
def load_cached_embeddings():
"""Load cached topic embeddings if they exist and are valid."""
if not EMBEDDINGS_CACHE_FILE.exists():
return None
try:
with open(EMBEDDINGS_CACHE_FILE, 'r', encoding='utf-8') as f:
cache_data = json.load(f)
# Verify cache is valid by checking topics hash
current_topics = load_topics()
current_hash = get_topics_hash(current_topics)
if cache_data.get("topics_hash") == current_hash:
# Convert list embeddings back to numpy arrays
embeddings = [np.array(emb) for emb in cache_data.get("embeddings", [])]
return embeddings
else:
# Topics have changed, cache is invalid
return None
except (json.JSONDecodeError, KeyError, ValueError) as e:
# Cache file is corrupted or invalid format
print(f"Warning: Could not load cached embeddings: {e}")
return None
def save_cached_embeddings(embeddings, topics):
"""Save topic embeddings to cache file."""
topics_hash = get_topics_hash(topics)
# Convert numpy arrays to lists for JSON serialization
embeddings_list = [emb.tolist() for emb in embeddings]
cache_data = {
"topics_hash": topics_hash,
"embeddings": embeddings_list,
"model": "models/text-embedding-004",
"cached_at": datetime.now().isoformat()
}
try:
with open(EMBEDDINGS_CACHE_FILE, 'w', encoding='utf-8') as f:
json.dump(cache_data, f, indent=2)
print(f"Cached {len(embeddings)} topic embeddings to {EMBEDDINGS_CACHE_FILE}")
except Exception as e:
print(f"Warning: Could not save cached embeddings: {e}")
def get_topic_embeddings():
"""
Get topic embeddings, loading from cache if available, otherwise generating and caching them.
Returns:
numpy.ndarray: Array of topic embeddings
"""
topics = load_topics()
# Try to load from cache first
cached_embeddings = load_cached_embeddings()
if cached_embeddings is not None:
print(f"Loaded {len(cached_embeddings)} topic embeddings from cache")
return np.array(cached_embeddings)
# Cache miss or invalid - generate embeddings
print(f"Generating embeddings for {len(topics)} topics (this may take a moment)...")
embedding_response = client.models.embed_content(
model="models/text-embedding-004",
contents=topics,
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
)
if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
raise RuntimeError("Embedding API did not return embeddings.")
embeddings = [np.array(e.values) for e in embedding_response.embeddings]
# Save to cache for future use
save_cached_embeddings(embeddings, topics)
return np.array(embeddings)
def find_most_similar_topic(input_text: str):
"""
Compare a single input text to all topics and return the highest cosine similarity.
Uses cached topic embeddings to avoid re-embedding topics on every call.
Args:
input_text: The text to compare against topics
Returns:
dict: Contains 'topic', 'similarity', and 'index' of the most similar topic
"""
# Load topics from JSON file
topics = load_topics()
if not topics:
raise ValueError("No topics found in topics.json")
# Get topic embeddings (from cache or generate)
topic_embeddings = get_topic_embeddings()
# Only embed the input text (much faster!)
embedding_response = client.models.embed_content(
model="models/text-embedding-004",
contents=[input_text],
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
)
if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
raise RuntimeError("Embedding API did not return embeddings.")
# Extract input embedding
input_embedding = np.array(embedding_response.embeddings[0].values).reshape(1, -1)
# Calculate cosine similarity between input and each topic
similarities = cosine_similarity(input_embedding, topic_embeddings)[0]
# Find the highest similarity
max_index = np.argmax(similarities)
max_similarity = similarities[max_index]
most_similar_topic = topics[max_index]
return {
"topic": most_similar_topic,
"similarity": float(max_similarity),
"index": int(max_index)
}
if __name__ == "__main__":
# Example usage
#start time
start_time = datetime.now()
test_text = "we should abandon the use of school uniform since one should be allowed to express their individuality by the clothes they were."
result = find_most_similar_topic(test_text)
print(f"Input text: '{test_text}'")
print(f"Most similar topic: '{result['topic']}'")
print(f"Cosine similarity: {result['similarity']:.4f}%")
#end time
end_time = datetime.now()
#in seconds
print(f"Time taken: {(end_time - start_time).total_seconds()} seconds")