Zaherrr's picture
Update app.py
50ae381 verified
raw
history blame
3.37 kB
import gradio as gr
from datasets import load_dataset, Dataset
from PIL import Image
import io
import base64
import json
from graph_visualization import visualize_graph
# branch_name = "edges-sorted-ascending"
branch_name = "Sorted_edges"
# Load the dataset
# dataset = load_dataset("Zaherrr/OOP_KG_Dataset", split='data', revision=branch_name)
dataset = load_dataset("Zaherrr/OOP_KG_MarkMap_Synthetic_Dataset", split='data') #, revision=branch_name)
print(f'This is the dataset: {dataset}')
print(dataset.info)
print(f'This is an example: {dataset[-5]}')
def reshape_json_data_to_fit_visualize_graph(graph_data):
nodes = graph_data["nodes"]
edges = graph_data["edges"]
transformed_nodes = [
{"id": nodes["id"][idx], "label": nodes["label"][idx]}
for idx in range(len(nodes["id"]))
]
transformed_edges = [
{"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"}
for idx in range(len(edges["source"]))
]
graph_data = {"nodes": transformed_nodes, "edges": transformed_edges}
return graph_data
def display_example(index):
example = dataset[index]
img = example["image"]
# Get image dimensions
img_width, img_height = img.size
# Prepare the graph data
graph_data = {"nodes": example["nodes"], "edges": example["edges"]}
transformed_graph_data = reshape_json_data_to_fit_visualize_graph(graph_data)
# Generate the graph visualization
graph_html = visualize_graph(transformed_graph_data)
# Modify the iframe to have a fixed height
graph_html = graph_html.replace('height: 100vh;', 'height: 500px;')
# Convert graph_data to a formatted JSON string
json_data = json.dumps(transformed_graph_data, indent=2)
return img, graph_html, json_data, transformed_graph_data, f"Width: {img_width}px, Height: {img_height}px"
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# Knowledge Graph Visualizer for the [Zaherrr/OOP_KG_MarkMap_Synthetic_Dataset](https://huggingface.co/datasets/Zaherrr/OOP_KG_MarkMap_Synthetic_Dataset) dataset")
with gr.Row():
index_slider = gr.Slider(
minimum=0,
maximum=len(dataset) - 1,
step=1,
label="Example Index"
)
with gr.Row():
image_output = gr.Image(type="pil", label="Image", height=500)
graph_output = gr.HTML(label="Knowledge Graph")
with gr.Row():
dimensions_output = gr.Textbox(
label="Image Dimensions (pixels)",
placeholder="Width and Height will appear here",
interactive=False,
)
with gr.Row():
json_output = gr.Code(language="json", label="Graph JSON Data")
text_output = gr.Textbox(
label="Graph Text Data",
placeholder="Text data will appear here",
interactive=False,
)
index_slider.change(
fn=display_example,
inputs=[index_slider],
outputs=[image_output, graph_output, json_output, text_output, dimensions_output],
)
return demo
# Create and launch the interface
if __name__ == "__main__":
demo = create_interface()
demo.launch()