File size: 4,011 Bytes
6b59850
 
 
 
 
 
 
 
 
 
38ed701
6b59850
 
 
 
 
 
 
 
 
 
 
 
 
 
6ef5512
 
6b59850
38ed701
6b59850
 
 
 
 
 
 
 
 
38ed701
6b59850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38ed701
 
 
6b59850
7084eba
6b59850
 
67ad3dd
6b59850
 
 
 
 
 
 
 
719a04f
6b59850
 
38ed701
6b59850
 
7084eba
6b59850
7084eba
6b59850
7084eba
38ed701
 
 
 
6b59850
 
 
 
 
 
38ed701
67ad3dd
6b59850
 
67ad3dd
6b59850
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from omegaconf import OmegaConf
import gradio as gr

from dataset import init_dataset, compute_input_output_dims
from extra_features import ExtraFeatures
from demo_model import LGGMText2Graph_Demo
from analysis.spectre_utils import CrossDomainSamplingMetrics
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import torch


cfg = OmegaConf.load('./config.yaml')
hydra_path = '.'


data_loaders, num_classes, max_n_nodes, nodes_dist, edge_types, node_types, n_nodes, cond_dims, cond_emb = init_dataset(cfg.dataset.name, cfg.train.batch_size, hydra_path, cfg.general.condition, cfg.model.transition)

extra_features = ExtraFeatures(cfg.model.extra_features, max_n_nodes)

input_dims, output_dims = compute_input_output_dims(data_loaders['train'], extra_features)

sampling_metrics = CrossDomainSamplingMetrics(data_loaders)

model = LGGMText2Graph_Demo.load_from_checkpoint('cc-deg.ckpt')
# model = LGGMText2Graph_Demo.load_from_checkpoint('cc-deg.ckpt', map_location=torch.device("cpu"))

model.init_prompt_encoder_pretrained()

def calculate_average_degree(graph):
    num_nodes = graph.number_of_nodes()
    num_edges = graph.number_of_edges()
    return (2 * num_edges) / num_nodes if num_nodes > 0 else 0


def predict(text, num_nodes = None):
    # Assuming model.generate and other processes are defined as before
    graphs = model.generate_pretrained(text, int(num_nodes))
    ccs = []
    degs = []
    images = []

    for g in graphs:
        ccs.append(nx.average_clustering(g))
        degs.append(calculate_average_degree(g))

        fig, ax = plt.subplots()
        nx.draw(g, ax=ax)
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        plt.close(fig)
        
        images.append(image)
    
    avg_deg = np.mean(degs)
    avg_cc = np.mean(ccs)

    return images[0], images[1], images[2], ccs[0], ccs[1], ccs[2], degs[0], degs[1], degs[2], avg_cc, avg_deg

def clear(input_text):
    return None, None, None, None, None, None, None, None, None, None, None


with gr.Blocks() as demo:
    gr.Markdown("## Text2Graph Generation Demo")
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(label="Input your text prompt here", placeholder="Type here...")
        with gr.Column():
            input_num = gr.Slider(5, 100, value=25, step = 1, label="Count", info="Number of nodes in the graph to be generated")
        with gr.Column():
            gr.Markdown("### Suggested Prompts")
            gr.Markdown("1. Create a complex network with high clustering coefficient.\n2. Create a graph with extremely low number of triangles.\n 3. Please give me a Power Network with extremely low number of triangles but with medium level of average degree.")

    with gr.Row() as output_row:
        output_images = [gr.Image(label = f"Generated Network #{_}") for _ in range(3)]
    with gr.Row():
        output_texts_cc = [gr.Textbox(label=f"CC #{_}") for _ in range(3)]
    with gr.Row():
        output_texts_deg = [gr.Textbox(label=f"DEG #{_}") for _ in range(3)]
    
    with gr.Row():
        avg_cc_text = gr.Textbox(label="Average Clustering Coefficient")
        avg_deg_text = gr.Textbox(label="Average Degree")

    with gr.Row():
        submit_button = gr.Button("Submit")
        clear_button = gr.Button("Clear")

    # Change function is linked to the submit button
    submit_button.click(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])
    input_text.submit(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])

    # Clear function resets the text input and clears the outputs
    clear_button.click(fn=clear, inputs=[input_text], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])

demo.launch()