File size: 3,295 Bytes
e35a1d9
 
 
 
 
 
 
 
 
05f9e0c
 
 
 
 
 
 
 
e35a1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import NoReturn
import spacy
import networkx as nx
import matplotlib.pyplot as plt
import io
from PIL import Image
import gradio as gr

# Load the spaCy model for dependency parsing
try:
    nlp = spacy.load("en_core_web_sm")
except OSError:
    # Download the model in case it's not available
    from spacy.cli import download
    download("en_core_web_sm")
    nlp = spacy.load("en_core_web_sm")
    
# Function to extract entities using NER
def extract_entities(text):
    doc = nlp(text)
    entities = [(ent.text, ent.label_) for ent in doc.ents]
    return entities

# Function to extract relationships dynamically from the text
def extract_relationships(text):
    relationships = []
    doc = nlp(text.lower())
    subject, verb, obj, Noun = None, None, None, None
    entities = []
    for token in doc:
        if token.dep_ in ("compound"):
            Noun = token.text + " "
            continue
        if not Noun:
            if token.dep_ in ("nsubj", "nsubjpass"):
                subject = token.text
            if token.dep_ in ("dobj", "attr", "pobj"):
                obj = token.text
                entities.append(obj)
            if token.dep_ in ("ROOT", "xcomp", "ccomp"):
                verb = token.text
        elif Noun:
            if token.dep_ in ("nsubj", "nsubjpass"):
                subject = Noun
                entities.append(subject)
            if token.dep_ in ("dobj", "attr", "pobj"):
                obj = Noun
                entities.append(obj)
            Noun = None
        if token.dep_ == "prep":
            subject = entities[-1]
            if token.head.dep_ == "ROOT":
                verb = token.head.text + " " + token.text
            else:
                verb = token.text
        if subject and verb and obj:
            relationships.append((subject.strip(), verb.strip(), obj.strip()))
            subject, verb, obj = None, None, None
    return relationships, entities

# Function to create the knowledge graph
def create_knowledge_graph(entities, relationships):
    G = nx.DiGraph()
    involved_entities = set()
    for subj, rel, obj in relationships:
        involved_entities.add(subj)
        involved_entities.add(obj)
    for entity in involved_entities:
        G.add_node(entity)
    for subj, rel, obj in relationships:
        G.add_edge(subj, obj, label=rel)
    return G

# Function to visualize the graph
def visualize_graph(G):
    pos = nx.spring_layout(G)
    edge_labels = nx.get_edge_attributes(G, 'label')
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, node_size=2000, node_color="lightblue", font_size=10, font_weight="bold")
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    plt.close()
    pil_image = Image.open(buf)
    return pil_image

# Function to process input and generate output
def process_text(text: str):
    relationships, entities = extract_relationships(text)
    G = create_knowledge_graph(entities, relationships)
    return visualize_graph(G)

# Gradio Interface
gr.Interface(
    fn=process_text,
    inputs=gr.Textbox(placeholder="Enter knowledge prompt here"),
    outputs=gr.Image(type="pil"),
    title="Knowledge Graph Generator"
).launch()