import ast import os from copy import deepcopy import dhg import gradio as gr import matplotlib.pyplot as plt import pandas as pd from dhg.visualization.structure.defaults import (default_hypergraph_strength, default_hypergraph_style, default_size) from dhg.visualization.structure.layout import force_layout from dhg.visualization.structure.utils import draw_circle_edge, draw_vertex from huggingface_hub import hf_hub_download def draw_hypergraph( hg: "dhg.Hypergraph", e_style="circle", v_label=None, v_size=1.0, v_color="r", v_line_width=1.0, e_color="gray", e_fill_color="whitesmoke", e_line_width=1.0, font_size=1.0, font_family="sans-serif", push_v_strength=1.0, push_e_strength=1.0, pull_e_strength=1.0, pull_center_strength=1.0, ): fig, ax = plt.subplots(figsize=(6, 6)) num_v, e_list = hg.num_v, deepcopy(hg.e[0]) # default configures v_color, e_color, e_fill_color = default_hypergraph_style( hg.num_v, hg.num_e, v_color, e_color, e_fill_color ) v_size, v_line_width, e_line_width, font_size = default_size( num_v, e_list, v_size, v_line_width, e_line_width ) ( push_v_strength, push_e_strength, pull_e_strength, pull_center_strength, ) = default_hypergraph_strength( num_v, e_list, push_v_strength, push_e_strength, pull_e_strength, pull_center_strength, ) # layout v_coor = force_layout( num_v, e_list, push_v_strength, push_e_strength, pull_e_strength, pull_center_strength, ) draw_circle_edge( ax, v_coor, v_size, e_list, e_color, e_fill_color, e_line_width, ) draw_vertex( ax, v_coor, v_label, font_size, font_family, v_size, v_color, v_line_width, ) plt.xlim((0, 1.0)) plt.ylim((0, 1.0)) plt.axis("off") fig.tight_layout() return fig def plot_dataset(dataset_choice: str, sampling_choice: str, split_choice: str): os.makedirs("artifacts", exist_ok=True) hf_hub_download( filename=f"processed/{sampling_choice}/{split_choice}_df.csv", local_dir="./artifacts/", repo_id=f"SauravMaheshkar/{dataset_choice}", repo_type="dataset", ) df = pd.read_csv(f"artifacts/processed/{sampling_choice}/{split_choice}_df.csv") num_vertices = len(df) edge_list = df["nodes"].values.tolist() edge_list = [ast.literal_eval(edges) for edges in edge_list] hypergraph = dhg.Hypergraph(num_vertices, edge_list) fig = draw_hypergraph(hypergraph) return fig with gr.Blocks() as demo: with gr.Row(): dataset_choices = gr.Dropdown( choices=[ "email-Eu", "email-Enron", "NDC-classes", "tags-math-sx", "email-Eu-25", "NDC-substances", "congress-bills", "tags-ask-ubuntu", "email-Enron-25", "NDC-classes-25", "threads-ask-ubuntu", "contact-high-school", "NDC-substances-25", "congress-bills-25", "contact-primary-school", ], value="email-Enron-25", label="Please choose a dataset", interactive=True, ) sampling_choice = gr.Dropdown( choices=[ "transductive", "inductive", ], value="inductive", label="Choose sampling type", interactive=True, ) split_choice = gr.Dropdown( choices=[ "train", "valid", "test", ], value="test", label="Choose split", interactive=True, ) output_plot = gr.Plot(label="Hypergraph plot") btn = gr.Button("Visualise") btn.click( fn=plot_dataset, inputs=[dataset_choices, sampling_choice, split_choice], outputs=output_plot, ) demo.launch()