from pyvis.network import Network import gradio as gr from transformers import pipeline import os model_id = "DReAMy-lib/t5-base-DreamBank-Generation-Act-Char" def get_graph_dict(graph_text): edge_labels = {} if graph_text == "": edge_labels = {("No_Graphs", None):None} else: try: for trpl in graph_text[1:-1].split(" | "): h,r,t = trpl[1:-1].split(" # ") edge_labels[(h,t)] = r except: edge_labels = {("Error", None):None} return edge_labels def text_to_graph(text): # Use a pipeline as a high-level helper pipe = pipeline( "text2text-generation", model=model_id, max_length=300, min_length=5, ) # generate text graph graph_text = pipe(text) graph_text = graph_text[0]["generated_text"] # get the nodes: label dict edge_labels = get_graph_dict(graph_text) # create network net = Network(directed=True) # nodes & edges for (h, t), r in edge_labels.items(): if (h == "Error") or (h == "No_Graphs"): net.add_node(h, shape="circle") continue else: net.add_node(h, shape="circle") net.add_node(t, shape="circle") net.add_edge(h, t, title=r, label=r) # set structure net.repulsion( node_distance=200, central_gravity=0.2, spring_length=200, spring_strength=0.05, damping=0.09 ) net.set_edge_smooth('dynamic') # get html html = net.generate_html() html = html.replace("'", "\"") html_s = f"""""" return html_s, graph_text