|
|
|
import os |
|
from flask import request, jsonify |
|
from sentence_transformers import SentenceTransformer, util |
|
|
|
|
|
|
|
cache_dir = os.environ.get("HF_HOME", "/tmp/huggingface") |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
print("Loading SentenceTransformer model (paraphrase-MiniLM-L6-v2)...") |
|
matcher_model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2', cache_folder=cache_dir) |
|
print("SentenceTransformer model loaded.") |
|
|
|
|
|
SIMILARITY_THRESHOLD = 0.6 |
|
|
|
def handle_match_question(): |
|
data = request.get_json() |
|
if not data or 'user_question' not in data or 'documents' not in data: |
|
return jsonify({'error': 'Invalid request. "user_question" and "documents" are required.'}), 400 |
|
|
|
user_question = data['user_question'] |
|
documents = data['documents'] |
|
|
|
if not documents: |
|
return jsonify({'answer': "There are no notes to search."}) |
|
|
|
|
|
all_questions = [] |
|
|
|
question_to_note_map = {} |
|
|
|
for doc in documents: |
|
note_text = doc.get('note_text', '') |
|
for q in doc.get('questions', []): |
|
all_questions.append(q) |
|
question_to_note_map[q] = note_text |
|
|
|
if not all_questions: |
|
return jsonify({'answer': "No questions have been generated for your notes yet."}) |
|
|
|
try: |
|
|
|
user_embedding = matcher_model.encode(user_question, convert_to_tensor=True) |
|
stored_embeddings = matcher_model.encode(all_questions, convert_to_tensor=True) |
|
|
|
|
|
cosine_scores = util.pytorch_cos_sim(user_embedding, stored_embeddings) |
|
|
|
|
|
best_match_idx = cosine_scores.argmax() |
|
best_score = float(cosine_scores[0][best_match_idx]) |
|
best_question = all_questions[best_match_idx] |
|
|
|
print(f"User Question: '{user_question}'") |
|
print(f"Best matched stored question: '{best_question}' with score: {best_score:.4f}") |
|
|
|
|
|
if best_score > SIMILARITY_THRESHOLD: |
|
|
|
answer = question_to_note_map[best_question] |
|
else: |
|
answer = "Sorry, I couldn't find a relevant note to answer your question." |
|
|
|
return jsonify({'answer': answer}) |
|
|
|
except Exception as e: |
|
print(f"Error during question matching: {e}") |
|
return jsonify({'error': str(e)}), 500 |
|
|