Anupam251272 commited on
Commit
37fff6b
·
verified ·
1 Parent(s): a4a03a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import networkx as nx
3
+ from sentence_transformers import SentenceTransformer
4
+ import spacy
5
+ import matplotlib.pyplot as plt
6
+ import gradio as gr
7
+ from io import BytesIO
8
+ import base64
9
+
10
+ # Load pre-trained models
11
+ nlp = spacy.load("en_core_web_sm")
12
+ sent_model = SentenceTransformer('bert-base-nli-mean-tokens')
13
+
14
+ def extract_entities(text):
15
+ """Extract entities from text using spaCy"""
16
+ doc = nlp(text)
17
+ entities = [(e.text, e.label_) for e in doc.ents]
18
+ return entities
19
+
20
+ def extract_relations(text):
21
+ """Extract relationships between entities using spaCy's dependency parser"""
22
+ doc = nlp(text)
23
+ relations = []
24
+ for token in doc:
25
+ if token.dep_ in ("nsubj", "dobj", "prep"):
26
+ subject = token.head.text
27
+ predicate = token.text
28
+ object = token.text if token.dep_ == "prep" else token.head.text
29
+ relations.append((subject, predicate, object))
30
+ return relations
31
+
32
+ def build_knowledge_graph(entities, relations):
33
+ """Construct the knowledge graph using NetworkX"""
34
+ G = nx.Graph()
35
+ for entity, entity_type in entities:
36
+ G.add_node(entity, type=entity_type)
37
+ for subject, predicate, object in relations:
38
+ G.add_edge(subject, object, label=predicate)
39
+ return G
40
+
41
+ def visualize_graph(graph):
42
+ """Visualize the knowledge graph using NetworkX and Matplotlib"""
43
+ pos = nx.spring_layout(graph)
44
+ plt.figure(figsize=(12, 8))
45
+ nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=2000, font_size=10, font_weight='bold', edge_color='gray')
46
+ edge_labels = nx.get_edge_attributes(graph, 'label')
47
+ nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_color='red')
48
+
49
+ # Save the plot to a BytesIO object
50
+ img = BytesIO()
51
+ plt.savefig(img, format='png')
52
+ img.seek(0)
53
+ plt.close()
54
+
55
+ # Encode the image to base64
56
+ plot_data = base64.b64encode(img.getvalue()).decode()
57
+ return plot_data
58
+
59
+ def run_app(input_text):
60
+ try:
61
+ # Extract entities
62
+ entities = extract_entities(input_text)
63
+ entity_text = "\n".join([f"{e[0]} ({e[1]})" for e in entities])
64
+
65
+ # Extract relations
66
+ relations = extract_relations(input_text)
67
+ relation_text = "\n".join([f"{r[0]} --{r[1]}--> {r[2]}" for r in relations])
68
+
69
+ # Build knowledge graph
70
+ graph = build_knowledge_graph(entities, relations)
71
+
72
+ # Visualize graph
73
+ plot_data = visualize_graph(graph)
74
+
75
+ # Convert base64 image to HTML img tag
76
+ plot_html = f'<img src="data:image/png;base64,{plot_data}" alt="Knowledge Graph">'
77
+
78
+ return f"Entities:\n{entity_text}\n\nRelations:\n{relation_text}\n\nKnowledge graph created and visualized.", plot_html
79
+ except Exception as e:
80
+ return f"An error occurred: {str(e)}", None
81
+
82
+ # Sample input text
83
+ sample_text = "This is a sample text. John Smith is the CEO of Apple Inc. located in Cupertino, California. The Paris Agreement is a landmark international treaty on climate change."
84
+
85
+ # Create Gradio interface
86
+ demo = gr.Interface(
87
+ fn=run_app,
88
+ inputs=gr.Textbox(label="Input Text", value=sample_text),
89
+ outputs=[gr.Textbox(label="Output Text"), gr.HTML(label="Knowledge Graph Visualization")],
90
+ title="Knowledge Graph Builder",
91
+ description="Enter text to generate and visualize a knowledge graph"
92
+ )
93
+
94
+ demo.launch()