Knowledge_Graph_Generator / generate_knowledge_graph.py
Demosthene-OR's picture
Update generate_knowledge_graph.py
1d1a8c7
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
from pyvis.network import Network
from dotenv import load_dotenv
import os
import asyncio
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
graph_transformer = LLMGraphTransformer(llm=llm)
async def extract_graph_data(text):
documents = [Document(page_content=text)]
graph_documents = await graph_transformer.aconvert_to_graph_documents(documents)
return graph_documents
def visualize_graph(graph_documents):
net = Network(height="600px", width="100%", directed=True, notebook=False, bgcolor="#222222", font_color="white", filter_menu=True, cdn_resources='remote')
nodes = graph_documents[0].nodes
relationships = graph_documents[0].relationships
node_dict = {node.id: node for node in nodes}
valid_edges = []
valid_node_ids = set()
for rel in relationships:
if rel.source.id in node_dict and rel.target.id in node_dict:
valid_edges.append(rel)
valid_node_ids.update([rel.source.id, rel.target.id])
for node_id in valid_node_ids:
node = node_dict[node_id]
try:
net.add_node(node.id, label=node.id, title=node.type, group=node.type)
except:
continue
for rel in valid_edges:
try:
net.add_edge(rel.source.id, rel.target.id, label=rel.type.lower())
except:
continue
net.set_options('{"physics": {"forceAtlas2Based": {"gravitationalConstant": -100, "centralGravity": 0.01, "springLength": 200, "springConstant": 0.08}, "minVelocity": 0.75, "solver": "forceAtlas2Based"}}')
return net
def generate_knowledge_graph(text):
graph_documents = asyncio.run(extract_graph_data(text))
net = visualize_graph(graph_documents)
return net, graph_documents
def answer_question_with_graph(question, graph_documents, k_relations=5):
all_relationships = []
for doc in graph_documents:
all_relationships.extend(doc.relationships)
if not all_relationships:
return "Aucune relation trouvée dans le graphe.", visualize_graph(graph_documents)
rel_docs = []
for i, rel in enumerate(all_relationships):
text_rep = f"L'entité '{rel.source.id}' a pour relation '{rel.type}' avec l'entité '{rel.target.id}'."
rel_docs.append(Document(page_content=text_rep, metadata={"rel_index": i}))
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = FAISS.from_documents(rel_docs, embeddings)
retrieved_docs = vectorstore.similarity_search(question, k=k_relations)
used_relationships = [all_relationships[doc.metadata["rel_index"]] for doc in retrieved_docs]
context = "\n".join([doc.page_content for doc in retrieved_docs])
prompt = PromptTemplate(
template='''Tu es un assistant expert qui répond aux questions en se basant UNIQUEMENT sur ce sous-ensemble de relations extraites d'un graphe de connaissances.\n\nContexte (Relations pertinentes trouvées) :\n{context}\n\nQuestion : {question}\n\nRéponds de manière claire et concise en français. Si la réponse n'est pas dans le contexte fourni, dis-le explicitement.''',
input_variables=["context", "question"]
)
chain = prompt | llm
answer = chain.invoke({"context": context, "question": question}).content
net = Network(height="450px", width="100%", directed=True, bgcolor="#222222", font_color="white")
nodes_added = set()
for rel in used_relationships:
if rel.source.id not in nodes_added:
net.add_node(rel.source.id, label=rel.source.id, title=rel.source.type, group=rel.source.type)
nodes_added.add(rel.source.id)
if rel.target.id not in nodes_added:
net.add_node(rel.target.id, label=rel.target.id, title=rel.target.type, group=rel.target.type)
nodes_added.add(rel.target.id)
try:
net.add_edge(rel.source.id, rel.target.id, label=rel.type)
except:
pass
net.set_options('{"physics": {"forceAtlas2Based": {"gravitationalConstant": -50}}}')
net.save_graph("filtered_graph.html")
return answer, net