AGENT_ANALYSE_RAG_dev / audit_page /knowledge_graph.py
Ilyas KHIAT
enhance graph
0222cea
import streamlit as st
from utils.kg.construct_kg import get_graph
from utils.audit.rag import get_text_from_content_for_doc,get_text_from_content_for_audio
from streamlit_agraph import agraph, Node, Edge, Config
import random
import math
from utils.audit.response_llm import generate_response_via_langchain
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import PromptTemplate
from itext2kg.models import KnowledgeGraph
def if_node_exists(nodes, node_id):
"""
Check if a node exists in the graph.
Args:
graph (dict): A dictionary representing the graph with keys 'nodes' and 'relationships'.
node_id (str): The id of the node to check.
Returns:
return_value: True if the node exists, False otherwise.
"""
for node in nodes:
if node.id == node_id:
return True
return False
def generate_random_color():
r = random.randint(180, 255)
g = random.randint(180, 255)
b = random.randint(180, 255)
return (r, g, b)
def rgb_to_hex(rgb):
return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])
def get_node_types(graph):
node_types = set()
for node in graph.nodes:
node_types.add(node.type)
for relationship in graph.relationships:
source = relationship.source
target = relationship.target
node_types.add(source.type)
node_types.add(target.type)
return node_types
def get_node_types_advanced(graph:KnowledgeGraph):
node_types = set()
for node in graph.entities:
node_types.add(node.label)
for relationship in graph.relationships:
source = relationship.startEntity
target = relationship.endEntity
node_types.add(source.label)
node_types.add(target.label)
return node_types
def color_distance(color1, color2):
# Calculate Euclidean distance between two RGB colors
return math.sqrt((color1[0] - color2[0]) ** 2 + (color1[1] - color2[1]) ** 2 + (color1[2] - color2[2]) ** 2)
def generate_distinct_colors(num_colors, min_distance=30):
colors = []
while len(colors) < num_colors:
new_color = generate_random_color()
if all(color_distance(new_color, existing_color) >= min_distance for existing_color in colors):
colors.append(new_color)
return [rgb_to_hex(color) for color in colors]
def list_to_dict_colors(node_types:set):
number_of_colors = len(node_types)
colors = generate_distinct_colors(number_of_colors)
node_colors = {}
for i, node_type in enumerate(node_types):
node_colors[node_type] = colors[i]
return node_colors
def convert_neo4j_to_agraph(neo4j_graph, node_colors):
"""
Converts a Neo4j graph into an Agraph format.
Args:
neo4j_graph (dict): A dictionary representing the Neo4j graph with keys 'nodes' and 'relationships'.
'nodes' is a list of dicts with each dict having 'id' and 'type' keys.
'relationships' is a list of dicts with 'source', 'target', and 'type' keys.
Returns:
return_value: The Agraph visualization object.
"""
nodes = []
edges = []
# Creating Agraph nodes
for node in neo4j_graph.nodes:
# Use the node id as the Agraph node id
node_id = node.id.replace(" ", "_") # Replace spaces with underscores for ids
label = node.id
type = node.type
size = 25 # Default size, can be customized
shape = "circle" # Default shape, can be customized
# For example purposes, no images are added, but you can set 'image' if needed.
new_node = Node(id=node_id,title=type, label=label, size=size, shape=shape,color=node_colors[type])
if not if_node_exists(nodes, new_node.id):
nodes.append(new_node)
# Creating Agraph edges
for relationship in neo4j_graph.relationships:
size = 25 # Default size, can be customized
shape = "circle" # Default shape, can be customized
source = relationship.source
source_type = source.type
source_id = source.id.replace(" ", "_")
label_source = source.id
source_node = Node(id=source_id,title=source_type, label=label_source, size=size, shape=shape,color=node_colors[source_type])
if not if_node_exists(nodes, source_node.id):
nodes.append(source_node)
target = relationship.target
target_type = target.type
target_id = target.id.replace(" ", "_")
label_target = target.id
target_node = Node(id=target_id,title=target_type, label=label_target, size=size, shape=shape,color=node_colors[target_type])
if not if_node_exists(nodes, target_node.id):
nodes.append(target_node)
label = relationship.type
edges.append(Edge(source=source_id, label=label, target=target_id))
# Define the configuration for Agraph
config = Config(width=1200, height=800, directed=True, physics=True, hierarchical=True,from_json="config.json")
# Create the Agraph visualization
return edges, nodes, config
def convert_advanced_neo4j_to_agraph(neo4j_graph:KnowledgeGraph, node_colors):
"""
Converts a Neo4j graph into an Agraph format.
Args:
neo4j_graph (dict): A dictionary representing the Neo4j graph with keys 'nodes' and 'relationships'.
'nodes' is a list of dicts with each dict having 'id' and 'type' keys.
'relationships' is a list of dicts with 'source', 'target', and 'type' keys.
Returns:
return_value: The Agraph visualization object.
"""
nodes = []
edges = []
# Creating Agraph nodes
for node in neo4j_graph.entities:
# Use the node id as the Agraph node id
node_id = node.name.replace(" ", "_") # Replace spaces with underscores for ids
label = node.name
type = node.label
size = 25 # Default size, can be customized
shape = "circle" # Default shape, can be customized
# For example purposes, no images are added, but you can set 'image' if needed.
new_node = Node(id=node_id,title=type, label=label, size=size, shape=shape,color=node_colors[type])
# if not if_node_exists(nodes, new_node.id):
# nodes.append(new_node)
nodes.append(new_node)
# Creating Agraph edges
for relationship in neo4j_graph.relationships:
size = 25 # Default size, can be customized
shape = "circle" # Default shape, can be customized
source = relationship.startEntity
source_type = source.label
source_id = source.name.replace(" ", "_")
label_source = source.name
source_node = Node(id=source_id,title=source_type, label=label_source, size=size, shape=shape,color=node_colors[source_type])
# if not if_node_exists(nodes, source_node.id):
# nodes.append(source_node)
target = relationship.endEntity
target_type = target.label
target_id = target.name.replace(" ", "_")
label_target = target.name
target_node = Node(id=target_id,title=target_type, label=label_target, size=size, shape=shape,color=node_colors[target_type])
# if not if_node_exists(nodes, target_node.id):
# nodes.append(target_node)
label = relationship.name
edges.append(Edge(source=source_id, label=label, target=target_id))
# Define the configuration
config = Config(width=1200, height=800, directed=True, physics=True, hierarchical=True,from_json="config.json")
# Create the Agraph visualization
return edges, nodes, config
def display_graph(edges, nodes, config):
# Display the Agraph visualization
return agraph(edges=edges, nodes=nodes, config=config)
def filter_nodes_by_types(nodes:list[Node], node_types_filter:list) -> list[Node]:
filtered_nodes = []
for node in nodes:
if node.title in node_types_filter: #the title represents the type of the node
filtered_nodes.append(node)
return filtered_nodes
@st.dialog(title="Changer la vue")
def change_view_dialog():
st.write("Changer la vue")
for index, item in enumerate(st.session_state.filter_views.keys()):
emp = st.empty()
col1, col2, col3 = emp.columns([8, 1, 1])
if index > 0 and col2.button("🗑️", key=f"del{index}"):
del st.session_state.filter_views[item]
st.session_state.current_view = "Vue par défaut"
st.rerun()
but_content = "🔍" if st.session_state.current_view != item else "✅"
if col3.button(but_content, key=f"valid{index}"):
st.session_state.current_view = item
st.rerun()
if len(st.session_state.filter_views.keys()) > index:
with col1.expander(item):
if index > 0:
change_name = st.text_input("Nom de la vue", label_visibility="collapsed", placeholder="Changez le nom de la vue",key=f"change_name{index}")
if st.button("Renommer",key=f"rename{index}"):
if change_name != "":
st.session_state.filter_views[change_name] = st.session_state.filter_views.pop(item)
st.session_state.current_view = change_name
st.rerun()
st.markdown("\n".join(f"- {label.strip()}" for label in st.session_state.filter_views[item]))
else:
emp.empty()
@st.dialog(title="Ajouter une vue")
def add_view_dialog(filters):
st.write("Ajouter une vue")
view_name = st.text_input("Nom de la vue")
st.markdown("les filtres actuels:")
st.write(filters)
if st.button("Ajouter la vue"):
st.session_state.filter_views[view_name] = filters
st.session_state.current_view = view_name
st.rerun()
@st.dialog(title="Changer la couleur")
def change_color_dialog():
st.write("Changer la couleur")
for node_type,color in st.session_state.node_types.items():
color = st.color_picker(f"La couleur de l'entité **{node_type.strip()}**",color)
st.session_state.node_types[node_type] = color
if st.button("Valider"):
st.rerun()
def kg_main():
#st.set_page_config(page_title="Graphe de connaissance", page_icon="", layout="wide")
if "audit" not in st.session_state or st.session_state.audit == {}:
st.error("Veuillez d'abord effectuer un audit pour visualiser le graphe de connaissance.")
return
if "cr" not in st.session_state:
st.error("Veuillez d'abord effectuer un compte rendu pour visualiser le graphe de connaissance.")
return
if "graph" not in st.session_state:
st.session_state.graph = None
if "filter_views" not in st.session_state:
st.session_state.filter_views = {}
if "current_view" not in st.session_state:
st.session_state.current_view = None
st.title("Graphe de connaissance")
if "node_types" not in st.session_state:
st.session_state.node_types = None
if "summary" not in st.session_state:
st.session_state.summary = None
if "chat_graph_history" not in st.session_state:
st.session_state.chat_graph_history = []
audit = st.session_state.audit_simplified
# content = st.session_state.audit["content"]
# if audit["type de fichier"] == "pdf":
# text = get_text_from_content_for_doc(content)
# elif audit["type de fichier"] == "audio":
# text = get_text_from_content_for_audio(content)
text = st.session_state.cr + "mots clés" + audit["Mots clés"]
#summary_prompt = f"Voici un ensemble de documents : {text}. À partir de ces documents, veuillez fournir des résumés concis en vous concentrant sur l'extraction des relations essentielles et des événements. Il est crucial d'inclure les dates des actions ou des événements, car elles seront utilisées pour l'analyse chronologique. Par exemple : 'Sam a été licencié par le conseil d'administration d'OpenAI le 17 novembre 2023 (17 novembre, vendredi)', ce qui illustre la relation entre Sam et OpenAI ainsi que la date de l'événement."
if st.button("Générer le graphe"):
# with st.spinner("Extractions des relations..."):
# sum = generate_response_openai(summary_prompt,model="gpt-4o")
# st.session_state.summary = sum
with st.spinner("Génération du graphe..."):
keywords_list = audit["Mots clés"].strip().split(",")
allowed_nodes_types =keywords_list+ ["Person","Organization","Location","Event","Date","Time","Ressource","Concept"]
graph = get_graph(text,allowed_nodes=allowed_nodes_types)
st.session_state.graph = graph
node_types = get_node_types(graph[0])
nodes_type_dict = list_to_dict_colors(node_types)
st.session_state.node_types = nodes_type_dict
st.session_state.filter_views["Vue par défaut"] = list(node_types)
st.session_state.current_view = "Vue par défaut"
else:
graph = st.session_state.graph
if graph is not None:
#st.write(graph)
edges,nodes,config = convert_neo4j_to_agraph(graph[0],st.session_state.node_types)
col1, col2 = st.columns([2.5, 1.5])
with col1.container(border=True,height=800):
st.write("##### Visualisation du graphe (**"+st.session_state.current_view+"**)")
filter_col,add_view_col,change_view_col,color_col = st.columns([9,1,1,1])
if color_col.button("🎨",help="Changer la couleur"):
change_color_dialog()
if change_view_col.button("🔍",help="Changer de vue"):
change_view_dialog()
#add mots cles to evry label in audit["Mots clés"]
#filter_labels = [ label + " (mot clé)" if label.strip().lower() in audit["Mots clés"].strip().lower().split(",") else label for label in st.session_state.filter_views[st.session_state.current_view] ]
filter = filter_col.multiselect("Filtrer selon l'étiquette",st.session_state.node_types.keys(),placeholder="Sélectionner une ou plusieurs étiquettes",default=st.session_state.filter_views[st.session_state.current_view],label_visibility="collapsed")
if add_view_col.button("➕",help="Ajouter une vue"):
add_view_dialog(filter)
if filter:
nodes = filter_nodes_by_types(nodes,filter)
selected = display_graph(edges,nodes,config)
with col2.container(border=True,height=800):
st.markdown("##### Dialoguer avec le graphe")
user_query = st.chat_input("Par ici ...")
if user_query is not None and user_query != "":
st.session_state.chat_graph_history.append(HumanMessage(content=user_query))
with st.container(height=650, border=False):
for message in st.session_state.chat_graph_history:
if isinstance(message, AIMessage):
with st.chat_message("AI"):
st.markdown(message.content)
elif isinstance(message, HumanMessage):
with st.chat_message("Moi"):
st.write(message.content)
#check if last message is human message
if len(st.session_state.chat_graph_history) > 0:
last_message = st.session_state.chat_graph_history[-1]
if isinstance(last_message, HumanMessage):
with st.chat_message("AI"):
retreive = st.session_state.vectorstore.as_retriever()
context = retreive.invoke(last_message.content)
wrapped_prompt = f"Étant donné le contexte suivant {context}, et le graph de connaissance: {graph}, {last_message.content}"
response = st.write_stream(generate_response_via_langchain(wrapped_prompt,stream=True))
st.session_state.chat_graph_history.append(AIMessage(content=response))
if selected is not None:
with st.chat_message("AI"):
st.markdown(f" EXPLORER LES DONNEES CONTENUES DANS **{selected}**")
prompts = [f"Extrait moi toutes les informations du noeud ''{selected}'' ➡️",
f"Montre moi les conversations autour du noeud ''{selected}'' ➡️"]
for i,prompt in enumerate(prompts):
button = st.button(prompt,key=f"p_{i}",on_click=lambda i=i: st.session_state.chat_graph_history.append(HumanMessage(content=prompts[i])))
node_types = st.session_state.node_types