mfbalin commited on
Commit
8b70c2e
1 Parent(s): 2add1c3
Files changed (1) hide show
  1. app.py +47 -27
app.py CHANGED
@@ -6,47 +6,39 @@ import gradio as gr
6
  import datetime
7
  import numpy as np
8
 
9
- # from dgl.data import YelpDataset
10
 
11
- # import dgl
12
- # import torch as th
13
 
14
- # from dgl.dataloading import LaborSampler, NeighborSampler
15
 
16
- # data = YelpDataset()
17
 
18
  # device = 'cuda:0'
 
19
 
20
- # g = data[0].to(device)
21
 
22
- # num_layers = 3
23
 
24
- # fanouts = [10] * num_layers
25
 
26
- # samplers = [LaborSampler(fanouts, importance_sampling=1), LaborSampler(fanouts, importance_sampling=0), NeighborSampler(fanouts)]
27
 
28
- # names = ['LABOR-1', 'LABOR-0', 'NS']
29
 
30
- # indices = th.arange(g.num_nodes()).to(device)
31
 
32
- # loaders = [dgl.dataloading.DataLoader(g, indices, sampler, batch_size=1024, shuffle=True, drop_last=True) for sampler in samplers]
33
-
34
- # results = []
35
-
36
- # for sampled in zip(*loaders):
37
- # numbers.append([s[0].shape for s in sampled])
38
- # print(numbers[-1])
39
-
40
- # th.tensor(numbers).mean(dim=0, dtype=th.float64)
41
 
42
  def get_time():
43
  return datetime.datetime.now()
44
 
45
-
46
  plot_end = 2 * math.pi
47
 
48
 
49
- def get_plot(period=1):
50
  global plot_end
51
  x = np.arange(plot_end - 2 * math.pi, plot_end, 0.02)
52
  y = np.sin(2 * math.pi * period * x)
@@ -63,6 +55,34 @@ def get_plot(period=1):
63
  plot_end = 2 * math.pi
64
  return update
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  with gr.Blocks() as demo:
68
  with gr.Row():
@@ -72,19 +92,19 @@ with gr.Blocks() as demo:
72
  "Change the value of the slider to automatically update the plot",
73
  label="",
74
  )
75
- period = gr.Number(
76
  label="batch size", value=1024, show_label=True
77
  )
78
- plot = gr.LinePlot(show_label=False)
79
  with gr.Column():
80
  name = gr.Textbox(label="Enter your name")
81
  greeting = gr.Textbox(label="Greeting")
82
  button = gr.Button(value="Greet")
83
  button.click(lambda s: f"Hello {s}", name, greeting)
84
 
85
- demo.load(lambda: datetime.datetime.now(), None, c_time2, every=1)
86
- dep = demo.load(get_plot, None, plot, every=1)
87
- period.submit(get_plot, period, plot, every=1, cancels=[dep])
88
 
89
  if __name__ == "__main__":
90
  demo.queue().launch()
 
6
  import datetime
7
  import numpy as np
8
 
9
+ from dgl.data import YelpDataset
10
 
11
+ import dgl
12
+ import torch as th
13
 
14
+ from dgl.dataloading import LaborSampler, NeighborSampler
15
 
16
+ data = YelpDataset()
17
 
18
  # device = 'cuda:0'
19
+ device = 'cpu'
20
 
21
+ g = data[0].to(device)
22
 
23
+ num_layers = 3
24
 
25
+ fanouts = [10] * num_layers
26
 
27
+ samplers = [LaborSampler(fanouts, importance_sampling=1), LaborSampler(fanouts, importance_sampling=0), NeighborSampler(fanouts)]
28
 
29
+ names = ['LABOR-1', 'LABOR-0', 'NS']
30
 
31
+ indices = th.arange(g.num_nodes()).to(device)
32
 
33
+ loaders = [dgl.dataloading.DataLoader(g, indices, sampler, batch_size=batch_size, shuffle=True, drop_last=True) for sampler in samplers]
 
 
 
 
 
 
 
 
34
 
35
  def get_time():
36
  return datetime.datetime.now()
37
 
 
38
  plot_end = 2 * math.pi
39
 
40
 
41
+ def get_plot2(period=1):
42
  global plot_end
43
  x = np.arange(plot_end - 2 * math.pi, plot_end, 0.02)
44
  y = np.sin(2 * math.pi * period * x)
 
55
  plot_end = 2 * math.pi
56
  return update
57
 
58
+ results = []
59
+
60
+ def get_plot(batch_size=1024):
61
+ for sampled in zip(*loaders):
62
+ results.append([s[0].shape for s in sampled])
63
+ break
64
+
65
+ y = th.tensor(results)
66
+
67
+ d = {"x": [], "y": []}
68
+
69
+ for i, name in enumerate(names):
70
+ yy = y[:, i]
71
+ d[y] += yy
72
+ d[x] += [name] * yy.shape[0]
73
+
74
+ update = gr.BarPlot.update(
75
+ value=pd.DataFrame(d),
76
+ x="x",
77
+ y="y",
78
+ title="Number of sampled vertices",
79
+ width=600,
80
+ height=350
81
+ )
82
+
83
+ return update
84
+ # th.tensor(results).mean(dim=0, dtype=th.float64)
85
+
86
 
87
  with gr.Blocks() as demo:
88
  with gr.Row():
 
92
  "Change the value of the slider to automatically update the plot",
93
  label="",
94
  )
95
+ batch_size = gr.Number(
96
  label="batch size", value=1024, show_label=True
97
  )
98
+ plot = gr.BarPlot(show_label=False)
99
  with gr.Column():
100
  name = gr.Textbox(label="Enter your name")
101
  greeting = gr.Textbox(label="Greeting")
102
  button = gr.Button(value="Greet")
103
  button.click(lambda s: f"Hello {s}", name, greeting)
104
 
105
+ demo.load(lambda: datetime.datetime.now(), None, c_time2, every=10)
106
+ dep = demo.load(get_plot, None, plot, every=10)
107
+ batch_size.submit(get_plot, batch_size, plot, every=10, cancels=[dep])
108
 
109
  if __name__ == "__main__":
110
  demo.queue().launch()