Anupam251272's picture
Update app.py
39c1e76 verified
import os
import networkx as nx
from sentence_transformers import SentenceTransformer
import spacy
import matplotlib.pyplot as plt
import gradio as gr
from io import BytesIO
import base64
# Install the spaCy model if not already installed
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
import subprocess
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
nlp = spacy.load("en_core_web_sm")
sent_model = SentenceTransformer('bert-base-nli-mean-tokens')
def extract_entities(text):
"""Extract entities from text using spaCy"""
doc = nlp(text)
entities = [(e.text, e.label_) for e in doc.ents]
return entities
def extract_relations(text):
"""Extract relationships between entities using spaCy's dependency parser"""
doc = nlp(text)
relations = []
for token in doc:
if token.dep_ in ("nsubj", "dobj", "prep"):
subject = token.head.text
predicate = token.text
object = token.text if token.dep_ == "prep" else token.head.text
relations.append((subject, predicate, object))
return relations
def build_knowledge_graph(entities, relations):
"""Construct the knowledge graph using NetworkX"""
G = nx.Graph()
for entity, entity_type in entities:
G.add_node(entity, type=entity_type)
for subject, predicate, object in relations:
G.add_edge(subject, object, label=predicate)
return G
def visualize_graph(graph):
"""Visualize the knowledge graph using NetworkX and Matplotlib"""
pos = nx.spring_layout(graph)
plt.figure(figsize=(12, 8))
nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', edge_color='gray')
edge_labels = nx.get_edge_attributes(graph, 'label')
nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_color='red')
# Save the plot to a BytesIO object
img = BytesIO()
plt.savefig(img, format='png')
img.seek(0)
plt.close()
# Encode the image to base64
plot_data = base64.b64encode(img.getvalue()).decode()
return plot_data
def run_app(input_text):
try:
# Extract entities
entities = extract_entities(input_text)
entity_text = "\n".join([f"{e[0]} ({e[1]})" for e in entities])
# Extract relations
relations = extract_relations(input_text)
relation_text = "\n".join([f"{r[0]} --{r[1]}--> {r[2]}" for r in relations])
# Build knowledge graph
graph = build_knowledge_graph(entities, relations)
# Visualize graph
plot_data = visualize_graph(graph)
# Convert base64 image to HTML img tag
plot_html = f'<img src="data:image/png;base64,{plot_data}" alt="Knowledge Graph">'
return f"Entities:\n{entity_text}\n\nRelations:\n{relation_text}\n\nKnowledge graph created and visualized.", plot_html
except Exception as e:
return f"An error occurred: {str(e)}", None
# Sample input text
sample_text = "This is a sample text. John Smith is the CEO of Apple Inc. located in Cupertino, California. The Paris Agreement is a landmark international treaty on climate change."
# Create Gradio interface
demo = gr.Interface(
fn=run_app,
inputs=gr.Textbox(label="Input Text", value=sample_text),
outputs=[gr.Textbox(label="Output Text"), gr.HTML(label="Knowledge Graph Visualization")],
title="Knowledge Graph Builder",
description="Enter text to generate and visualize a knowledge graph"
)
demo.launch()