BackEnd / core /embedding_model.py
HaRin2806
fix bug
289c45a
import logging
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
import uuid
import os
from config import EMBEDDING_MODEL, CHROMA_PERSIST_DIRECTORY, COLLECTION_NAME
# Cấu hình logging
logger = logging.getLogger(__name__)
# Global instance để implement singleton pattern
_embedding_model_instance = None
def get_embedding_model():
"""
Singleton pattern để đảm bảo chỉ có một instance của EmbeddingModel
"""
global _embedding_model_instance
if _embedding_model_instance is None:
logger.info("Khởi tạo EmbeddingModel instance lần đầu")
_embedding_model_instance = EmbeddingModel()
else:
logger.debug("Sử dụng EmbeddingModel instance đã có")
return _embedding_model_instance
class EmbeddingModel:
def __init__(self):
"""Khởi tạo embedding model và ChromaDB client"""
logger.info(f"Đang khởi tạo embedding model: {EMBEDDING_MODEL}")
try:
# Khởi tạo sentence transformer với trust_remote_code=True
self.model = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True)
logger.info("Đã tải sentence transformer model")
except Exception as e:
logger.error(f"Lỗi khởi tạo model: {e}")
# Thử với cache_folder explicit
cache_dir = os.getenv('SENTENCE_TRANSFORMERS_HOME', '/app/.cache/sentence-transformers')
self.model = SentenceTransformer(EMBEDDING_MODEL, cache_folder=cache_dir, trust_remote_code=True)
logger.info("Đã tải sentence transformer model với cache folder explicit")
# Đảm bảo thư mục ChromaDB tồn tại và có quyền ghi
try:
os.makedirs(CHROMA_PERSIST_DIRECTORY, exist_ok=True)
# Test ghi file để kiểm tra permission
test_file = os.path.join(CHROMA_PERSIST_DIRECTORY, 'test_permission.tmp')
with open(test_file, 'w') as f:
f.write('test')
os.remove(test_file)
logger.info(f"Thư mục ChromaDB đã sẵn sàng: {CHROMA_PERSIST_DIRECTORY}")
except Exception as e:
logger.error(f"Lỗi tạo/kiểm tra thư mục ChromaDB: {e}")
# Fallback to /tmp directory
import tempfile
CHROMA_PERSIST_DIRECTORY = os.path.join(tempfile.gettempdir(), 'chroma_db')
os.makedirs(CHROMA_PERSIST_DIRECTORY, exist_ok=True)
logger.warning(f"Sử dụng thư mục tạm thời: {CHROMA_PERSIST_DIRECTORY}")
# Khởi tạo ChromaDB client với persistent storage
try:
self.chroma_client = chromadb.PersistentClient(
path=CHROMA_PERSIST_DIRECTORY,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
logger.info(f"Đã kết nối ChromaDB tại: {CHROMA_PERSIST_DIRECTORY}")
except Exception as e:
logger.error(f"Lỗi kết nối ChromaDB: {e}")
# Fallback to in-memory client
logger.warning("Fallback to in-memory ChromaDB client")
self.chroma_client = chromadb.Client()
# Lấy hoặc tạo collection
try:
self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME)
logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với {self.collection.count()} items")
except Exception:
logger.warning(f"Collection '{COLLECTION_NAME}' không tồn tại, tạo mới...")
self.collection = self.chroma_client.create_collection(name=COLLECTION_NAME)
logger.info(f"Đã tạo collection mới: {COLLECTION_NAME}")
def _add_prefix_to_text(self, text, is_query=True):
"""
Thêm prefix cho text theo yêu cầu của multilingual-e5-base
"""
# Kiểm tra xem text đã có prefix chưa
if text.startswith(('query:', 'passage:')):
return text
# Thêm prefix phù hợp
if is_query:
return f"query: {text}"
else:
return f"passage: {text}"
def encode(self, texts, is_query=True):
"""
Encode văn bản thành embeddings
Args:
texts (str or list): Văn bản hoặc danh sách văn bản cần encode
is_query (bool): True nếu là query, False nếu là passage
Returns:
list: Embeddings vector
"""
try:
if isinstance(texts, str):
texts = [texts]
# Thêm prefix cho texts
processed_texts = [self._add_prefix_to_text(text, is_query) for text in texts]
logger.debug(f"Đang encode {len(processed_texts)} văn bản")
embeddings = self.model.encode(processed_texts, show_progress_bar=False, normalize_embeddings=True)
return embeddings.tolist()
except Exception as e:
logger.error(f"Lỗi encode văn bản: {e}")
raise
def search(self, query, top_k=5, age_filter=None):
"""
Tìm kiếm văn bản tương tự trong ChromaDB
Args:
query (str): Câu hỏi cần tìm kiếm
top_k (int): Số lượng kết quả trả về
age_filter (int): Lọc theo độ tuổi (optional)
Returns:
list: Danh sách kết quả tìm kiếm
"""
try:
logger.debug(f"Dang tim kiem cho query: {query[:50]}...")
# Encode query thành embedding (với prefix query:)
query_embedding = self.encode(query, is_query=True)[0]
# Tạo where clause cho age filter
where_clause = None
if age_filter:
where_clause = {
"$and": [
{"age_min": {"$lte": age_filter}},
{"age_max": {"$gte": age_filter}}
]
}
# Thực hiện search trong ChromaDB
search_results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=where_clause,
include=['documents', 'metadatas', 'distances']
)
if not search_results or not search_results['documents']:
logger.warning("Khong tim thay ket qua nao")
return []
# Format kết quả
results = []
documents = search_results['documents'][0]
metadatas = search_results['metadatas'][0]
distances = search_results['distances'][0]
for i, (doc, metadata, distance) in enumerate(zip(documents, metadatas, distances)):
results.append({
'document': doc,
'metadata': metadata or {},
'distance': distance,
'similarity': 1 - distance, # Chuyển distance thành similarity
'rank': i + 1
})
logger.info(f"Tim thay {len(results)} ket qua cho query")
return results
except Exception as e:
logger.error(f"Loi tim kiem: {e}")
return []
def add_documents(self, documents, metadatas=None, ids=None):
"""
Thêm documents vào ChromaDB
Args:
documents (list): Danh sách văn bản
metadatas (list): Danh sách metadata tương ứng
ids (list): Danh sách ID tương ứng (optional)
Returns:
bool: True nếu thành công
"""
try:
if not documents:
logger.warning("Không có documents để thêm")
return False
# Tạo IDs nếu không được cung cấp
if not ids:
ids = [str(uuid.uuid4()) for _ in documents]
# Tạo metadatas rỗng nếu không được cung cấp
if not metadatas:
metadatas = [{} for _ in documents]
logger.info(f"Đang thêm {len(documents)} documents vào ChromaDB")
# Encode documents thành embeddings (với prefix passage:)
embeddings = self.encode(documents, is_query=False)
# Thêm vào collection
self.collection.add(
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=ids
)
logger.info(f"Đã thêm thành công {len(documents)} documents")
return True
except Exception as e:
logger.error(f"Lỗi thêm documents: {e}")
return False
def index_chunks(self, chunks):
"""
Index các chunks dữ liệu vào ChromaDB
"""
try:
if not chunks:
logger.warning("Không có chunks để index")
return False
documents = []
metadatas = []
ids = []
for chunk in chunks:
if not chunk.get('content'):
logger.warning(f"Chunk thiếu content: {chunk}")
continue
documents.append(chunk['content'])
# Lấy metadata đã được chuẩn bị sẵn
metadata = chunk.get('metadata', {})
metadatas.append(metadata)
# Sử dụng ID có sẵn hoặc tạo mới
chunk_id = chunk.get('id') or str(uuid.uuid4())
ids.append(chunk_id)
if not documents:
logger.warning("Không có documents hợp lệ để index")
return False
# Batch processing để tránh overload
batch_size = 100
total_batches = (len(documents) + batch_size - 1) // batch_size
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i + batch_size]
batch_metas = metadatas[i:i + batch_size]
batch_ids = ids[i:i + batch_size]
batch_num = (i // batch_size) + 1
logger.info(f"Đang xử lý batch {batch_num}/{total_batches} ({len(batch_docs)} items)")
success = self.add_documents(batch_docs, batch_metas, batch_ids)
if not success:
logger.error(f"Lỗi xử lý batch {batch_num}")
return False
logger.info(f"Đã index thành công {len(documents)} chunks")
return True
except Exception as e:
logger.error(f"Lỗi index chunks: {e}")
return False
def count(self):
"""Đếm số lượng documents trong collection"""
try:
return self.collection.count()
except Exception as e:
logger.error(f"Lỗi đếm documents: {e}")
return 0
def delete_collection(self):
"""Xóa collection hiện tại"""
try:
logger.warning(f"Đang xóa collection: {COLLECTION_NAME}")
self.chroma_client.delete_collection(name=COLLECTION_NAME)
# Tạo lại collection mới
self.collection = self.chroma_client.create_collection(name=COLLECTION_NAME)
logger.info("Đã tạo lại collection mới")
return True
except Exception as e:
logger.error(f"Lỗi xóa collection: {e}")
return False
def get_stats(self):
"""Lấy thống kê về collection"""
try:
total_count = self.count()
# Lấy sample để phân tích metadata
sample_results = self.collection.get(limit=min(100, total_count))
# Thống kê content types
content_types = {}
chapters = {}
age_groups = {}
if sample_results and sample_results.get('metadatas'):
for metadata in sample_results['metadatas']:
if not metadata:
continue
# Content type stats
content_type = metadata.get('content_type', 'unknown')
content_types[content_type] = content_types.get(content_type, 0) + 1
# Chapter stats
chapter = metadata.get('chapter', 'unknown')
chapters[chapter] = chapters.get(chapter, 0) + 1
# Age group stats
age_group = metadata.get('age_group', 'unknown')
age_groups[age_group] = age_groups.get(age_group, 0) + 1
return {
'total_documents': total_count,
'content_types': content_types,
'chapters': chapters,
'age_groups': age_groups,
'collection_name': COLLECTION_NAME,
'embedding_model': EMBEDDING_MODEL
}
except Exception as e:
logger.error(f"Lỗi lấy stats: {e}")
return {
'total_documents': 0,
'error': str(e)
}