|
|
from datetime import datetime |
|
|
import os |
|
|
import json |
|
|
import hashlib |
|
|
from pathlib import Path |
|
|
from dotenv import load_dotenv |
|
|
from google import genai |
|
|
from google.genai import types |
|
|
import numpy as np |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
if not GOOGLE_API_KEY: |
|
|
raise ValueError("GOOGLE_API_KEY is not set in environment variables.") |
|
|
|
|
|
|
|
|
TOPICS_FILE = Path(__file__).parent.parent / "data" / "topics.json" |
|
|
|
|
|
EMBEDDINGS_CACHE_FILE = Path(__file__).parent.parent / "data" / "topic_embeddings_cache.json" |
|
|
|
|
|
|
|
|
client = genai.Client(api_key=GOOGLE_API_KEY) |
|
|
|
|
|
|
|
|
def load_topics(): |
|
|
"""Load topics from topics.json file.""" |
|
|
with open(TOPICS_FILE, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
return data.get("topics", []) |
|
|
|
|
|
|
|
|
def get_topics_hash(topics): |
|
|
"""Generate a hash of the topics list to verify cache validity.""" |
|
|
topics_str = json.dumps(topics, sort_keys=True) |
|
|
return hashlib.md5(topics_str.encode('utf-8')).hexdigest() |
|
|
|
|
|
|
|
|
def load_cached_embeddings(): |
|
|
"""Load cached topic embeddings if they exist and are valid.""" |
|
|
if not EMBEDDINGS_CACHE_FILE.exists(): |
|
|
return None |
|
|
|
|
|
try: |
|
|
with open(EMBEDDINGS_CACHE_FILE, 'r', encoding='utf-8') as f: |
|
|
cache_data = json.load(f) |
|
|
|
|
|
|
|
|
current_topics = load_topics() |
|
|
current_hash = get_topics_hash(current_topics) |
|
|
|
|
|
if cache_data.get("topics_hash") == current_hash: |
|
|
|
|
|
embeddings = [np.array(emb) for emb in cache_data.get("embeddings", [])] |
|
|
return embeddings |
|
|
else: |
|
|
|
|
|
return None |
|
|
except (json.JSONDecodeError, KeyError, ValueError) as e: |
|
|
|
|
|
print(f"Warning: Could not load cached embeddings: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def save_cached_embeddings(embeddings, topics): |
|
|
"""Save topic embeddings to cache file.""" |
|
|
topics_hash = get_topics_hash(topics) |
|
|
|
|
|
|
|
|
embeddings_list = [emb.tolist() for emb in embeddings] |
|
|
|
|
|
cache_data = { |
|
|
"topics_hash": topics_hash, |
|
|
"embeddings": embeddings_list, |
|
|
"model": "models/text-embedding-004", |
|
|
"cached_at": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
try: |
|
|
with open(EMBEDDINGS_CACHE_FILE, 'w', encoding='utf-8') as f: |
|
|
json.dump(cache_data, f, indent=2) |
|
|
print(f"Cached {len(embeddings)} topic embeddings to {EMBEDDINGS_CACHE_FILE}") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not save cached embeddings: {e}") |
|
|
|
|
|
|
|
|
def get_topic_embeddings(): |
|
|
""" |
|
|
Get topic embeddings, loading from cache if available, otherwise generating and caching them. |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray: Array of topic embeddings |
|
|
""" |
|
|
topics = load_topics() |
|
|
|
|
|
|
|
|
cached_embeddings = load_cached_embeddings() |
|
|
if cached_embeddings is not None: |
|
|
print(f"Loaded {len(cached_embeddings)} topic embeddings from cache") |
|
|
return np.array(cached_embeddings) |
|
|
|
|
|
|
|
|
print(f"Generating embeddings for {len(topics)} topics (this may take a moment)...") |
|
|
embedding_response = client.models.embed_content( |
|
|
model="models/text-embedding-004", |
|
|
contents=topics, |
|
|
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") |
|
|
) |
|
|
|
|
|
if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None: |
|
|
raise RuntimeError("Embedding API did not return embeddings.") |
|
|
|
|
|
embeddings = [np.array(e.values) for e in embedding_response.embeddings] |
|
|
|
|
|
|
|
|
save_cached_embeddings(embeddings, topics) |
|
|
|
|
|
return np.array(embeddings) |
|
|
|
|
|
|
|
|
def find_most_similar_topic(input_text: str): |
|
|
""" |
|
|
Compare a single input text to all topics and return the highest cosine similarity. |
|
|
Uses cached topic embeddings to avoid re-embedding topics on every call. |
|
|
|
|
|
Args: |
|
|
input_text: The text to compare against topics |
|
|
|
|
|
Returns: |
|
|
dict: Contains 'topic', 'similarity', and 'index' of the most similar topic |
|
|
""" |
|
|
|
|
|
topics = load_topics() |
|
|
|
|
|
if not topics: |
|
|
raise ValueError("No topics found in topics.json") |
|
|
|
|
|
|
|
|
topic_embeddings = get_topic_embeddings() |
|
|
|
|
|
|
|
|
embedding_response = client.models.embed_content( |
|
|
model="models/text-embedding-004", |
|
|
contents=[input_text], |
|
|
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") |
|
|
) |
|
|
|
|
|
if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None: |
|
|
raise RuntimeError("Embedding API did not return embeddings.") |
|
|
|
|
|
|
|
|
input_embedding = np.array(embedding_response.embeddings[0].values).reshape(1, -1) |
|
|
|
|
|
|
|
|
similarities = cosine_similarity(input_embedding, topic_embeddings)[0] |
|
|
|
|
|
|
|
|
max_index = np.argmax(similarities) |
|
|
max_similarity = similarities[max_index] |
|
|
most_similar_topic = topics[max_index] |
|
|
|
|
|
return { |
|
|
"topic": most_similar_topic, |
|
|
"similarity": float(max_similarity), |
|
|
"index": int(max_index) |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
start_time = datetime.now() |
|
|
test_text = "we should abandon the use of school uniform since one should be allowed to express their individuality by the clothes they were." |
|
|
result = find_most_similar_topic(test_text) |
|
|
print(f"Input text: '{test_text}'") |
|
|
print(f"Most similar topic: '{result['topic']}'") |
|
|
print(f"Cosine similarity: {result['similarity']:.4f}%") |
|
|
|
|
|
end_time = datetime.now() |
|
|
|
|
|
print(f"Time taken: {(end_time - start_time).total_seconds()} seconds") |