gigant commited on
Commit
d91dab1
·
1 Parent(s): 357a26a

adding structural graph information

Browse files
Files changed (1) hide show
  1. app.py +69 -7
app.py CHANGED
@@ -6,11 +6,19 @@ from datasets import load_dataset
6
  import spacy
7
  import gradio as gr
8
  import en_core_web_trf
 
9
 
10
  dataset = load_dataset("gigant/tib_transcripts")
11
 
12
  nlp = en_core_web_trf.load()
13
 
 
 
 
 
 
 
 
14
  def dependency_parser(sentences):
15
  return [nlp(sentence) for sentence in sentences]
16
 
@@ -23,11 +31,51 @@ def construct_dependency_graph(docs):
23
  nodes = [token.text for token in doc]
24
  senders = []
25
  receivers = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  for token in doc:
27
  for child in token.children:
28
  senders.append(token.i)
29
  receivers.append(child.i)
30
- graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return graphs
32
 
33
  def to_jraph(graph):
@@ -72,15 +120,29 @@ def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph:
72
  int(senders[e]), int(receivers[e]), edge_feature=edges[e])
73
  return nx_graph
74
 
75
- def plot_graph_sentence(sentence):
 
76
  docs = dependency_parser([sentence])
77
- graphs = construct_dependency_graph(docs)
 
 
 
 
 
78
  g = to_jraph(graphs[0])
79
  nx_graph = convert_jraph_to_networkx_graph(g)
80
- pos = nx.spring_layout(nx_graph)
81
- plot = plt.figure(figsize=(6, 6))
82
- nx.draw(nx_graph, pos=pos, labels={i: e for i,e in enumerate(graphs[0]["nodes"])}, with_labels = True,
83
- node_size=800, font_color='black', node_color="yellow")
 
 
 
 
 
 
 
 
84
  return plot
85
 
86
  def get_list_sentences(id):
 
6
  import spacy
7
  import gradio as gr
8
  import en_core_web_trf
9
+ import numpy as np
10
 
11
  dataset = load_dataset("gigant/tib_transcripts")
12
 
13
  nlp = en_core_web_trf.load()
14
 
15
+ def half_circle_layout(n_nodes, sentence_node=True):
16
+ pos = {}
17
+ for i_node in range(n_nodes - 1):
18
+ pos[i_node] = ((- np.cos(i_node * np.pi/(n_nodes - 1))), 0.5 * (-np.sin(i_node * np.pi/(n_nodes - 1))))
19
+ pos[n_nodes - 1] = (0, -0.25)
20
+ return pos
21
+
22
  def dependency_parser(sentences):
23
  return [nlp(sentence) for sentence in sentences]
24
 
 
31
  nodes = [token.text for token in doc]
32
  senders = []
33
  receivers = []
34
+ edge_labels = {}
35
+ for token in doc:
36
+ for child in token.children:
37
+ senders.append(token.i)
38
+ receivers.append(child.i)
39
+ edge_labels[(token.i, child.i)] = token.dep_
40
+ graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels})
41
+ return graphs
42
+
43
+ def construct_both_graph(docs):
44
+ """
45
+ docs is a list of outputs of the SpaCy dependency parser
46
+ """
47
+ graphs = []
48
+ for doc in docs:
49
+ nodes = [token.text for token in doc]
50
+ nodes.append("Sentence")
51
+ senders = [token.i for token in doc][:-1]
52
+ receivers = [token.i for token in doc][1:]
53
+ edge_labels = {(token.i, token.i + 1): "next" for token in doc[:-1]}
54
+ for node in range(len(nodes) - 1):
55
+ senders.append(node)
56
+ receivers.append(len(nodes) - 1)
57
+ edge_labels[(node, len(nodes) - 1)] = "in"
58
  for token in doc:
59
  for child in token.children:
60
  senders.append(token.i)
61
  receivers.append(child.i)
62
+ edge_labels[(token.i, child.i)] = token.dep_
63
+ graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels})
64
+ return graphs
65
+
66
+ def construct_structural_graph(docs):
67
+ graphs = []
68
+ for doc in docs:
69
+ nodes = [token.text for token in doc]
70
+ nodes.append("Sentence")
71
+ senders = [token.i for token in doc][:-1]
72
+ receivers = [token.i for token in doc][1:]
73
+ edge_labels = {(token.i, token.i + 1): "next" for token in doc[:-1]}
74
+ for node in range(len(nodes) - 1):
75
+ senders.append(node)
76
+ receivers.append(len(nodes) - 1)
77
+ edge_labels[(node, len(nodes) - 1)] = "in"
78
+ graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels})
79
  return graphs
80
 
81
  def to_jraph(graph):
 
120
  int(senders[e]), int(receivers[e]), edge_feature=edges[e])
121
  return nx_graph
122
 
123
+ def plot_graph_sentence(sentence, graph_type="both"):
124
+ # sentences = dataset["train"][0]["abstract"].split(".")
125
  docs = dependency_parser([sentence])
126
+ if graph_type == "dependency":
127
+ graphs = construct_dependency_graph(docs)
128
+ elif graph_type == "structural":
129
+ graphs = construct_structural_graph(docs)
130
+ elif graph_type == "both":
131
+ graphs = construct_both_graph(docs)
132
  g = to_jraph(graphs[0])
133
  nx_graph = convert_jraph_to_networkx_graph(g)
134
+ pos = half_circle_layout(len(graphs[0]["nodes"]))
135
+ plot = plt.figure(figsize=(25, 6))
136
+ nx.draw(nx_graph, pos=pos,
137
+ labels={i: e for i,e in enumerate(graphs[0]["nodes"])},
138
+ with_labels = True, edge_color="blue",
139
+ # connectionstyle="arc3,rad=0.1",
140
+ node_size=1000, font_color='black', node_color="yellow")
141
+ nx.draw_networkx_edge_labels(
142
+ nx_graph, pos=pos,
143
+ edge_labels=graphs[0]["edge_labels"],
144
+ font_color='red'
145
+ )
146
  return plot
147
 
148
  def get_list_sentences(id):