SauravMaheshkar's picture
feat: add initial gradio app
2891210 unverified
raw
history blame
No virus
4.35 kB
import ast
import os
from copy import deepcopy
import dhg
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
from dhg.visualization.structure.defaults import (default_hypergraph_strength,
default_hypergraph_style,
default_size)
from dhg.visualization.structure.layout import force_layout
from dhg.visualization.structure.utils import draw_circle_edge, draw_vertex
from huggingface_hub import hf_hub_download
def draw_hypergraph(
hg: "dhg.Hypergraph",
e_style="circle",
v_label=None,
v_size=1.0,
v_color="r",
v_line_width=1.0,
e_color="gray",
e_fill_color="whitesmoke",
e_line_width=1.0,
font_size=1.0,
font_family="sans-serif",
push_v_strength=1.0,
push_e_strength=1.0,
pull_e_strength=1.0,
pull_center_strength=1.0,
):
fig, ax = plt.subplots(figsize=(6, 6))
num_v, e_list = hg.num_v, deepcopy(hg.e[0])
# default configures
v_color, e_color, e_fill_color = default_hypergraph_style(
hg.num_v, hg.num_e, v_color, e_color, e_fill_color
)
v_size, v_line_width, e_line_width, font_size = default_size(
num_v, e_list, v_size, v_line_width, e_line_width
)
(
push_v_strength,
push_e_strength,
pull_e_strength,
pull_center_strength,
) = default_hypergraph_strength(
num_v,
e_list,
push_v_strength,
push_e_strength,
pull_e_strength,
pull_center_strength,
)
# layout
v_coor = force_layout(
num_v,
e_list,
push_v_strength,
push_e_strength,
pull_e_strength,
pull_center_strength,
)
draw_circle_edge(
ax,
v_coor,
v_size,
e_list,
e_color,
e_fill_color,
e_line_width,
)
draw_vertex(
ax,
v_coor,
v_label,
font_size,
font_family,
v_size,
v_color,
v_line_width,
)
plt.xlim((0, 1.0))
plt.ylim((0, 1.0))
plt.axis("off")
fig.tight_layout()
return fig
def plot_dataset(dataset_choice: str, sampling_choice: str, split_choice: str):
os.makedirs("artifacts", exist_ok=True)
hf_hub_download(
filename=f"processed/{sampling_choice}/{split_choice}_df.csv",
local_dir="./artifacts/",
repo_id=f"SauravMaheshkar/{dataset_choice}",
repo_type="dataset",
)
df = pd.read_csv(f"artifacts/processed/{sampling_choice}/{split_choice}_df.csv")
num_vertices = len(df)
edge_list = df["nodes"].values.tolist()
edge_list = [ast.literal_eval(edges) for edges in edge_list]
hypergraph = dhg.Hypergraph(num_vertices, edge_list)
fig = draw_hypergraph(hypergraph)
return fig
with gr.Blocks() as demo:
with gr.Row():
dataset_choices = gr.Dropdown(
choices=[
"email-Eu",
"email-Enron",
"NDC-classes",
"tags-math-sx",
"email-Eu-25",
"NDC-substances",
"congress-bills",
"tags-ask-ubuntu",
"email-Enron-25",
"NDC-classes-25",
"threads-ask-ubuntu",
"contact-high-school",
"NDC-substances-25",
"congress-bills-25",
"contact-primary-school",
],
value="email-Enron-25",
label="Please choose a dataset",
interactive=True,
)
sampling_choice = gr.Dropdown(
choices=[
"transductive",
"inductive",
],
value="inductive",
label="Choose sampling type",
interactive=True,
)
split_choice = gr.Dropdown(
choices=[
"train",
"valid",
"test",
],
value="test",
label="Choose split",
interactive=True,
)
output_plot = gr.Plot(label="Hypergraph plot")
btn = gr.Button("Visualise")
btn.click(
fn=plot_dataset,
inputs=[dataset_choices, sampling_choice, split_choice],
outputs=output_plot,
)
demo.launch()