File size: 8,021 Bytes
99b658c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import gradio as gr
import networkx as nx
import spacy
import plotly.graph_objects as go

from sources import RSS_FEEDS
from fetch import fetch_articles

# Imports for the LLM knowledge graph transformer
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer


def build_interactive_knowledge_graph(feed_items):
    """
    Build an interactive knowledge graph from aggregated RSS feed text using an LLM.
    
    Steps:
      1. Combine the title and summary of all feed items.
      2. Create a Document and split it into chunks.
      3. Use ChatOpenAI and LLMGraphTransformer to get graph information.
      4. Merge nodes and relationships into a directed NetworkX graph.
      5. Compute a spring layout and convert the graph to a Plotly figure.
      6. Compute node hover text showing all outgoing/incoming connections.
      7. Re-add arrow annotations to indicate direction (with no extra text).
      8. Return the Plotly figure.
    """
    # 1. Combine all feed items into one aggregated text.
    combined_text = "\n\n".join([f"{item['title']}. {item['summary']}" for item in feed_items])
    
    # 2. Create a Document and split it.
    doc = Document(page_content=combined_text)
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    docs = text_splitter.split_documents([doc])
    
    # 3. Initialize the LLM and transformer.
    llm = ChatOpenAI(temperature=0, model="gpt-4o")
    llm_transformer = LLMGraphTransformer(llm=llm)
    graph_documents = llm_transformer.convert_to_graph_documents(docs)
    
    # 4. Build a directed NetworkX graph.
    G = nx.DiGraph()
    for graph_doc in graph_documents:
        # Convert the Pydantic model to a dictionary.
        gdoc = graph_doc.model_dump()
        nodes = gdoc.get("nodes", [])
        # In these documents, relationships are stored under "relationships".
        relationships = gdoc.get("relationships", [])
        
        # Add nodes.
        for node in nodes:
            node_id = node.get("id") or node.get("name")
            if node_id:
                G.add_node(node_id)
        
        # Add relationships as directed edges.
        for rel in relationships:
            source_obj = rel.get("source", {})
            target_obj = rel.get("target", {})
            source = source_obj.get("id")
            target = target_obj.get("id")
            rel_type = rel.get("type", "")
            if source and target:
                if G.has_edge(source, target):
                    if "relation_types" in G[source][target]:
                        if rel_type not in G[source][target]["relation_types"]:
                            G[source][target]["relation_types"].append(rel_type)
                    else:
                        G[source][target]["relation_types"] = [rel_type]
                    G[source][target]["weight"] += 1
                else:
                    G.add_edge(source, target, weight=1, relation_types=[rel_type])
    
    # 5. Compute positions using a spring layout.
    pos = nx.spring_layout(G, k=1.2)
    #pos = nx.kamada_kawai_layout(G)
    
    # 6. Prepare node hover text.
    # For each node, list all outgoing and incoming connection details.
    node_hover = {}
    for node in G.nodes():
        outgoing = []
        for u, v, data in G.out_edges(node, data=True):
            rels = ", ".join(data.get("relation_types", []))
            outgoing.append(f"Out: {node} - {rels} -> {v}")
        incoming = []
        for u, v, data in G.in_edges(node, data=True):
            rels = ", ".join(data.get("relation_types", []))
            incoming.append(f"In: {u} - {rels} -> {node}")
        details = outgoing + incoming
        if details:
            node_hover[node] = "<br>".join(details)
        else:
            node_hover[node] = node  # Fallback if there are no connections.
    
    # 7. Create node trace using calculated positions and hover text.
    node_x = []
    node_y = []
    node_text = []         # Displayed text is just the node name.
    node_hover_list = []   # Custom hover info with connection details.
    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_text.append(node)
        node_hover_list.append(node_hover.get(node, node))
        
    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode='markers+text',
        text=node_text,
        textposition="top center",
        hoverinfo='text',
        hovertext=node_hover_list,
        marker=dict(
            size=10,
            color='#1f78b4'
        )
    )
    
    # 8. Create edge traces: one trace per edge.
    edge_traces = []
    for edge in G.edges(data=True):
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_trace = go.Scatter(
            x=[x0, x1],
            y=[y0, y1],
            mode='lines',
            line=dict(width=1, color='#888'),
            hoverinfo='none'
        )
        edge_traces.append(edge_trace)
    
    # 9. Build the interactive Plotly figure.
    fig = go.Figure(
        data=edge_traces + [node_trace],
        layout=go.Layout(
            title='<br>Interactive Knowledge Graph (LLM-derived)',
            showlegend=False,
            hovermode='closest',
            margin=dict(b=20, l=5, r=5, t=40),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            width=1200,   # wider figure
            height=800,   # taller figure
            dragmode='pan'
        )
    )
    
    # 10. Re-add arrow annotations for each edge (without hover text).
    for edge in G.edges(data=True):
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        fig.add_annotation(
            x=x1,
            y=y1,
            ax=x0,
            ay=y0,
            xref='x',
            yref='y',
            axref='x',
            ayref='y',
            showarrow=True,
            arrowhead=3,
            arrowcolor='#888',
            arrowwidth=2,
            text="",  # No text; rely on node hover for details.
        )
    
    return fig

def get_combined_feed(source_choice, selected_news_sites):
    """
    Create an aggregated feed from selected RSS sources
    and build an interactive Plotly knowledge graph.
    """
    feed_items = []
    
    # Fetch articles from selected news sites.
    if "News" in source_choice and selected_news_sites:
        selected_feeds = {name: url for name, url in RSS_FEEDS.items() if name in selected_news_sites}
        feed_items += fetch_articles(selected_feeds, limit=6)
    
    # Aggregate feed text.
    feed_text = "\n\n".join([f"🔹 {item['title']} ({item['published']})\n{item['link']}" for item in feed_items])
    
    # Build an interactive knowledge graph using Plotly.
    graph_fig = build_interactive_knowledge_graph(feed_items)
    
    return feed_text, graph_fig

# Define the Gradio interface with a button to trigger processing.
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            source_selector = gr.CheckboxGroup(
                ["News"], value=["News"], label="Select Sources"
            )
            news_site_selector = gr.CheckboxGroup(
                list(RSS_FEEDS.keys()), value=["BBC", "Wired"], label="News Sites"
            )
        with gr.Column():
            feed_output = gr.Textbox(label="Aggregated Feed", lines=20)
    with gr.Row():
        with gr.Column():
            graph_output = gr.Plot(label="Interactive Knowledge Graph")
    
    # Button to trigger graph generation.
    generate_button = gr.Button("Generate Graph")
    generate_button.click(
        fn=get_combined_feed,
        inputs=[source_selector, news_site_selector],
        outputs=[feed_output, graph_output]
    )

demo.launch()