word_graph_viz / app.py
gigant's picture
Update app.py
841d57f verified
raw
history blame
9.55 kB
import networkx as nx
import matplotlib.pyplot as plt
import jraph
import jax.numpy as jnp
from datasets import load_dataset
import spacy
import gradio as gr
import en_core_web_trf
import numpy as np
import benepar
import re
dataset = load_dataset("gigant/tib_transcripts")
nlp = en_core_web_trf.load()
benepar.download('benepar_en3')
nlp.add_pipe('benepar', config={'model': 'benepar_en3'})
def parse_tree(sentence):
stack = [] # or a `collections.deque()` object, which is a little faster
top = items = []
for token in filter(None, re.compile(r'(?:([()])|\s+)').split(sentence)):
if token == '(':
stack.append(items)
items.append([])
items = items[-1]
elif token == ')':
if not stack:
raise ValueError("Unbalanced parentheses")
items = stack.pop()
else:
items.append(token)
if stack:
raise ValueError("Unbalanced parentheses")
return top
class Tree():
def __init__(self, name, children):
self.children = children
self.name = name
self.id = None
def set_id_rec(self, id=0):
self.id = id
last_id=id
for child in self.children:
last_id = child.set_id_rec(id=last_id+1)
return last_id
def set_all_ids(self):
self.set_id_rec(0)
def print_tree(self, level=0):
to_print = f'|{"-" * level} {self.name} ({self.id})'
for child in self.children:
to_print += f"\n{child.print_tree(level + 1)}"
return to_print
def __str__(self):
return self.print_tree(0)
def get_list_nodes(self):
return [self.name] + [_ for child in self.children for _ in child.get_list_nodes()]
def rec_const_parsing(list_nodes):
if isinstance(list_nodes, list):
name, children = list_nodes[0], list_nodes[1:]
else:
name, children = list_nodes, []
return Tree(name, [rec_const_parsing(child) for i, child in enumerate(children)])
def tree_to_graph(t):
senders = []
receivers = []
for child in t.children:
senders.append(t.id)
receivers.append(child.id)
s_rec, r_rec = tree_to_graph(child)
senders.extend(s_rec)
receivers.extend(r_rec)
return senders, receivers
def construct_constituency_graph(docs):
doc = docs[0]
sent = list(doc.sents)[0]
print(sent._.parse_string)
t = rec_const_parsing(parse_tree(sent._.parse_string)[0])
t.set_all_ids()
senders, receivers = tree_to_graph(t)
nodes = t.get_list_nodes()
graphs = [{"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": {}}]
return graphs
def half_circle_layout(n_nodes, sentence_node=True):
pos = {}
for i_node in range(n_nodes - 1):
pos[i_node] = ((- np.cos(i_node * np.pi/(n_nodes - 1))), 0.5 * (-np.sin(i_node * np.pi/(n_nodes - 1))))
pos[n_nodes - 1] = (0, -0.25)
return pos
def get_adjacency_matrix(jraph_graph: jraph.GraphsTuple):
nodes, edges, receivers, senders, _, _, _ = jraph_graph
adj_mat = jnp.zeros((len(nodes), len(nodes)))
for i in range(len(receivers)):
adj_mat = adj_mat.at[senders[i], receivers[i]].set(1)
return adj_mat
def dependency_parser(sentences):
return [nlp(sentence) for sentence in sentences]
def construct_dependency_graph(docs):
"""
docs is a list of outputs of the SpaCy dependency parser
"""
graphs = []
for doc in docs:
nodes = [token.text for token in doc]
senders = []
receivers = []
edge_labels = {}
for token in doc:
for child in token.children:
senders.append(child.i)
receivers.append(token.i)
edge_labels[(token.i, child.i)] = child.dep_
graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels})
return graphs
def construct_both_graph(docs):
"""
docs is a list of outputs of the SpaCy dependency parser
"""
graphs = []
for doc in docs:
nodes = [token.text for token in doc]
nodes.append("Sentence")
senders = [token.i for token in doc][:-1]
senders.extend([token.i for token in doc][1:])
receivers = [token.i for token in doc][1:]
receivers.extend([token.i for token in doc][:-1])
edge_labels = {(token.i, token.i + 1): "next" for token in doc[:-1]}
for token in doc[:-1]:
edge_labels[(token.i + 1, token.i)] = "previous"
for node in range(len(nodes) - 1):
senders.append(node)
receivers.append(len(nodes) - 1)
edge_labels[(node, len(nodes) - 1)] = "in"
for token in doc:
for child in token.children:
senders.append(child.i)
receivers.append(token.i)
edge_labels[(token.i, child.i)] = child.dep_
graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels})
return graphs
def construct_structural_graph(docs):
graphs = []
for doc in docs:
nodes = [token.text for token in doc]
nodes.append("Sentence")
senders = [token.i for token in doc][:-1]
senders.extend([token.i for token in doc][1:])
receivers = [token.i for token in doc][1:]
receivers.extend([token.i for token in doc][:-1])
edge_labels = {(token.i, token.i + 1): "next" for token in doc[:-1]}
for token in doc[:-1]:
edge_labels[(token.i + 1, token.i)] = "previous"
for node in range(len(nodes) - 1):
senders.append(node)
receivers.append(len(nodes) - 1)
edge_labels[(node, len(nodes) - 1)] = "in"
graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels})
return graphs
def to_jraph(graph):
nodes = graph["nodes"]
s = graph["senders"]
r = graph["receivers"]
# Define a three node graph, each node has an integer as its feature.
node_features = jnp.array([0]*len(nodes))
# We will construct a graph for which there is a directed edge between each node
# and its successor. We define this with `senders` (source nodes) and `receivers`
# (destination nodes).
senders = jnp.array(s)
receivers = jnp.array(r)
# We then save the number of nodes and the number of edges.
# This information is used to make running GNNs over multiple graphs
# in a GraphsTuple possible.
n_node = jnp.array([len(nodes)])
n_edge = jnp.array([len(s)])
return jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
edges=None, n_node=n_node, n_edge=n_edge, globals=None)
def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph:
nodes, edges, receivers, senders, _, _, _ = jraph_graph
nx_graph = nx.DiGraph()
if nodes is None:
for n in range(jraph_graph.n_node[0]):
nx_graph.add_node(n)
else:
for n in range(jraph_graph.n_node[0]):
nx_graph.add_node(n, node_feature=nodes[n])
if edges is None:
for e in range(jraph_graph.n_edge[0]):
nx_graph.add_edge(int(senders[e]), int(receivers[e]))
else:
for e in range(jraph_graph.n_edge[0]):
nx_graph.add_edge(
int(senders[e]), int(receivers[e]), edge_feature=edges[e])
return nx_graph
def plot_graph_sentence(sentence, graph_type="constituency"):
# sentences = dataset["train"][0]["abstract"].split(".")
docs = dependency_parser([sentence])
if graph_type == "dependency":
graphs = construct_dependency_graph(docs)
elif graph_type == "structural":
graphs = construct_structural_graph(docs)
elif graph_type == "structural+dependency":
graphs = construct_both_graph(docs)
elif graph_type == "constituency":
graphs = construct_constituency_graph(docs)
g = to_jraph(graphs[0])
adj_mat = get_adjacency_matrix(g)
nx_graph = convert_jraph_to_networkx_graph(g)
pos = half_circle_layout(len(graphs[0]["nodes"]))
if graph_type == "constituency":
pos = nx.planar_layout(nx_graph)
plot = plt.figure(figsize=(12, 6))
nx.draw(nx_graph, pos=pos,
labels={i: e for i,e in enumerate(graphs[0]["nodes"])},
with_labels = True, edge_color="blue",
# connectionstyle="arc3,rad=0.1",
node_size=1000, font_color='black', node_color="yellow")
nx.draw_networkx_edge_labels(
nx_graph, pos=pos,
edge_labels=graphs[0]["edge_labels"],
font_color='red'
)
adj_mat_plot, ax = plt.subplots(figsize=(6, 6))
ax.matshow(adj_mat)
return [gr.update(value=plot), gr.update(value=adj_mat_plot)]
def get_list_sentences(id):
id = int(min(id, len(dataset["train"]) - 1))
return gr.update(choices = dataset["train"][id]["transcript"].split("."))
with gr.Blocks() as demo:
with gr.Row():
graph_type = gr.Dropdown(label="Graph type", choices=["structural", "dependency", "structural+dependency", "constituency"], value="structural+dependency", interactive = True)
with gr.Tab("From transcript"):
with gr.Row():
with gr.Column():
id = gr.Number(label="Transcript")
with gr.Column(scale=3):
sentence_transcript = gr.Dropdown(label="Sentence", choices = dataset["train"][0]["transcript"].split(".")[1:], interactive = True)
with gr.Tab("Type sentence"):
with gr.Row():
sentence_typed = gr.Textbox(label="Sentence", interactive = True)
with gr.Row():
with gr.Column(scale=2):
plot_graph = gr.Plot(label="Word graph")
with gr.Column():
plot_adj = gr.Plot(label="Word graph adjacency matrix")
id.change(get_list_sentences, id, sentence_transcript)
sentence_transcript.change(plot_graph_sentence, [sentence_transcript, graph_type], [plot_graph, plot_adj])
sentence_typed.change(plot_graph_sentence, [sentence_typed, graph_type], [plot_graph, plot_adj])
demo.launch()