File size: 8,923 Bytes
a343d94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from dotenv import load_dotenv
load_dotenv()
import os
import chromadb
import re
import logging
from sqlalchemy import create_engine
from uuid import uuid4
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, Settings
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.embeddings.nebius import NebiusEmbedding
from llama_index.core import SQLDatabase
from llama_index.core.schema import TextNode
# Configure logging
logging.basicConfig(level=logging.INFO)
# --- Configuration ---
DATABASE_PATH = os.path.join('data', 'sales_database.db')
# CORRECTED PATHS based on your clarification:
# schema is under knowledge_base
SCHEMA_DIR = os.path.join('knowledge_base', 'schema')
DATA_DICTIONARY_PATH = os.path.join(SCHEMA_DIR, 'data_dictionary.md')
SALES_SCHEMA_SQL_PATH = os.path.join(SCHEMA_DIR, 'sales_schema.sql')
# kpi_definitions.md is under business_glossary
KPI_DEFINITIONS_PATH = os.path.join('knowledge_base', 'business_glossary', 'kpi_definitions.md')
CHROMA_DB_SCHEMA_PATH = os.path.join('.', 'chroma_db_schema')
CHROMA_DB_KPI_PATH = os.path.join('.', 'chroma_db_kpi')
# Ensure NEBIUS_API_KEY is set
if "NEBIUS_API_KEY" not in os.environ:
logging.critical("NEBIUS_API_KEY environment variable not set. Please set it before running this script.")
raise ValueError("NEBIUS_API_KEY environment variable not set. Please set it before running this script.")
# --- Initialize Nebius AI Embedding Model (Global for consistency) ---
print("\n--- Initializing Nebius AI Embedding Model ---")
try:
embed_model = NebiusEmbedding(
api_key=os.environ["NEBIUS_API_KEY"],
model_name="BAAI/bge-en-icl", # Consistent with schema_retriever_tool for now
api_base="https://api.studio.nebius.com/v1/" # Verify this base URL is correct for the model
)
test_string = "This is a test string to generate an embedding for diagnostic purposes."
test_embedding = embed_model.get_text_embedding(test_string)
if test_embedding is None or not isinstance(test_embedding, list) or not all(isinstance(x, (int, float)) for x in test_embedding):
logging.critical(f"FATAL ERROR: NebiusEmbedding returned invalid output for test string. Type: {type(test_embedding)}, Value: {test_embedding[:10] if isinstance(test_embedding, list) else test_embedding}")
raise ValueError("NebiusEmbedding test failed: returned invalid embedding. Cannot proceed with RAG index creation.")
logging.info(f"Nebius AI Embedding Model initialized and tested successfully. Embedding length: {len(test_embedding)}")
except Exception as e:
logging.critical(f"UNRECOVERABLE ERROR during Nebius AI Embedding Model initialization or testing: {e}")
print(f"\n!!!! UNRECOVERABLE ERROR: {e} !!!!")
print("Please check your NEBIUS_API_KEY, model_name, and api_base configuration carefully.")
raise
print("--- Nebius AI Embedding Model setup complete. Proceeding to Knowledge Base setup. ---")
# --- Helper Function to parse Data Dictionary ---
def parse_data_dictionary_md(md_path: str) -> dict:
table_descriptions = {}
if not os.path.exists(md_path):
logging.error(f"Data dictionary file not found: {md_path}")
return {}
with open(md_path, 'r', encoding='utf-8') as f:
content = f.read()
table_headers = re.finditer(r'## Table: `(\w+)`\n\*\*Purpose:\*\* (.+?)(?=\n## Table: `|\Z)', content, re.DOTALL)
for match in table_headers:
table_name = match.group(1).strip()
purpose = match.group(2).strip()
table_descriptions[table_name] = purpose
logging.info(f"Parsed table: {table_name}, Purpose length: {len(purpose)}")
return table_descriptions
# --- Setup for Schema Retriever Agent's Knowledge Base ---
print("\n--- Setting up Schema Retriever Agent's Knowledge Base (chroma_db_schema) ---")
try:
data_dict_descriptions = parse_data_dictionary_md(DATA_DICTIONARY_PATH)
logging.info(f"Loaded descriptions for {len(data_dict_descriptions)} tables from data_dictionary.md")
engine = create_engine(f"sqlite:///{DATABASE_PATH}")
sql_database = SQLDatabase(engine)
logging.info(f"Connected to database: {DATABASE_PATH}")
all_table_names = sql_database.get_usable_table_names()
schema_nodes = []
if not all_table_names:
logging.warning("WARNING: No tables found in the database. Schema KB will be empty.")
for table_name in all_table_names:
ddl = sql_database.get_single_table_info(table_name)
human_description = data_dict_descriptions.get(table_name, "No specific description available for this table.")
combined_context = f"Table Name: {table_name}\n" \
f"Description: {human_description}\n" \
f"Table Schema (DDL):\n{ddl if ddl else ''}"
node_embedding = embed_model.get_text_embedding(combined_context)
if node_embedding is None:
raise ValueError(f"Failed to generate embedding for schema table: {table_name}")
# Adding a simple metadata dictionary
schema_nodes.append(TextNode(text=combined_context, embedding=node_embedding, id_=table_name, metadata={"table_name": table_name, "source": "data_dictionary"}))
chroma_client_schema = chromadb.PersistentClient(path=CHROMA_DB_SCHEMA_PATH)
chroma_collection_schema = chroma_client_schema.get_or_create_collection(name="schema_kb")
if chroma_collection_schema.count() > 0:
logging.info(f"Clearing {chroma_collection_schema.count()} existing items from schema_kb before re-indexing.")
chroma_collection_schema.delete(ids=[id_ for id_ in chroma_collection_schema.get()['ids']])
logging.info(f"Adding {len(schema_nodes)} schema nodes to ChromaDB directly...")
if schema_nodes:
chroma_collection_schema.add(
documents=[node.text for node in schema_nodes],
embeddings=[node.embedding for node in schema_nodes],
# Pass non-empty metadata dict for each node
metadatas=[node.metadata for node in schema_nodes],
ids=[node.id_ for node in schema_nodes]
)
logging.info(f"Schema knowledge base indexed and persisted to {CHROMA_DB_SCHEMA_PATH}")
except Exception as e:
logging.exception("Error setting up Schema KB:")
print(f"Error setting up Schema KB: {e}")
print("Please ensure your database is created and accessible and data_dictionary.md is correctly formatted.")
print("If the error persists, there might be a deeper compatibility issue. See above for more detailed embedding checks.")
raise
# --- Setup for KPI Answering Agent's Knowledge Base (kpi_definitions.md) ---
print("\n--- Setting up KPI Answering Agent's Knowledge Base (chroma_db_kpi) ---")
try:
kpi_docs = SimpleDirectoryReader(input_files=[KPI_DEFINITIONS_PATH]).load_data()
logging.info(f"Loaded {len(kpi_docs)} documents for KPI Agent.")
kpi_nodes = []
for doc in kpi_docs:
node_embedding = embed_model.get_text_embedding(doc.get_content())
if node_embedding is None:
raise ValueError(f"Failed to generate embedding for KPI document: {doc.id_}")
# Adding a simple metadata dictionary
kpi_nodes.append(TextNode(text=doc.get_content(), embedding=node_embedding, id_=str(uuid4()), metadata={"source_file": os.path.basename(KPI_DEFINITIONS_PATH), "doc_type": "kpi_definition"}))
chroma_client_kpi = chromadb.PersistentClient(path=CHROMA_DB_KPI_PATH)
chroma_collection_kpi = chroma_client_kpi.get_or_create_collection(name="kpi_kb")
if chroma_collection_kpi.count() > 0:
logging.info(f"Clearing {chroma_collection_kpi.count()} existing items from kpi_kb before re-indexing.")
chroma_collection_kpi.delete(ids=[id_ for id_ in chroma_collection_kpi.get()['ids']])
logging.info(f"Adding {len(kpi_nodes)} KPI nodes to ChromaDB directly...")
if kpi_nodes:
chroma_collection_kpi.add(
documents=[node.text for node in kpi_nodes],
embeddings=[node.embedding for node in kpi_nodes],
# Pass non-empty metadata dict for each node
metadatas=[node.metadata for node in kpi_nodes],
ids=[node.id_ for node in kpi_nodes]
)
logging.info(f"KPI knowledge base indexed and persisted to {CHROMA_DB_KPI_PATH}")
except Exception as e:
logging.exception("Error setting up KPI KB:")
print(f"Error setting up KPI KB: {e}")
print("Please ensure your 'kpi_definitions.md' file exists and contains valid text content.")
print("If the error persists, there might be a deeper compatibility issue. See above for more detailed embedding checks.")
raise
print("\nKnowledge base setup complete for RAG agents!")
print(f"Indices are saved in '{CHROMA_DB_SCHEMA_PATH}' and '{CHROMA_DB_KPI_PATH}' directories.") |