EZ4Fanta commited on
Commit
2500a05
·
1 Parent(s): fed7321
Files changed (3) hide show
  1. app.py +32 -5
  2. demo.ipynb +2 -2
  3. utils.py +31 -3
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
  import os
3
  import json
@@ -115,14 +118,20 @@ def run_query(protein, compound, disease, pathway, go, depth,
115
  must_show = sample_dict.copy()
116
  try:
117
  print('start sampling')
118
- sub_g, new2orig, node_map_sub = subgraph_by_node(graph, sample_dict, node_map, depth=depth)
119
  id_map_sub = {k: {vv: kk for kk, vv in v.items()} for k, v in node_map_sub.items()}
 
 
 
 
 
 
120
  # 储存subgraph
121
  # save_subgraph_and_metadata(sub_g, id_map_sub, must_show)
122
  # 调用公共函数生成展示 HTML
123
  iframe_html, html_code = generate_iframe(sub_g, id_map_sub, must_show, display_limits)
124
 
125
- return iframe_html, 'success', html_code, sub_g, id_map_sub, must_show
126
  except Exception as e:
127
  return f"Error: {str(e)}", f"Error: {str(e)}", sample_dict, None, None, None
128
 
@@ -171,14 +180,32 @@ def get_text_content(file_path="static/gr_head.md"):
171
  with open(file_path, "r", encoding="utf-8") as f:
172
  return f.read()
173
 
 
 
 
 
 
 
 
 
174
  def download_entity(sub_g, id_map_sub):
175
  try:
 
176
  report_subgraph(sub_g, id_map_sub, save_root=results_root)
177
- path = os.path.join(results_root, 'triples.txt')
178
- return gr.update(value=path, visible=True)
 
 
 
 
 
 
 
 
 
179
  except Exception as e:
180
  return gr.update(value=f"Error: {e}", visible=True)
181
-
182
  # 新增:加载静态文件的函数,如果存在则加载保存的子图数据
183
  def load_static_files():
184
  subgraph_file = "static/subgraph.dgl"
 
1
+ import zipfile
2
+
3
+ import yaml
4
  import gradio as gr
5
  import os
6
  import json
 
118
  must_show = sample_dict.copy()
119
  try:
120
  print('start sampling')
121
+ sub_g, new2orig, node_map_sub, statistics = subgraph_by_node(graph, sample_dict, node_map, depth=depth)
122
  id_map_sub = {k: {vv: kk for kk, vv in v.items()} for k, v in node_map_sub.items()}
123
+ # save statistics as json
124
+ with open(join(results_root, "statistics.yaml"), "w") as f:
125
+ yaml.dump(statistics, f)
126
+
127
+ statistics_text = "\n".join([f"{k}: {v}" for k, v in statistics.items()])
128
+
129
  # 储存subgraph
130
  # save_subgraph_and_metadata(sub_g, id_map_sub, must_show)
131
  # 调用公共函数生成展示 HTML
132
  iframe_html, html_code = generate_iframe(sub_g, id_map_sub, must_show, display_limits)
133
 
134
+ return iframe_html, 'success', statistics_text, sub_g, id_map_sub, must_show
135
  except Exception as e:
136
  return f"Error: {str(e)}", f"Error: {str(e)}", sample_dict, None, None, None
137
 
 
180
  with open(file_path, "r", encoding="utf-8") as f:
181
  return f.read()
182
 
183
+ # def download_entity(sub_g, id_map_sub):
184
+ # try:
185
+ # report_subgraph(sub_g, id_map_sub, save_root=results_root)
186
+ # path = join(results_root, 'triples.txt')
187
+ # return gr.update(value=path, visible=True)
188
+ # except Exception as e:
189
+ # return gr.update(value=f"Error: {e}", visible=True)
190
+
191
  def download_entity(sub_g, id_map_sub):
192
  try:
193
+ # 生成统计数据和三元组文件
194
  report_subgraph(sub_g, id_map_sub, save_root=results_root)
195
+ triples_path = join(results_root, 'triples.txt')
196
+ statistics_path = join(results_root, 'statistics.yaml')
197
+
198
+ # 创建一个压缩文件
199
+ zip_path = join(results_root, 'results.zip')
200
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
201
+ zipf.write(triples_path, arcname='triples.txt')
202
+ zipf.write(statistics_path, arcname='statistics.yaml')
203
+
204
+ # 返回压缩文件路径
205
+ return gr.update(value=zip_path, visible=True)
206
  except Exception as e:
207
  return gr.update(value=f"Error: {e}", visible=True)
208
+
209
  # 新增:加载静态文件的函数,如果存在则加载保存的子图数据
210
  def load_static_files():
211
  subgraph_file = "static/subgraph.dgl"
demo.ipynb CHANGED
@@ -133,7 +133,7 @@
133
  },
