from pyvis.network import Network import gradio as gr from transformers import pipeline import os model_id = "DReAMy-lib/t5-base-DreamBank-Generation-Act-Char" def get_graph_dict(graph_text): edge_labels = [] if graph_text == "": edge_labels = {("No_Graphs", None):None} else: try: for trpl in graph_text[1:-1].split("), ("): h,r,t = trpl.split(" : ") if t == "none": t = h edge_labels.append((h,t, "_".join(r.split(" ")))) except: edge_labels.append(("Error", None, None)) return edge_labels def text_to_graph(text): # Use a pipeline as a high-level helper pipe = pipeline( "text2text-generation", model=model_id, max_length=300, min_length=5, ) # generate text graph graph_text = pipe(text) graph_text = graph_text[0]["generated_text"] # get the nodes: label dict edge_labels = get_graph_dict(graph_text) # create network net = Network(directed=True) # nodes & edges for (h, t, r) in edge_labels: if (h == "Error") or (h == "No_Graphs"): net.add_node(h, shape="circle") continue else: net.add_node(h, shape="circle") net.add_node(t, shape="circle") net.add_edge(h, t, title=r, label=r) # set structure net.repulsion( node_distance=200, central_gravity=0.2, spring_length=200, spring_strength=0.05, damping=0.09 ) net.set_edge_smooth('dynamic') # get html html = net.generate_html() html = html.replace("'", "\"") html_s = f"""""" return html_s, graph_text