SauravMaheshkar commited on
Commit
2891210
1 Parent(s): 778ea0f

feat: add initial gradio app

Browse files
Files changed (3) hide show
  1. .gitignore +5 -0
  2. app.py +172 -0
  3. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ **/.DS_Store
2
+
3
+ .venv/
4
+ artifacts/
5
+
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import os
3
+ from copy import deepcopy
4
+
5
+ import dhg
6
+ import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+ import pandas as pd
9
+ from dhg.visualization.structure.defaults import (default_hypergraph_strength,
10
+ default_hypergraph_style,
11
+ default_size)
12
+ from dhg.visualization.structure.layout import force_layout
13
+ from dhg.visualization.structure.utils import draw_circle_edge, draw_vertex
14
+ from huggingface_hub import hf_hub_download
15
+
16
+
17
+ def draw_hypergraph(
18
+ hg: "dhg.Hypergraph",
19
+ e_style="circle",
20
+ v_label=None,
21
+ v_size=1.0,
22
+ v_color="r",
23
+ v_line_width=1.0,
24
+ e_color="gray",
25
+ e_fill_color="whitesmoke",
26
+ e_line_width=1.0,
27
+ font_size=1.0,
28
+ font_family="sans-serif",
29
+ push_v_strength=1.0,
30
+ push_e_strength=1.0,
31
+ pull_e_strength=1.0,
32
+ pull_center_strength=1.0,
33
+ ):
34
+ fig, ax = plt.subplots(figsize=(6, 6))
35
+
36
+ num_v, e_list = hg.num_v, deepcopy(hg.e[0])
37
+ # default configures
38
+ v_color, e_color, e_fill_color = default_hypergraph_style(
39
+ hg.num_v, hg.num_e, v_color, e_color, e_fill_color
40
+ )
41
+ v_size, v_line_width, e_line_width, font_size = default_size(
42
+ num_v, e_list, v_size, v_line_width, e_line_width
43
+ )
44
+ (
45
+ push_v_strength,
46
+ push_e_strength,
47
+ pull_e_strength,
48
+ pull_center_strength,
49
+ ) = default_hypergraph_strength(
50
+ num_v,
51
+ e_list,
52
+ push_v_strength,
53
+ push_e_strength,
54
+ pull_e_strength,
55
+ pull_center_strength,
56
+ )
57
+ # layout
58
+ v_coor = force_layout(
59
+ num_v,
60
+ e_list,
61
+ push_v_strength,
62
+ push_e_strength,
63
+ pull_e_strength,
64
+ pull_center_strength,
65
+ )
66
+ draw_circle_edge(
67
+ ax,
68
+ v_coor,
69
+ v_size,
70
+ e_list,
71
+ e_color,
72
+ e_fill_color,
73
+ e_line_width,
74
+ )
75
+
76
+ draw_vertex(
77
+ ax,
78
+ v_coor,
79
+ v_label,
80
+ font_size,
81
+ font_family,
82
+ v_size,
83
+ v_color,
84
+ v_line_width,
85
+ )
86
+
87
+ plt.xlim((0, 1.0))
88
+ plt.ylim((0, 1.0))
89
+ plt.axis("off")
90
+ fig.tight_layout()
91
+
92
+ return fig
93
+
94
+
95
+ def plot_dataset(dataset_choice: str, sampling_choice: str, split_choice: str):
96
+ os.makedirs("artifacts", exist_ok=True)
97
+ hf_hub_download(
98
+ filename=f"processed/{sampling_choice}/{split_choice}_df.csv",
99
+ local_dir="./artifacts/",
100
+ repo_id=f"SauravMaheshkar/{dataset_choice}",
101
+ repo_type="dataset",
102
+ )
103
+
104
+ df = pd.read_csv(f"artifacts/processed/{sampling_choice}/{split_choice}_df.csv")
105
+
106
+ num_vertices = len(df)
107
+ edge_list = df["nodes"].values.tolist()
108
+ edge_list = [ast.literal_eval(edges) for edges in edge_list]
109
+
110
+ hypergraph = dhg.Hypergraph(num_vertices, edge_list)
111
+
112
+ fig = draw_hypergraph(hypergraph)
113
+ return fig
114
+
115
+
116
+ with gr.Blocks() as demo:
117
+
118
+ with gr.Row():
119
+ dataset_choices = gr.Dropdown(
120
+ choices=[
121
+ "email-Eu",
122
+ "email-Enron",
123
+ "NDC-classes",
124
+ "tags-math-sx",
125
+ "email-Eu-25",
126
+ "NDC-substances",
127
+ "congress-bills",
128
+ "tags-ask-ubuntu",
129
+ "email-Enron-25",
130
+ "NDC-classes-25",
131
+ "threads-ask-ubuntu",
132
+ "contact-high-school",
133
+ "NDC-substances-25",
134
+ "congress-bills-25",
135
+ "contact-primary-school",
136
+ ],
137
+ value="email-Enron-25",
138
+ label="Please choose a dataset",
139
+ interactive=True,
140
+ )
141
+
142
+ sampling_choice = gr.Dropdown(
143
+ choices=[
144
+ "transductive",
145
+ "inductive",
146
+ ],
147
+ value="inductive",
148
+ label="Choose sampling type",
149
+ interactive=True,
150
+ )
151
+
152
+ split_choice = gr.Dropdown(
153
+ choices=[
154
+ "train",
155
+ "valid",
156
+ "test",
157
+ ],
158
+ value="test",
159
+ label="Choose split",
160
+ interactive=True,
161
+ )
162
+
163
+ output_plot = gr.Plot(label="Hypergraph plot")
164
+
165
+ btn = gr.Button("Visualise")
166
+ btn.click(
167
+ fn=plot_dataset,
168
+ inputs=[dataset_choices, sampling_choice, split_choice],
169
+ outputs=output_plot,
170
+ )
171
+
172
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ dhg
2
+ gradio
3
+ huggingface_hub
4
+ matplotlib
5
+ pandas