Spaces:
Running
Running
import streamlit as st | |
from utils.kg.construct_kg import get_graph | |
from utils.audit.rag import get_text_from_content_for_doc,get_text_from_content_for_audio | |
from streamlit_agraph import agraph, Node, Edge, Config | |
import random | |
import math | |
from utils.audit.response_llm import generate_response_via_langchain | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_core.prompts import PromptTemplate | |
from itext2kg.models import KnowledgeGraph | |
def if_node_exists(nodes, node_id): | |
""" | |
Check if a node exists in the graph. | |
Args: | |
graph (dict): A dictionary representing the graph with keys 'nodes' and 'relationships'. | |
node_id (str): The id of the node to check. | |
Returns: | |
return_value: True if the node exists, False otherwise. | |
""" | |
for node in nodes: | |
if node.id == node_id: | |
return True | |
return False | |
def generate_random_color(): | |
r = random.randint(180, 255) | |
g = random.randint(180, 255) | |
b = random.randint(180, 255) | |
return (r, g, b) | |
def rgb_to_hex(rgb): | |
return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2]) | |
def get_node_types(graph): | |
node_types = set() | |
for node in graph.nodes: | |
node_types.add(node.type) | |
for relationship in graph.relationships: | |
source = relationship.source | |
target = relationship.target | |
node_types.add(source.type) | |
node_types.add(target.type) | |
return node_types | |
def get_node_types_advanced(graph:KnowledgeGraph): | |
node_types = set() | |
for node in graph.entities: | |
node_types.add(node.label) | |
for relationship in graph.relationships: | |
source = relationship.startEntity | |
target = relationship.endEntity | |
node_types.add(source.label) | |
node_types.add(target.label) | |
return node_types | |
def color_distance(color1, color2): | |
# Calculate Euclidean distance between two RGB colors | |
return math.sqrt((color1[0] - color2[0]) ** 2 + (color1[1] - color2[1]) ** 2 + (color1[2] - color2[2]) ** 2) | |
def generate_distinct_colors(num_colors, min_distance=30): | |
colors = [] | |
while len(colors) < num_colors: | |
new_color = generate_random_color() | |
if all(color_distance(new_color, existing_color) >= min_distance for existing_color in colors): | |
colors.append(new_color) | |
return [rgb_to_hex(color) for color in colors] | |
def list_to_dict_colors(node_types:set): | |
number_of_colors = len(node_types) | |
colors = generate_distinct_colors(number_of_colors) | |
node_colors = {} | |
for i, node_type in enumerate(node_types): | |
node_colors[node_type] = colors[i] | |
return node_colors | |
def convert_neo4j_to_agraph(neo4j_graph, node_colors): | |
""" | |
Converts a Neo4j graph into an Agraph format. | |
Args: | |
neo4j_graph (dict): A dictionary representing the Neo4j graph with keys 'nodes' and 'relationships'. | |
'nodes' is a list of dicts with each dict having 'id' and 'type' keys. | |
'relationships' is a list of dicts with 'source', 'target', and 'type' keys. | |
Returns: | |
return_value: The Agraph visualization object. | |
""" | |
nodes = [] | |
edges = [] | |
# Creating Agraph nodes | |
for node in neo4j_graph.nodes: | |
# Use the node id as the Agraph node id | |
node_id = node.id.replace(" ", "_") # Replace spaces with underscores for ids | |
label = node.id | |
type = node.type | |
size = 25 # Default size, can be customized | |
shape = "circle" # Default shape, can be customized | |
# For example purposes, no images are added, but you can set 'image' if needed. | |
new_node = Node(id=node_id,title=type, label=label, size=size, shape=shape,color=node_colors[type]) | |
if not if_node_exists(nodes, new_node.id): | |
nodes.append(new_node) | |
# Creating Agraph edges | |
for relationship in neo4j_graph.relationships: | |
size = 25 # Default size, can be customized | |
shape = "circle" # Default shape, can be customized | |
source = relationship.source | |
source_type = source.type | |
source_id = source.id.replace(" ", "_") | |
label_source = source.id | |
source_node = Node(id=source_id,title=source_type, label=label_source, size=size, shape=shape,color=node_colors[source_type]) | |
if not if_node_exists(nodes, source_node.id): | |
nodes.append(source_node) | |
target = relationship.target | |
target_type = target.type | |
target_id = target.id.replace(" ", "_") | |
label_target = target.id | |
target_node = Node(id=target_id,title=target_type, label=label_target, size=size, shape=shape,color=node_colors[target_type]) | |
if not if_node_exists(nodes, target_node.id): | |
nodes.append(target_node) | |
label = relationship.type | |
edges.append(Edge(source=source_id, label=label, target=target_id)) | |
# Define the configuration for Agraph | |
config = Config(width=1200, height=800, directed=True, physics=True, hierarchical=True,from_json="config.json") | |
# Create the Agraph visualization | |
return edges, nodes, config | |
def convert_advanced_neo4j_to_agraph(neo4j_graph:KnowledgeGraph, node_colors): | |
""" | |
Converts a Neo4j graph into an Agraph format. | |
Args: | |
neo4j_graph (dict): A dictionary representing the Neo4j graph with keys 'nodes' and 'relationships'. | |
'nodes' is a list of dicts with each dict having 'id' and 'type' keys. | |
'relationships' is a list of dicts with 'source', 'target', and 'type' keys. | |
Returns: | |
return_value: The Agraph visualization object. | |
""" | |
nodes = [] | |
edges = [] | |
# Creating Agraph nodes | |
for node in neo4j_graph.entities: | |
# Use the node id as the Agraph node id | |
node_id = node.name.replace(" ", "_") # Replace spaces with underscores for ids | |
label = node.name | |
type = node.label | |
size = 25 # Default size, can be customized | |
shape = "circle" # Default shape, can be customized | |
# For example purposes, no images are added, but you can set 'image' if needed. | |
new_node = Node(id=node_id,title=type, label=label, size=size, shape=shape,color=node_colors[type]) | |
# if not if_node_exists(nodes, new_node.id): | |
# nodes.append(new_node) | |
nodes.append(new_node) | |
# Creating Agraph edges | |
for relationship in neo4j_graph.relationships: | |
size = 25 # Default size, can be customized | |
shape = "circle" # Default shape, can be customized | |
source = relationship.startEntity | |
source_type = source.label | |
source_id = source.name.replace(" ", "_") | |
label_source = source.name | |
source_node = Node(id=source_id,title=source_type, label=label_source, size=size, shape=shape,color=node_colors[source_type]) | |
# if not if_node_exists(nodes, source_node.id): | |
# nodes.append(source_node) | |
target = relationship.endEntity | |
target_type = target.label | |
target_id = target.name.replace(" ", "_") | |
label_target = target.name | |
target_node = Node(id=target_id,title=target_type, label=label_target, size=size, shape=shape,color=node_colors[target_type]) | |
# if not if_node_exists(nodes, target_node.id): | |
# nodes.append(target_node) | |
label = relationship.name | |
edges.append(Edge(source=source_id, label=label, target=target_id)) | |
# Define the configuration | |
config = Config(width=1200, height=800, directed=True, physics=True, hierarchical=True,from_json="config.json") | |
# Create the Agraph visualization | |
return edges, nodes, config | |
def display_graph(edges, nodes, config): | |
# Display the Agraph visualization | |
return agraph(edges=edges, nodes=nodes, config=config) | |
def filter_nodes_by_types(nodes:list[Node], node_types_filter:list) -> list[Node]: | |
filtered_nodes = [] | |
for node in nodes: | |
if node.title in node_types_filter: #the title represents the type of the node | |
filtered_nodes.append(node) | |
return filtered_nodes | |
def change_view_dialog(): | |
st.write("Changer la vue") | |
for index, item in enumerate(st.session_state.filter_views.keys()): | |
emp = st.empty() | |
col1, col2, col3 = emp.columns([8, 1, 1]) | |
if index > 0 and col2.button("🗑️", key=f"del{index}"): | |
del st.session_state.filter_views[item] | |
st.session_state.current_view = "Vue par défaut" | |
st.rerun() | |
but_content = "🔍" if st.session_state.current_view != item else "✅" | |
if col3.button(but_content, key=f"valid{index}"): | |
st.session_state.current_view = item | |
st.rerun() | |
if len(st.session_state.filter_views.keys()) > index: | |
with col1.expander(item): | |
if index > 0: | |
change_name = st.text_input("Nom de la vue", label_visibility="collapsed", placeholder="Changez le nom de la vue",key=f"change_name{index}") | |
if st.button("Renommer",key=f"rename{index}"): | |
if change_name != "": | |
st.session_state.filter_views[change_name] = st.session_state.filter_views.pop(item) | |
st.session_state.current_view = change_name | |
st.rerun() | |
st.markdown("\n".join(f"- {label.strip()}" for label in st.session_state.filter_views[item])) | |
else: | |
emp.empty() | |
def add_view_dialog(filters): | |
st.write("Ajouter une vue") | |
view_name = st.text_input("Nom de la vue") | |
st.markdown("les filtres actuels:") | |
st.write(filters) | |
if st.button("Ajouter la vue"): | |
st.session_state.filter_views[view_name] = filters | |
st.session_state.current_view = view_name | |
st.rerun() | |
def change_color_dialog(): | |
st.write("Changer la couleur") | |
for node_type,color in st.session_state.node_types.items(): | |
color = st.color_picker(f"La couleur de l'entité **{node_type.strip()}**",color) | |
st.session_state.node_types[node_type] = color | |
if st.button("Valider"): | |
st.rerun() | |
def kg_main(): | |
#st.set_page_config(page_title="Graphe de connaissance", page_icon="", layout="wide") | |
if "audit" not in st.session_state or st.session_state.audit == {}: | |
st.error("Veuillez d'abord effectuer un audit pour visualiser le graphe de connaissance.") | |
return | |
if "cr" not in st.session_state: | |
st.error("Veuillez d'abord effectuer un compte rendu pour visualiser le graphe de connaissance.") | |
return | |
if "graph" not in st.session_state: | |
st.session_state.graph = None | |
if "filter_views" not in st.session_state: | |
st.session_state.filter_views = {} | |
if "current_view" not in st.session_state: | |
st.session_state.current_view = None | |
st.title("Graphe de connaissance") | |
if "node_types" not in st.session_state: | |
st.session_state.node_types = None | |
if "summary" not in st.session_state: | |
st.session_state.summary = None | |
if "chat_graph_history" not in st.session_state: | |
st.session_state.chat_graph_history = [] | |
audit = st.session_state.audit_simplified | |
# content = st.session_state.audit["content"] | |
# if audit["type de fichier"] == "pdf": | |
# text = get_text_from_content_for_doc(content) | |
# elif audit["type de fichier"] == "audio": | |
# text = get_text_from_content_for_audio(content) | |
text = st.session_state.cr + "mots clés" + audit["Mots clés"] | |
#summary_prompt = f"Voici un ensemble de documents : {text}. À partir de ces documents, veuillez fournir des résumés concis en vous concentrant sur l'extraction des relations essentielles et des événements. Il est crucial d'inclure les dates des actions ou des événements, car elles seront utilisées pour l'analyse chronologique. Par exemple : 'Sam a été licencié par le conseil d'administration d'OpenAI le 17 novembre 2023 (17 novembre, vendredi)', ce qui illustre la relation entre Sam et OpenAI ainsi que la date de l'événement." | |
if st.button("Générer le graphe"): | |
# with st.spinner("Extractions des relations..."): | |
# sum = generate_response_openai(summary_prompt,model="gpt-4o") | |
# st.session_state.summary = sum | |
with st.spinner("Génération du graphe..."): | |
keywords_list = audit["Mots clés"].strip().split(",") | |
allowed_nodes_types =keywords_list+ ["Person","Organization","Location","Event","Date","Time","Ressource","Concept"] | |
graph = get_graph(text,allowed_nodes=allowed_nodes_types) | |
st.session_state.graph = graph | |
node_types = get_node_types(graph[0]) | |
nodes_type_dict = list_to_dict_colors(node_types) | |
st.session_state.node_types = nodes_type_dict | |
st.session_state.filter_views["Vue par défaut"] = list(node_types) | |
st.session_state.current_view = "Vue par défaut" | |
else: | |
graph = st.session_state.graph | |
if graph is not None: | |
#st.write(graph) | |
edges,nodes,config = convert_neo4j_to_agraph(graph[0],st.session_state.node_types) | |
col1, col2 = st.columns([2.5, 1.5]) | |
with col1.container(border=True,height=800): | |
st.write("##### Visualisation du graphe (**"+st.session_state.current_view+"**)") | |
filter_col,add_view_col,change_view_col,color_col = st.columns([9,1,1,1]) | |
if color_col.button("🎨",help="Changer la couleur"): | |
change_color_dialog() | |
if change_view_col.button("🔍",help="Changer de vue"): | |
change_view_dialog() | |
#add mots cles to evry label in audit["Mots clés"] | |
#filter_labels = [ label + " (mot clé)" if label.strip().lower() in audit["Mots clés"].strip().lower().split(",") else label for label in st.session_state.filter_views[st.session_state.current_view] ] | |
filter = filter_col.multiselect("Filtrer selon l'étiquette",st.session_state.node_types.keys(),placeholder="Sélectionner une ou plusieurs étiquettes",default=st.session_state.filter_views[st.session_state.current_view],label_visibility="collapsed") | |
if add_view_col.button("➕",help="Ajouter une vue"): | |
add_view_dialog(filter) | |
if filter: | |
nodes = filter_nodes_by_types(nodes,filter) | |
selected = display_graph(edges,nodes,config) | |
with col2.container(border=True,height=800): | |
st.markdown("##### Dialoguer avec le graphe") | |
user_query = st.chat_input("Par ici ...") | |
if user_query is not None and user_query != "": | |
st.session_state.chat_graph_history.append(HumanMessage(content=user_query)) | |
with st.container(height=650, border=False): | |
for message in st.session_state.chat_graph_history: | |
if isinstance(message, AIMessage): | |
with st.chat_message("AI"): | |
st.markdown(message.content) | |
elif isinstance(message, HumanMessage): | |
with st.chat_message("Moi"): | |
st.write(message.content) | |
#check if last message is human message | |
if len(st.session_state.chat_graph_history) > 0: | |
last_message = st.session_state.chat_graph_history[-1] | |
if isinstance(last_message, HumanMessage): | |
with st.chat_message("AI"): | |
retreive = st.session_state.vectorstore.as_retriever() | |
context = retreive.invoke(last_message.content) | |
wrapped_prompt = f"Étant donné le contexte suivant {context}, et le graph de connaissance: {graph}, {last_message.content}" | |
response = st.write_stream(generate_response_via_langchain(wrapped_prompt,stream=True)) | |
st.session_state.chat_graph_history.append(AIMessage(content=response)) | |
if selected is not None: | |
with st.chat_message("AI"): | |
st.markdown(f" EXPLORER LES DONNEES CONTENUES DANS **{selected}**") | |
prompts = [f"Extrait moi toutes les informations du noeud ''{selected}'' ➡️", | |
f"Montre moi les conversations autour du noeud ''{selected}'' ➡️"] | |
for i,prompt in enumerate(prompts): | |
button = st.button(prompt,key=f"p_{i}",on_click=lambda i=i: st.session_state.chat_graph_history.append(HumanMessage(content=prompts[i]))) | |
node_types = st.session_state.node_types | |