ComfyKnowledgeGraph / graph_handler.py
jibinmathew's picture
Upload 42 files
eb957df verified
from llama_index.core import Document
from llama_index.core import KnowledgeGraphIndex, ServiceContext, StorageContext
from llama_index.llms.openai import OpenAI
from llama_index.core.graph_stores import SimpleGraphStore
from llama_index.core import SimpleDirectoryReader, load_index_from_storage
from typing import List
from dotenv import load_dotenv
import os
import json
import networkx as nx
from pyvis.network import Network
from datetime import datetime
from retrieve import get_latest_dir
import html
load_dotenv()
llm = OpenAI(
temperature=0.0, model="gpt-3.5-turbo", api_key=os.getenv("OPENAI_API_KEY")
)
graph_store = SimpleGraphStore()
storage_context = StorageContext.from_defaults(graph_store=graph_store)
service_context = ServiceContext.from_defaults(
llm=llm, chunk_size=2048, chunk_overlap=24
)
def create_document(input_dir: str) -> List[Document]:
"""
Create a document from the given directory.
Args:
input_dir (str): The input directory to read the documents from.
Returns:
List[Document]: The list of documents from the directory.
"""
reader = SimpleDirectoryReader(
input_dir, exclude_hidden=True, required_exts=[".json"]
)
products_document = []
for docs in reader.iter_data():
products_document.extend(docs)
return products_document
def kg_triplet_extract_fn(text) -> List[str]:
"""
Extract the triplets from the text.
Args:
text (str): The text to extract the triplets from.
Returns:
List[str]: The list of triplets extracted from the text.
"""
json_text = text.split("\n\n")[-1]
product_spec = json.loads(json_text)
triplets = []
product_name = product_spec["name"]
del product_spec["name"]
for key, value in product_spec.items():
triplets.append((product_name, key, value))
return triplets
def generate_graph_visualization(kg_index):
"""
Generate a graph visualization from the KG index.
Args:
kg_index (KnowledgeGraphIndex): The Knowledge Graph index to generate the visualization from.
Returns:
str: The path to the generated graph visualization.
"""
output_directory = os.getenv("GRAPH_VIS_DIR", "graph_vis")
# Generate a timestamp for the filename
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
graph_output_file = f"{timestamp}.html"
graph_output_path = os.path.join(output_directory, graph_output_file)
g = kg_index.get_networkx_graph(limit=20000)
net = Network(
notebook=False,
cdn_resources="remote",
height="800px",
width="100%",
select_menu=True,
filter_menu=False,
)
net.from_nx(g)
net.force_atlas_2based(central_gravity=0.015, gravity=-31)
net.save_graph(graph_output_path)
print(f"Graph visualization saved to: {graph_output_path}")
return graph_output_path
def plot_subgraph(triplets):
"""
Plot a subgraph from the triplets.
Args:
triplets (str): The triplets to plot the subgraph from.
Returns:
str: The escaped HTML content to display the subgraph
"""
G = nx.DiGraph()
for edge_str in eval(triplets):
source, action, target = eval(edge_str)
G.add_edge(source, target, label=action)
net = Network(notebook=True, cdn_resources="remote", height="400px", width="100%")
net.from_nx(G)
net.force_atlas_2based(central_gravity=0.015, gravity=-31)
html_content = net.generate_html()
escaped_html = html.escape(html_content)
return escaped_html
def create_kg(max_features: int = 60):
"""
Create a Knowledge Graph from the given directory.
Args:
max_features (int): The maximum number of features to use for the KG.
Returns:
KnowledgeGraphIndex: The Knowledge Graph index.
"""
input_dir = os.getenv("PROD_SPEC_DIR", "prod_spec")
product_documents = create_document(input_dir)
kg_index = KnowledgeGraphIndex.from_documents(
documents=product_documents,
max_triplets_per_chunk=max_features,
storage_context=storage_context,
service_context=service_context,
show_progress=True,
include_embeddings=True,
kg_triplet_extract_fn=kg_triplet_extract_fn,
)
graphvis_path = generate_graph_visualization(kg_index)
return kg_index, graphvis_path
def persist_kg(kg_index: KnowledgeGraphIndex) -> str:
"""
Persist the KG index to storage.
Args:
kg_index (KnowledgeGraphIndex): The Knowledge Graph index to persist.
Returns:
str: The path to the persisted KG index.
"""
output_dir = os.getenv("GRAPH_DIR", "graphs")
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
kg_path = f"{output_dir}/{timestamp}"
kg_index.storage_context.persist(kg_path)
return kg_path
def load_kg(kg_dir: str) -> KnowledgeGraphIndex:
"""
Load the KG index from the given directory.
Args:
kg_dir (str): The parent directory to load the KG index from.
Returns:
KnowledgeGraphIndex: The loaded Knowledge Graph index.
"""
kg_path = get_latest_dir(kg_dir)
kg_index = load_index_from_storage(
StorageContext.from_defaults(persist_dir=kg_path)
)
return kg_index
def query(kg_dir: str, query: str):
"""
Query the KG index for a given query.
Args:
kg_dir (str): The directory to load the KG index from.
query (str): The query to ask the KG index.
Returns:
Response: The response from the KG index.
"""
kg_index = load_kg(kg_dir)
query_engine = kg_index.as_query_engine(
include_text=True,
response_mode="refine",
graph_store_query_depth=6,
similarity_top_k=5,
)
response = query_engine.query(query)
return response
def query_graph_qa(graph_rag_index, query, search_level):
"""
Query the Graph-RAG model for a given query.
Args:
graph_rag_index (KnowledgeGraphIndex): The Graph-RAG model index.
query (str): The query to ask the Graph-RAG model.
search_level (int): The max search level to use for the Graph-RAG model.
Returns:
tuple: The response, reference, and reference text from the Graph-RAG model.
"""
myretriever = graph_rag_index.as_retriever(
include_text=True,
similarity_top_k=search_level,
)
query_engine = graph_rag_index.as_query_engine(
sub_retrievers=[
myretriever,
],
graph_store_query_depth=6,
include_text=True,
similarity_top_k=search_level,
)
response = query_engine.query(query)
nodes = myretriever.retrieve(query)
reference = []
for _, value in response.metadata.items():
if isinstance(value, dict) and "kg_rel_texts" in value:
reference = value["kg_rel_texts"]
break
reference_text = []
for node in nodes:
reference_text.append(node.text)
return response, reference, reference_text
if __name__ == "__main__":
kg_index, graphvis_path = create_kg()
persist_kg(kg_index)
kg_index = load_kg(os.getenv("GRAPH_DIR", "graphs"))
generate_graph_visualization(kg_index)
response = query(
os.getenv("GRAPH_DIR", "graphs"),
"Tell me the Built-in memory in Apple iPhone 15 Pro Max 256Gb Blue Titanium?",
)
print(response)
key = list(response.metadata)[-1]
print(response.metadata[key])