Spaces:
Sleeping
Sleeping
import json | |
import sqlite3 | |
import os | |
import sys | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from app.models.knowledge_graph import Neo4jConnection | |
from sentence_transformers import SentenceTransformer | |
from pyvi.ViTokenizer import tokenize | |
import faiss | |
import numpy as np | |
""" | |
Script này thực hiện lấy các entity từ neo4j từ xa về và tạo ra các data lưu trong sqlite, đồng thời tạo các embeddings | |
dựa trên từng row. | |
""" | |
# Kết nối SQLite | |
VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db' | |
FAISS_INDEX_PATH = 'app/data/faiss_index.index' | |
conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH) | |
cursor = conn.cursor() | |
# Tạo bảng embeddings nếu chưa tồn tại | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS embeddings ( | |
e_index INTEGER PRIMARY KEY, | |
id TEXT NOT NULL, | |
name TEXT NOT NULL, | |
label TEXT NOT NULL, | |
properties TEXT NOT NULL | |
) | |
''') | |
def insert_embedding(e_index, id, name, label, properties): | |
"""Thêm embedding vào SQLite.""" | |
cursor.execute(''' | |
INSERT INTO embeddings (e_index, id, name, label, properties) | |
VALUES (?, ?, ?, ?, ?) | |
''', (e_index, id, name, label, json.dumps(properties))) | |
conn.commit() | |
print(f"Đã thêm embedding: {name}") | |
def update_embedding(embedding_id, id, name, label, properties): | |
"""Cập nhật embedding trong SQLite.""" | |
cursor.execute(''' | |
UPDATE embeddings | |
SET id = ?, name = ?, label = ?, properties = ? | |
WHERE e_index = ? | |
''', (id, name, label, json.dumps(properties), embedding_id)) | |
conn.commit() | |
print(f"Đã cập nhật embedding ID: {embedding_id}") | |
def get_all_embeddings(): | |
"""Lấy tất cả embeddings từ SQLite.""" | |
cursor.execute('SELECT * FROM embeddings') | |
return cursor.fetchall() | |
def get_embedding_by_id(embedding_id): | |
"""Lấy embedding theo e_index từ SQLite.""" | |
cursor.execute('SELECT * FROM embeddings WHERE e_index = ?', (embedding_id,)) | |
return cursor.fetchone() | |
def save_faiss_index(index, index_file=FAISS_INDEX_PATH): | |
"""Lưu FAISS index vào file.""" | |
faiss.write_index(index, index_file) | |
print(f"Đã lưu FAISS index vào {index_file}") | |
def load_faiss_index(index_file=FAISS_INDEX_PATH): | |
"""Nạp FAISS index từ file.""" | |
if os.path.exists(index_file): | |
index = faiss.read_index(index_file) | |
print(f"Đã nạp FAISS index từ {index_file}") | |
return index | |
return None | |
def compute_and_save_embeddings(index_file=FAISS_INDEX_PATH): | |
"""Tính toán embeddings, lưu vào FAISS và đồng bộ metadata vào SQLite.""" | |
print("Loading model...") | |
model = SentenceTransformer('dangvantuan/vietnamese-embedding') | |
print("Model loaded") | |
# Lấy dữ liệu từ Neo4j | |
neo4j = Neo4jConnection() | |
result = neo4j.execute_query("MATCH (n) RETURN n") | |
corpus = [] | |
# Chuẩn bị corpus và lưu metadata vào SQLite | |
print("Processing Neo4j data and saving to SQLite...") | |
for index, record in enumerate(result): | |
print(record) | |
label = list(record["n"].labels)[0] | |
print(label) | |
embedding = dict(record["n"]) | |
id = embedding.pop('id') | |
name = embedding.pop('name') if 'name' in embedding else id | |
properties = embedding | |
corpus.append(name) | |
# Kiểm tra và cập nhật/thêm vào SQLite | |
cursor.execute('SELECT e_index FROM embeddings WHERE e_index = ?', (index,)) | |
existing = cursor.fetchone() | |
if existing: | |
update_embedding(index, id, name, label, properties) | |
else: | |
insert_embedding(index, id, name, label, properties) | |
# Tính toán embeddings | |
print("Tokenizing and encoding...") | |
tokenized = [tokenize(s) for s in corpus] | |
embeddings = model.encode(tokenized, show_progress_bar=False) | |
print("Encoding done") | |
# Chuẩn hóa embeddings | |
print("Normalizing...") | |
faiss.normalize_L2(embeddings) | |
print("Normalized") | |
# Tạo và lưu FAISS index | |
d = embeddings.shape[1] | |
index = faiss.IndexFlatIP(d) | |
index.add(embeddings) | |
save_faiss_index(index, index_file) | |
print("Processing completed") | |
return index, corpus, embeddings | |
def load_or_compute_embeddings(index_file=FAISS_INDEX_PATH): | |
"""Nạp hoặc tính toán embeddings và FAISS index.""" | |
# Thử nạp FAISS index | |
index = load_faiss_index(index_file) | |
# Lấy corpus từ SQLite | |
embeddings_data = get_all_embeddings() | |
corpus = [row[2] for row in embeddings_data] # Lấy cột name | |
if index is None or not corpus: | |
print("No saved index or corpus found, computing new ones...") | |
index, corpus, embeddings = compute_and_save_embeddings(index_file) | |
else: | |
print("Loaded existing index and corpus") | |
return index, corpus | |
def get_qvec_by_text(model, text): | |
q_token = tokenize(text) | |
q_vec = model.encode([q_token]) | |
faiss.normalize_L2(q_vec) | |
return q_vec | |
if __name__ == "__main__": | |
try: | |
index, corpus = load_or_compute_embeddings() | |
print(f"Index ready with {index.ntotal} embeddings, corpus size: {len(corpus)}") | |
model = SentenceTransformer('dangvantuan/vietnamese-embedding') | |
while True: | |
try: | |
query = input("Nhập câu truy vấn (nhấn Ctrl+C để thoát): ") | |
q_vec = get_qvec_by_text(model, query) | |
k = 1 # số kết quả cần lấy | |
D, I = index.search(q_vec, k) | |
print("Câu truy vấn:", query) | |
print(I[0][0]) | |
print(type(I[0][0])) | |
print("Câu gần nhất:", get_embedding_by_id(int(I[0][0])), "(khoảng cách:", D[0][0], ")") | |
print("-" * 50) | |
except KeyboardInterrupt: | |
print("\nĐã dừng chương trình!") | |
break | |
finally: | |
conn.close() | |
print("SQLite connection closed") | |