import os import networkx as nx from sentence_transformers import SentenceTransformer import spacy import matplotlib.pyplot as plt import gradio as gr from io import BytesIO import base64 # Install the spaCy model if not already installed try: nlp = spacy.load("en_core_web_sm") except OSError: import subprocess subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"]) nlp = spacy.load("en_core_web_sm") sent_model = SentenceTransformer('bert-base-nli-mean-tokens') def extract_entities(text): """Extract entities from text using spaCy""" doc = nlp(text) entities = [(e.text, e.label_) for e in doc.ents] return entities def extract_relations(text): """Extract relationships between entities using spaCy's dependency parser""" doc = nlp(text) relations = [] for token in doc: if token.dep_ in ("nsubj", "dobj", "prep"): subject = token.head.text predicate = token.text object = token.text if token.dep_ == "prep" else token.head.text relations.append((subject, predicate, object)) return relations def build_knowledge_graph(entities, relations): """Construct the knowledge graph using NetworkX""" G = nx.Graph() for entity, entity_type in entities: G.add_node(entity, type=entity_type) for subject, predicate, object in relations: G.add_edge(subject, object, label=predicate) return G def visualize_graph(graph): """Visualize the knowledge graph using NetworkX and Matplotlib""" pos = nx.spring_layout(graph) plt.figure(figsize=(12, 8)) nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', edge_color='gray') edge_labels = nx.get_edge_attributes(graph, 'label') nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_color='red') # Save the plot to a BytesIO object img = BytesIO() plt.savefig(img, format='png') img.seek(0) plt.close() # Encode the image to base64 plot_data = base64.b64encode(img.getvalue()).decode() return plot_data def run_app(input_text): try: # Extract entities entities = extract_entities(input_text) entity_text = "\n".join([f"{e[0]} ({e[1]})" for e in entities]) # Extract relations relations = extract_relations(input_text) relation_text = "\n".join([f"{r[0]} --{r[1]}--> {r[2]}" for r in relations]) # Build knowledge graph graph = build_knowledge_graph(entities, relations) # Visualize graph plot_data = visualize_graph(graph) # Convert base64 image to HTML img tag plot_html = f'Knowledge Graph' return f"Entities:\n{entity_text}\n\nRelations:\n{relation_text}\n\nKnowledge graph created and visualized.", plot_html except Exception as e: return f"An error occurred: {str(e)}", None # Sample input text sample_text = "This is a sample text. John Smith is the CEO of Apple Inc. located in Cupertino, California. The Paris Agreement is a landmark international treaty on climate change." # Create Gradio interface demo = gr.Interface( fn=run_app, inputs=gr.Textbox(label="Input Text", value=sample_text), outputs=[gr.Textbox(label="Output Text"), gr.HTML(label="Knowledge Graph Visualization")], title="Knowledge Graph Builder", description="Enter text to generate and visualize a knowledge graph" ) demo.launch()