import math import pandas as pd import gradio as gr import datetime import numpy as np from dgl.data import YelpDataset import dgl import torch as th from dgl.dataloading import LaborSampler, NeighborSampler data = YelpDataset() # device = 'cuda:0' device = 'cpu' g = data[0].to(device) num_layers = 3 fanouts = [10] * num_layers samplers = [LaborSampler(fanouts, importance_sampling=1), LaborSampler(fanouts, importance_sampling=0), NeighborSampler(fanouts)] names = ['LABOR-1', 'LABOR-0', 'NS'] indices = th.arange(g.num_nodes()).to(device) batch_size=1024 loaders = [dgl.dataloading.DataLoader(g, indices, sampler, batch_size=batch_size, shuffle=True, drop_last=True) for sampler in samplers] def get_time(): return datetime.datetime.now() plot_end = 2 * math.pi def get_plot2(period=1): global plot_end x = np.arange(plot_end - 2 * math.pi, plot_end, 0.02) y = np.sin(2 * math.pi * period * x) update = gr.LinePlot.update( value=pd.DataFrame({"x": x, "y": y}), x="x", y="y", title="Plot (updates every second)", width=600, height=350, ) plot_end += 2 * math.pi if plot_end > 1000: plot_end = 2 * math.pi return update results = [] def get_plot(batch_size=1024): for sampled in zip(*loaders): results.append([s[0].shape for s in sampled]) break t = th.tensor(results) x = "sampler" y = "# vertices" d = {x: [], y: []} for i, name in enumerate(names): yy = t[:, i] d[y] += yy.tolist() d[x] += [name] * yy.shape[0] update = gr.BarPlot.update( value=pd.DataFrame(d), x=x, y=y, title="Number of sampled vertices", width=600, height=350 ) return update # th.tensor(results).mean(dim=0, dtype=th.float64) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): c_time2 = gr.Textbox(label="Current Time refreshed every second") gr.Textbox( "Change the value of the slider to automatically update the plot", label="", ) batch_size = gr.Number( label="batch size", value=1024, show_label=True ) plot = gr.BarPlot(show_label=False) with gr.Column(): name = gr.Textbox(label="Enter your name") greeting = gr.Textbox(label="Greeting") button = gr.Button(value="Greet") button.click(lambda s: f"Hello {s}", name, greeting) demo.load(lambda: datetime.datetime.now(), None, c_time2, every=10) dep = demo.load(get_plot, None, plot, every=10) batch_size.submit(get_plot, batch_size, plot, every=10, cancels=[dep]) if __name__ == "__main__": demo.queue().launch()