crop-diag-module / prepare_script /sync_neo4j_node.py
Sontranwakumo
init: move from github
88cc76c
import json
import sqlite3
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.models.knowledge_graph import Neo4jConnection
from sentence_transformers import SentenceTransformer
from pyvi.ViTokenizer import tokenize
import faiss
import numpy as np
"""
Script này thực hiện lấy các entity từ neo4j từ xa về và tạo ra các data lưu trong sqlite, đồng thời tạo các embeddings
dựa trên từng row.
"""
# Kết nối SQLite
VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db'
FAISS_INDEX_PATH = 'app/data/faiss_index.index'
conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH)
cursor = conn.cursor()
# Tạo bảng embeddings nếu chưa tồn tại
cursor.execute('''
CREATE TABLE IF NOT EXISTS embeddings (
e_index INTEGER PRIMARY KEY,
id TEXT NOT NULL,
name TEXT NOT NULL,
label TEXT NOT NULL,
properties TEXT NOT NULL
)
''')
def insert_embedding(e_index, id, name, label, properties):
"""Thêm embedding vào SQLite."""
cursor.execute('''
INSERT INTO embeddings (e_index, id, name, label, properties)
VALUES (?, ?, ?, ?, ?)
''', (e_index, id, name, label, json.dumps(properties)))
conn.commit()
print(f"Đã thêm embedding: {name}")
def update_embedding(embedding_id, id, name, label, properties):
"""Cập nhật embedding trong SQLite."""
cursor.execute('''
UPDATE embeddings
SET id = ?, name = ?, label = ?, properties = ?
WHERE e_index = ?
''', (id, name, label, json.dumps(properties), embedding_id))
conn.commit()
print(f"Đã cập nhật embedding ID: {embedding_id}")
def get_all_embeddings():
"""Lấy tất cả embeddings từ SQLite."""
cursor.execute('SELECT * FROM embeddings')
return cursor.fetchall()
def get_embedding_by_id(embedding_id):
"""Lấy embedding theo e_index từ SQLite."""
cursor.execute('SELECT * FROM embeddings WHERE e_index = ?', (embedding_id,))
return cursor.fetchone()
def save_faiss_index(index, index_file=FAISS_INDEX_PATH):
"""Lưu FAISS index vào file."""
faiss.write_index(index, index_file)
print(f"Đã lưu FAISS index vào {index_file}")
def load_faiss_index(index_file=FAISS_INDEX_PATH):
"""Nạp FAISS index từ file."""
if os.path.exists(index_file):
index = faiss.read_index(index_file)
print(f"Đã nạp FAISS index từ {index_file}")
return index
return None
def compute_and_save_embeddings(index_file=FAISS_INDEX_PATH):
"""Tính toán embeddings, lưu vào FAISS và đồng bộ metadata vào SQLite."""
print("Loading model...")
model = SentenceTransformer('dangvantuan/vietnamese-embedding')
print("Model loaded")
# Lấy dữ liệu từ Neo4j
neo4j = Neo4jConnection()
result = neo4j.execute_query("MATCH (n) RETURN n")
corpus = []
# Chuẩn bị corpus và lưu metadata vào SQLite
print("Processing Neo4j data and saving to SQLite...")
for index, record in enumerate(result):
print(record)
label = list(record["n"].labels)[0]
print(label)
embedding = dict(record["n"])
id = embedding.pop('id')
name = embedding.pop('name') if 'name' in embedding else id
properties = embedding
corpus.append(name)
# Kiểm tra và cập nhật/thêm vào SQLite
cursor.execute('SELECT e_index FROM embeddings WHERE e_index = ?', (index,))
existing = cursor.fetchone()
if existing:
update_embedding(index, id, name, label, properties)
else:
insert_embedding(index, id, name, label, properties)
# Tính toán embeddings
print("Tokenizing and encoding...")
tokenized = [tokenize(s) for s in corpus]
embeddings = model.encode(tokenized, show_progress_bar=False)
print("Encoding done")
# Chuẩn hóa embeddings
print("Normalizing...")
faiss.normalize_L2(embeddings)
print("Normalized")
# Tạo và lưu FAISS index
d = embeddings.shape[1]
index = faiss.IndexFlatIP(d)
index.add(embeddings)
save_faiss_index(index, index_file)
print("Processing completed")
return index, corpus, embeddings
def load_or_compute_embeddings(index_file=FAISS_INDEX_PATH):
"""Nạp hoặc tính toán embeddings và FAISS index."""
# Thử nạp FAISS index
index = load_faiss_index(index_file)
# Lấy corpus từ SQLite
embeddings_data = get_all_embeddings()
corpus = [row[2] for row in embeddings_data] # Lấy cột name
if index is None or not corpus:
print("No saved index or corpus found, computing new ones...")
index, corpus, embeddings = compute_and_save_embeddings(index_file)
else:
print("Loaded existing index and corpus")
return index, corpus
def get_qvec_by_text(model, text):
q_token = tokenize(text)
q_vec = model.encode([q_token])
faiss.normalize_L2(q_vec)
return q_vec
if __name__ == "__main__":
try:
index, corpus = load_or_compute_embeddings()
print(f"Index ready with {index.ntotal} embeddings, corpus size: {len(corpus)}")
model = SentenceTransformer('dangvantuan/vietnamese-embedding')
while True:
try:
query = input("Nhập câu truy vấn (nhấn Ctrl+C để thoát): ")
q_vec = get_qvec_by_text(model, query)
k = 1 # số kết quả cần lấy
D, I = index.search(q_vec, k)
print("Câu truy vấn:", query)
print(I[0][0])
print(type(I[0][0]))
print("Câu gần nhất:", get_embedding_by_id(int(I[0][0])), "(khoảng cách:", D[0][0], ")")
print("-" * 50)
except KeyboardInterrupt:
print("\nĐã dừng chương trình!")
break
finally:
conn.close()
print("SQLite connection closed")