Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
from phi3_instruct_graph import MODEL_LIST, Phi3InstructGraph | |
from textwrap import dedent | |
import rapidjson | |
import spaces | |
from pyvis.network import Network | |
import networkx as nx | |
import spacy | |
from spacy import displacy | |
from spacy.tokens import Span | |
import random | |
json_example = {'nodes': [{'id': 'Aerosmith', 'type': 'organization', 'detailed_type': 'rock band'}, {'id': 'Steven Tyler', 'type': 'person', 'detailed_type': 'lead singer'}, {'id': 'vocal cord injury', 'type': 'medical condition', 'detailed_type': 'fractured larynx'}, {'id': 'retirement', 'type': 'event', 'detailed_type': 'announcement'}, {'id': 'touring', 'type': 'activity', 'detailed_type': 'musical performance'}, {'id': 'September 2023', 'type': 'date', 'detailed_type': 'specific time'}], 'edges': [{'from': 'Aerosmith', 'to': 'Steven Tyler', 'label': 'led by'}, {'from': 'Steven Tyler', 'to': 'vocal cord injury', 'label': 'suffered'}, {'from': 'vocal cord injury', 'to': 'retirement', 'label': 'caused'}, {'from': 'retirement', 'to': 'touring', 'label': 'ended'}, {'from': 'vocal cord injury', 'to': 'September 2023', 'label': 'occurred in'}]} | |
def extract(text, model): | |
model = Phi3InstructGraph(model=model) | |
result = model.extract(text) | |
return rapidjson.loads(result) | |
def handle_text(text): | |
return " ".join(text.split()) | |
def get_random_color(): | |
return f"#{random.randint(0, 0xFFFFFF):06x}" | |
def get_random_light_color(): | |
# Generate higher RGB values to ensure a lighter color | |
r = random.randint(128, 255) | |
g = random.randint(128, 255) | |
b = random.randint(128, 255) | |
return f"#{r:02x}{g:02x}{b:02x}" | |
def get_random_color(): | |
return f"#{random.randint(0, 0xFFFFFF):06x}" | |
def find_token_indices(doc, substring, text): | |
result = [] | |
start_index = text.find(substring) | |
while start_index != -1: | |
end_index = start_index + len(substring) | |
start_token = None | |
end_token = None | |
for token in doc: | |
if token.idx == start_index: | |
start_token = token.i | |
if token.idx + len(token) == end_index: | |
end_token = token.i + 1 | |
if start_token is None or end_token is None: | |
print(f"Token boundaries not found for '{substring}' at index {start_index}") | |
else: | |
result.append({ | |
"start": start_token, | |
"end": end_token | |
}) | |
# Search for next occurrence | |
start_index = text.find(substring, end_index) | |
if not result: | |
print(f"Token boundaries not found for '{substring}'") | |
return result | |
def create_custom_entity_viz(data, full_text): | |
nlp = spacy.blank("xx") | |
doc = nlp(full_text) | |
spans = [] | |
colors = {} | |
for node in data["nodes"]: | |
# entity_spans = [m.span() for m in re.finditer(re.escape(node["id"]), full_text)] | |
entity_spans = find_token_indices(doc, node["id"], full_text) | |
for dataentity in entity_spans: | |
start = dataentity["start"] | |
end = dataentity["end"] | |
print("entity spans:", entity_spans) | |
if start < len(doc) and end <= len(doc): | |
span = Span(doc, start, end, label=node["type"]) | |
# print(span) | |
spans.append(span) | |
if node["type"] not in colors: | |
colors[node["type"]] = get_random_light_color() | |
for span in spans: | |
print(f"Span: {span.text}, Label: {span.label_}") | |
doc.set_ents(spans, default="unmodified") | |
doc.spans["sc"] = spans | |
options = { | |
"colors": colors, | |
"ents": list(colors.keys()), | |
"style": "ent", | |
"manual": True | |
} | |
html = displacy.render(doc, style="span", options=options) | |
return html | |
def create_graph(json_data): | |
G = nx.Graph() | |
for node in json_data['nodes']: | |
G.add_node(node['id'], title=f"{node['type']}: {node['detailed_type']}") | |
for edge in json_data['edges']: | |
G.add_edge(edge['from'], edge['to'], title=edge['label'], label=edge['label']) | |
nt = Network( | |
width="720px", | |
height="600px", | |
directed=True, | |
notebook=False, | |
# bgcolor="#111827", | |
# font_color="white" | |
bgcolor="#FFFFFF", | |
font_color="#111827" | |
) | |
nt.from_nx(G) | |
nt.barnes_hut( | |
gravity=-3000, | |
central_gravity=0.3, | |
spring_length=50, | |
spring_strength=0.001, | |
damping=0.09, | |
overlap=0, | |
) | |
# Customize edge appearance | |
# for edge in nt.edges: | |
# edge['font'] = {'size': 12, 'color': '#FFD700', 'face': 'Arial'} # Removed strokeWidth | |
# edge['color'] = {'color': '#FF4500', 'highlight': '#FF4500'} | |
# edge['width'] = 1 | |
# edge['arrows'] = {'to': {'enabled': True, 'type': 'arrow'}} | |
# edge['smooth'] = {'type': 'curvedCW', 'roundness': 0.2} | |
html = nt.generate_html() | |
# need to remove ' from HTML | |
html = html.replace("'", '"') | |
# return html | |
return f"""<iframe style="width: 140%; height: 620px; margin: 0 auto;" name="result" | |
allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;" | |
sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>""" | |
def process_and_visualize(text, model): | |
if not text or not model: | |
raise gr.Error("Text and model must be provided.") | |
json_data = extract(text, model) | |
# json_data = json_example | |
print(json_data) | |
entities_viz = create_custom_entity_viz(json_data, text) | |
graph_html = create_graph(json_data) | |
return graph_html, entities_viz, json_data | |
with gr.Blocks(title="Phi-3 Mini 4k Instruct Graph (by Emergent Methods") as demo: | |
gr.Markdown("# Phi-3 Mini 4k Instruct Graph (by Emergent Methods)") | |
gr.Markdown("Extract a JSON graph from a text input and visualize it.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_model = gr.Dropdown( | |
MODEL_LIST, label="Model", | |
# value=MODEL_LIST[0] | |
) | |
input_text = gr.TextArea(label="Text", info="The text to be extracted") | |
examples = gr.Examples( | |
examples=[ | |
handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing | |
lead singer Steven Tyler's unrecoverable vocal cord injury. | |
The decision comes after months of unsuccessful treatment for Tyler's fractured larynx, | |
which he suffered in September 2023."""), | |
handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual | |
court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI) | |
in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe, | |
pleaded not guilty to the charges."""), | |
], | |
inputs=input_text | |
) | |
submit_button = gr.Button("Extract and Visualize") | |
with gr.Column(scale=1): | |
output_entity_viz = gr.HTML(label="Entities Visualization", show_label=True) | |
output_graph = gr.HTML(label="Graph Visualization", show_label=True) | |
# output_json = gr.JSON(label="JSON Graph") | |
submit_button.click( | |
fn=process_and_visualize, | |
inputs=[input_text, input_model], | |
outputs=[output_graph, output_entity_viz] | |
) | |
demo.launch(share=False) |