Spaces:
Sleeping
Sleeping
from llama_index.core import Document | |
from llama_index.core import KnowledgeGraphIndex, ServiceContext, StorageContext | |
from llama_index.llms.openai import OpenAI | |
from llama_index.core.graph_stores import SimpleGraphStore | |
from llama_index.core import SimpleDirectoryReader, load_index_from_storage | |
from typing import List | |
from dotenv import load_dotenv | |
import os | |
import json | |
import networkx as nx | |
from pyvis.network import Network | |
from datetime import datetime | |
from retrieve import get_latest_dir | |
import html | |
load_dotenv() | |
llm = OpenAI( | |
temperature=0.0, model="gpt-3.5-turbo", api_key=os.getenv("OPENAI_API_KEY") | |
) | |
graph_store = SimpleGraphStore() | |
storage_context = StorageContext.from_defaults(graph_store=graph_store) | |
service_context = ServiceContext.from_defaults( | |
llm=llm, chunk_size=2048, chunk_overlap=24 | |
) | |
def create_document(input_dir: str) -> List[Document]: | |
""" | |
Create a document from the given directory. | |
Args: | |
input_dir (str): The input directory to read the documents from. | |
Returns: | |
List[Document]: The list of documents from the directory. | |
""" | |
reader = SimpleDirectoryReader( | |
input_dir, exclude_hidden=True, required_exts=[".json"] | |
) | |
products_document = [] | |
for docs in reader.iter_data(): | |
products_document.extend(docs) | |
return products_document | |
def kg_triplet_extract_fn(text) -> List[str]: | |
""" | |
Extract the triplets from the text. | |
Args: | |
text (str): The text to extract the triplets from. | |
Returns: | |
List[str]: The list of triplets extracted from the text. | |
""" | |
json_text = text.split("\n\n")[-1] | |
product_spec = json.loads(json_text) | |
triplets = [] | |
product_name = product_spec["name"] | |
del product_spec["name"] | |
for key, value in product_spec.items(): | |
triplets.append((product_name, key, value)) | |
return triplets | |
def generate_graph_visualization(kg_index): | |
""" | |
Generate a graph visualization from the KG index. | |
Args: | |
kg_index (KnowledgeGraphIndex): The Knowledge Graph index to generate the visualization from. | |
Returns: | |
str: The path to the generated graph visualization. | |
""" | |
output_directory = os.getenv("GRAPH_VIS_DIR", "graph_vis") | |
# Generate a timestamp for the filename | |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
graph_output_file = f"{timestamp}.html" | |
graph_output_path = os.path.join(output_directory, graph_output_file) | |
g = kg_index.get_networkx_graph(limit=20000) | |
net = Network( | |
notebook=False, | |
cdn_resources="remote", | |
height="800px", | |
width="100%", | |
select_menu=True, | |
filter_menu=False, | |
) | |
net.from_nx(g) | |
net.force_atlas_2based(central_gravity=0.015, gravity=-31) | |
net.save_graph(graph_output_path) | |
print(f"Graph visualization saved to: {graph_output_path}") | |
return graph_output_path | |
def plot_subgraph(triplets): | |
""" | |
Plot a subgraph from the triplets. | |
Args: | |
triplets (str): The triplets to plot the subgraph from. | |
Returns: | |
str: The escaped HTML content to display the subgraph | |
""" | |
G = nx.DiGraph() | |
for edge_str in eval(triplets): | |
source, action, target = eval(edge_str) | |
G.add_edge(source, target, label=action) | |
net = Network(notebook=True, cdn_resources="remote", height="400px", width="100%") | |
net.from_nx(G) | |
net.force_atlas_2based(central_gravity=0.015, gravity=-31) | |
html_content = net.generate_html() | |
escaped_html = html.escape(html_content) | |
return escaped_html | |
def create_kg(max_features: int = 60): | |
""" | |
Create a Knowledge Graph from the given directory. | |
Args: | |
max_features (int): The maximum number of features to use for the KG. | |
Returns: | |
KnowledgeGraphIndex: The Knowledge Graph index. | |
""" | |
input_dir = os.getenv("PROD_SPEC_DIR", "prod_spec") | |
product_documents = create_document(input_dir) | |
kg_index = KnowledgeGraphIndex.from_documents( | |
documents=product_documents, | |
max_triplets_per_chunk=max_features, | |
storage_context=storage_context, | |
service_context=service_context, | |
show_progress=True, | |
include_embeddings=True, | |
kg_triplet_extract_fn=kg_triplet_extract_fn, | |
) | |
graphvis_path = generate_graph_visualization(kg_index) | |
return kg_index, graphvis_path | |
def persist_kg(kg_index: KnowledgeGraphIndex) -> str: | |
""" | |
Persist the KG index to storage. | |
Args: | |
kg_index (KnowledgeGraphIndex): The Knowledge Graph index to persist. | |
Returns: | |
str: The path to the persisted KG index. | |
""" | |
output_dir = os.getenv("GRAPH_DIR", "graphs") | |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
kg_path = f"{output_dir}/{timestamp}" | |
kg_index.storage_context.persist(kg_path) | |
return kg_path | |
def load_kg(kg_dir: str) -> KnowledgeGraphIndex: | |
""" | |
Load the KG index from the given directory. | |
Args: | |
kg_dir (str): The parent directory to load the KG index from. | |
Returns: | |
KnowledgeGraphIndex: The loaded Knowledge Graph index. | |
""" | |
kg_path = get_latest_dir(kg_dir) | |
kg_index = load_index_from_storage( | |
StorageContext.from_defaults(persist_dir=kg_path) | |
) | |
return kg_index | |
def query(kg_dir: str, query: str): | |
""" | |
Query the KG index for a given query. | |
Args: | |
kg_dir (str): The directory to load the KG index from. | |
query (str): The query to ask the KG index. | |
Returns: | |
Response: The response from the KG index. | |
""" | |
kg_index = load_kg(kg_dir) | |
query_engine = kg_index.as_query_engine( | |
include_text=True, | |
response_mode="refine", | |
graph_store_query_depth=6, | |
similarity_top_k=5, | |
) | |
response = query_engine.query(query) | |
return response | |
def query_graph_qa(graph_rag_index, query, search_level): | |
""" | |
Query the Graph-RAG model for a given query. | |
Args: | |
graph_rag_index (KnowledgeGraphIndex): The Graph-RAG model index. | |
query (str): The query to ask the Graph-RAG model. | |
search_level (int): The max search level to use for the Graph-RAG model. | |
Returns: | |
tuple: The response, reference, and reference text from the Graph-RAG model. | |
""" | |
myretriever = graph_rag_index.as_retriever( | |
include_text=True, | |
similarity_top_k=search_level, | |
) | |
query_engine = graph_rag_index.as_query_engine( | |
sub_retrievers=[ | |
myretriever, | |
], | |
graph_store_query_depth=6, | |
include_text=True, | |
similarity_top_k=search_level, | |
) | |
response = query_engine.query(query) | |
nodes = myretriever.retrieve(query) | |
reference = [] | |
for _, value in response.metadata.items(): | |
if isinstance(value, dict) and "kg_rel_texts" in value: | |
reference = value["kg_rel_texts"] | |
break | |
reference_text = [] | |
for node in nodes: | |
reference_text.append(node.text) | |
return response, reference, reference_text | |
if __name__ == "__main__": | |
kg_index, graphvis_path = create_kg() | |
persist_kg(kg_index) | |
kg_index = load_kg(os.getenv("GRAPH_DIR", "graphs")) | |
generate_graph_visualization(kg_index) | |
response = query( | |
os.getenv("GRAPH_DIR", "graphs"), | |
"Tell me the Built-in memory in Apple iPhone 15 Pro Max 256Gb Blue Titanium?", | |
) | |
print(response) | |
key = list(response.metadata)[-1] | |
print(response.metadata[key]) | |