Spaces:
Runtime error
Runtime error
| import os | |
| import networkx as nx | |
| from sentence_transformers import SentenceTransformer | |
| import spacy | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from io import BytesIO | |
| import base64 | |
| # Install the spaCy model if not already installed | |
| try: | |
| nlp = spacy.load("en_core_web_sm") | |
| except OSError: | |
| import subprocess | |
| subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"]) | |
| nlp = spacy.load("en_core_web_sm") | |
| sent_model = SentenceTransformer('bert-base-nli-mean-tokens') | |
| def extract_entities(text): | |
| """Extract entities from text using spaCy""" | |
| doc = nlp(text) | |
| entities = [(e.text, e.label_) for e in doc.ents] | |
| return entities | |
| def extract_relations(text): | |
| """Extract relationships between entities using spaCy's dependency parser""" | |
| doc = nlp(text) | |
| relations = [] | |
| for token in doc: | |
| if token.dep_ in ("nsubj", "dobj", "prep"): | |
| subject = token.head.text | |
| predicate = token.text | |
| object = token.text if token.dep_ == "prep" else token.head.text | |
| relations.append((subject, predicate, object)) | |
| return relations | |
| def build_knowledge_graph(entities, relations): | |
| """Construct the knowledge graph using NetworkX""" | |
| G = nx.Graph() | |
| for entity, entity_type in entities: | |
| G.add_node(entity, type=entity_type) | |
| for subject, predicate, object in relations: | |
| G.add_edge(subject, object, label=predicate) | |
| return G | |
| def visualize_graph(graph): | |
| """Visualize the knowledge graph using NetworkX and Matplotlib""" | |
| pos = nx.spring_layout(graph) | |
| plt.figure(figsize=(12, 8)) | |
| nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', edge_color='gray') | |
| edge_labels = nx.get_edge_attributes(graph, 'label') | |
| nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_color='red') | |
| # Save the plot to a BytesIO object | |
| img = BytesIO() | |
| plt.savefig(img, format='png') | |
| img.seek(0) | |
| plt.close() | |
| # Encode the image to base64 | |
| plot_data = base64.b64encode(img.getvalue()).decode() | |
| return plot_data | |
| def run_app(input_text): | |
| try: | |
| # Extract entities | |
| entities = extract_entities(input_text) | |
| entity_text = "\n".join([f"{e[0]} ({e[1]})" for e in entities]) | |
| # Extract relations | |
| relations = extract_relations(input_text) | |
| relation_text = "\n".join([f"{r[0]} --{r[1]}--> {r[2]}" for r in relations]) | |
| # Build knowledge graph | |
| graph = build_knowledge_graph(entities, relations) | |
| # Visualize graph | |
| plot_data = visualize_graph(graph) | |
| # Convert base64 image to HTML img tag | |
| plot_html = f'<img src="data:image/png;base64,{plot_data}" alt="Knowledge Graph">' | |
| return f"Entities:\n{entity_text}\n\nRelations:\n{relation_text}\n\nKnowledge graph created and visualized.", plot_html | |
| except Exception as e: | |
| return f"An error occurred: {str(e)}", None | |
| # Sample input text | |
| sample_text = "This is a sample text. John Smith is the CEO of Apple Inc. located in Cupertino, California. The Paris Agreement is a landmark international treaty on climate change." | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=run_app, | |
| inputs=gr.Textbox(label="Input Text", value=sample_text), | |
| outputs=[gr.Textbox(label="Output Text"), gr.HTML(label="Knowledge Graph Visualization")], | |
| title="Knowledge Graph Builder", | |
| description="Enter text to generate and visualize a knowledge graph" | |
| ) | |
| demo.launch() |