134
  {
135
  "cell_type": "code",
136
- "execution_count": 3,
137
  "metadata": {},
138
  "outputs": [
139
  {
@@ -150,7 +150,7 @@
150
  "source": [
151
  "# 生成子图\n",
152
  "# 使用 subgraph_by_node 函数获取子图\n",
153
- "sub_g, new2orig, node_map = subgraph_by_node(graph, sample_dict, node_map, depth=depth)\n",
154
  "# 打印子图的基本信息\n",
155
  "print(sub_g)\n",
156
  "# 获取子图的实体和三元组信息\n",
 
133
  },
134
  {
135
  "cell_type": "code",
136
+ "execution_count": null,
137
  "metadata": {},
138
  "outputs": [
139
  {
 
150
  "source": [
151
  "# 生成子图\n",
152
  "# 使用 subgraph_by_node 函数获取子图\n",
153
+ "sub_g, new2orig, node_map, _ = subgraph_by_node(graph, sample_dict, node_map, depth=depth)\n",
154
  "# 打印子图的基本信息\n",
155
  "print(sub_g)\n",
156
  "# 获取子图的实体和三元组信息\n",
utils.py CHANGED
@@ -104,6 +104,32 @@ def degree_search(graph, node_type, node_name, node_map):
104
  if deg > 0:
105
  print(f" {etype}: {deg}")
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def subgraph_by_node(graph, sample_dict, node_map, depth=1):
108
  """
109
  Get a subgraph centered around a specific node.
@@ -115,6 +141,7 @@ def subgraph_by_node(graph, sample_dict, node_map, depth=1):
115
  Output:
116
  - full_g: The subgraph centered around the node.
117
  """
 
118
  # print(f"Getting subgraph from: {sample_dict}")
119
  for node_type, node_names in sample_dict.items():
120
  for node_name in node_names:
@@ -123,7 +150,9 @@ def subgraph_by_node(graph, sample_dict, node_map, depth=1):
123
  return
124
  # convert node names to node IDs
125
  sample_dict[node_type] = [node_map[node_type][node_name] for node_name in node_names]
126
-
 
 
127
 
128
  out_g, _ = dgl.khop_out_subgraph(graph, sample_dict, k=depth,
129
  relabel_nodes=True, store_ids=True)
@@ -167,8 +196,7 @@ def subgraph_by_node(graph, sample_dict, node_map, depth=1):
167
  # for etype in full_g.canonical_etypes:
168
  # print(f"{etype}: {full_g.num_edges(etype)}")
169
 
170
- return full_g, new2orig, new_node_map
171
-
172
 
173
  def report_subgraph(graph, id_map, save_root='static'):
174
  entities = defaultdict(list)
 
104
  if deg > 0:
105
  print(f" {etype}: {deg}")
106
 
107
+
108
+ def analyze_connections(graph, sample_dict, id_map):
109
+ connection_stats = {}
110
+ for node_type, node_ids in sample_dict.items():
111
+ for node_id in node_ids:
112
+ node_stats = {"connected_nodes": {}, "connected_edges": {}}
113
+ for etype in graph.canonical_etypes:
114
+ # 统计出度
115
+ if etype[0] == node_type:
116
+ neighbors = graph.successors(node_id, etype=etype).tolist()
117
+ node_stats["connected_nodes"].setdefault(etype[2], 0)
118
+ node_stats["connected_nodes"][etype[2]] += len(neighbors)
119
+ node_stats["connected_edges"].setdefault(etype, 0)
120
+ node_stats["connected_edges"][etype] += len(neighbors)
121
+ # 统计入度
122
+ if etype[2] == node_type:
123
+ neighbors = graph.predecessors(node_id, etype=etype).tolist()
124
+ node_stats["connected_nodes"].setdefault(etype[0], 0)
125
+ node_stats["connected_nodes"][etype[0]] += len(neighbors)
126
+ node_stats["connected_edges"].setdefault(etype, 0)
127
+ node_stats["connected_edges"][etype] += len(neighbors)
128
+ connection_stats[(node_type, id_map[node_type][node_id])] = node_stats
129
+ return connection_stats
130
+
131
+
132
+
133
  def subgraph_by_node(graph, sample_dict, node_map, depth=1):
134
  """
135
  Get a subgraph centered around a specific node.
 
141
  Output:
142
  - full_g: The subgraph centered around the node.
143
  """
144
+ cur_id_map = {}
145
  # print(f"Getting subgraph from: {sample_dict}")
146
  for node_type, node_names in sample_dict.items():
147
  for node_name in node_names:
 
150
  return
151
  # convert node names to node IDs
152
  sample_dict[node_type] = [node_map[node_type][node_name] for node_name in node_names]
153
+ cur_id_map[node_type] = {node_map[node_type][node_name]: node_name for node_name in node_names}
154
+
155
+ connection_stats = analyze_connections(graph, sample_dict, cur_id_map)
156
 
157
  out_g, _ = dgl.khop_out_subgraph(graph, sample_dict, k=depth,
158
  relabel_nodes=True, store_ids=True)
 
196
  # for etype in full_g.canonical_etypes:
197
  # print(f"{etype}: {full_g.num_edges(etype)}")
198
 
199
+ return full_g, new2orig, new_node_map, connection_stats
 
200
 
201
  def report_subgraph(graph, id_map, save_root='static'):
202
  entities = defaultdict(list)