labor / app.py
mfbalin's picture
fix
5828c6f
import math
import pandas as pd
import gradio as gr
import datetime
import numpy as np
from dgl.data import YelpDataset
import dgl
import torch as th
from dgl.dataloading import LaborSampler, NeighborSampler
data = YelpDataset()
# device = 'cuda:0'
device = 'cpu'
g = data[0].to(device)
num_layers = 3
fanouts = [10] * num_layers
samplers = [LaborSampler(fanouts, importance_sampling=1), LaborSampler(fanouts, importance_sampling=0), NeighborSampler(fanouts)]
names = ['LABOR-1', 'LABOR-0', 'NS']
indices = th.arange(g.num_nodes()).to(device)
batch_size=1024
loaders = [dgl.dataloading.DataLoader(g, indices, sampler, batch_size=batch_size, shuffle=True, drop_last=True) for sampler in samplers]
def get_time():
return datetime.datetime.now()
plot_end = 2 * math.pi
def get_plot2(period=1):
global plot_end
x = np.arange(plot_end - 2 * math.pi, plot_end, 0.02)
y = np.sin(2 * math.pi * period * x)
update = gr.LinePlot.update(
value=pd.DataFrame({"x": x, "y": y}),
x="x",
y="y",
title="Plot (updates every second)",
width=600,
height=350,
)
plot_end += 2 * math.pi
if plot_end > 1000:
plot_end = 2 * math.pi
return update
results = []
def get_plot(batch_size=1024):
for sampled in zip(*loaders):
results.append([s[0].shape for s in sampled])
break
y = th.tensor(results)
d = {"x": [], "y": []}
for i, name in enumerate(names):
yy = y[:, i]
d[y] += yy.tolist()
d[x] += [name] * yy.shape[0]
update = gr.BarPlot.update(
value=pd.DataFrame(d),
x="x",
y="y",
title="Number of sampled vertices",
width=600,
height=350
)
return update
# th.tensor(results).mean(dim=0, dtype=th.float64)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
c_time2 = gr.Textbox(label="Current Time refreshed every second")
gr.Textbox(
"Change the value of the slider to automatically update the plot",
label="",
)
batch_size = gr.Number(
label="batch size", value=1024, show_label=True
)
plot = gr.BarPlot(show_label=False)
with gr.Column():
name = gr.Textbox(label="Enter your name")
greeting = gr.Textbox(label="Greeting")
button = gr.Button(value="Greet")
button.click(lambda s: f"Hello {s}", name, greeting)
demo.load(lambda: datetime.datetime.now(), None, c_time2, every=10)
dep = demo.load(get_plot, None, plot, every=10)
batch_size.submit(get_plot, batch_size, plot, every=10, cancels=[dep])
if __name__ == "__main__":
demo.queue().launch()