File size: 7,890 Bytes
b6cf9eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import spaces
import gradio as gr
from phi3_instruct_graph import MODEL_LIST, Phi3InstructGraph
from textwrap import dedent
import rapidjson
import spaces
from pyvis.network import Network
import networkx as nx
import spacy
from spacy import displacy
from spacy.tokens import Span
import random
json_example = {'nodes': [{'id': 'Aerosmith', 'type': 'organization', 'detailed_type': 'rock band'}, {'id': 'Steven Tyler', 'type': 'person', 'detailed_type': 'lead singer'}, {'id': 'vocal cord injury', 'type': 'medical condition', 'detailed_type': 'fractured larynx'}, {'id': 'retirement', 'type': 'event', 'detailed_type': 'announcement'}, {'id': 'touring', 'type': 'activity', 'detailed_type': 'musical performance'}, {'id': 'September 2023', 'type': 'date', 'detailed_type': 'specific time'}], 'edges': [{'from': 'Aerosmith', 'to': 'Steven Tyler', 'label': 'led by'}, {'from': 'Steven Tyler', 'to': 'vocal cord injury', 'label': 'suffered'}, {'from': 'vocal cord injury', 'to': 'retirement', 'label': 'caused'}, {'from': 'retirement', 'to': 'touring', 'label': 'ended'}, {'from': 'vocal cord injury', 'to': 'September 2023', 'label': 'occurred in'}]}
@spaces.GPU
def extract(text, model):
model = Phi3InstructGraph(model=model)
result = model.extract(text)
return rapidjson.loads(result)
def handle_text(text):
return " ".join(text.split())
def get_random_color():
return f"#{random.randint(0, 0xFFFFFF):06x}"
def get_random_light_color():
# Generate higher RGB values to ensure a lighter color
r = random.randint(128, 255)
g = random.randint(128, 255)
b = random.randint(128, 255)
return f"#{r:02x}{g:02x}{b:02x}"
def get_random_color():
return f"#{random.randint(0, 0xFFFFFF):06x}"
def find_token_indices(doc, substring, text):
result = []
start_index = text.find(substring)
while start_index != -1:
end_index = start_index + len(substring)
start_token = None
end_token = None
for token in doc:
if token.idx == start_index:
start_token = token.i
if token.idx + len(token) == end_index:
end_token = token.i + 1
if start_token is None or end_token is None:
print(f"Token boundaries not found for '{substring}' at index {start_index}")
else:
result.append({
"start": start_token,
"end": end_token
})
# Search for next occurrence
start_index = text.find(substring, end_index)
if not result:
print(f"Token boundaries not found for '{substring}'")
return result
def create_custom_entity_viz(data, full_text):
nlp = spacy.blank("xx")
doc = nlp(full_text)
spans = []
colors = {}
for node in data["nodes"]:
# entity_spans = [m.span() for m in re.finditer(re.escape(node["id"]), full_text)]
entity_spans = find_token_indices(doc, node["id"], full_text)
for dataentity in entity_spans:
start = dataentity["start"]
end = dataentity["end"]
print("entity spans:", entity_spans)
if start < len(doc) and end <= len(doc):
span = Span(doc, start, end, label=node["type"])
# print(span)
spans.append(span)
if node["type"] not in colors:
colors[node["type"]] = get_random_light_color()
for span in spans:
print(f"Span: {span.text}, Label: {span.label_}")
doc.set_ents(spans, default="unmodified")
doc.spans["sc"] = spans
options = {
"colors": colors,
"ents": list(colors.keys()),
"style": "ent",
"manual": True
}
html = displacy.render(doc, style="span", options=options)
return html
def create_graph(json_data):
G = nx.Graph()
for node in json_data['nodes']:
G.add_node(node['id'], title=f"{node['type']}: {node['detailed_type']}")
for edge in json_data['edges']:
G.add_edge(edge['from'], edge['to'], title=edge['label'], label=edge['label'])
nt = Network(
width="720px",
height="600px",
directed=True,
notebook=False,
# bgcolor="#111827",
# font_color="white"
bgcolor="#FFFFFF",
font_color="#111827"
)
nt.from_nx(G)
nt.barnes_hut(
gravity=-3000,
central_gravity=0.3,
spring_length=50,
spring_strength=0.001,
damping=0.09,
overlap=0,
)
# Customize edge appearance
# for edge in nt.edges:
# edge['font'] = {'size': 12, 'color': '#FFD700', 'face': 'Arial'} # Removed strokeWidth
# edge['color'] = {'color': '#FF4500', 'highlight': '#FF4500'}
# edge['width'] = 1
# edge['arrows'] = {'to': {'enabled': True, 'type': 'arrow'}}
# edge['smooth'] = {'type': 'curvedCW', 'roundness': 0.2}
html = nt.generate_html()
# need to remove ' from HTML
html = html.replace("'", '"')
# return html
return f"""<iframe style="width: 140%; height: 620px; margin: 0 auto;" name="result"
allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;"
sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
def process_and_visualize(text, model):
if not text or not model:
raise gr.Error("Text and model must be provided.")
json_data = extract(text, model)
# json_data = json_example
print(json_data)
entities_viz = create_custom_entity_viz(json_data, text)
graph_html = create_graph(json_data)
return graph_html, entities_viz, json_data
with gr.Blocks(title="Phi-3 Mini 4k Instruct Graph (by Emergent Methods") as demo:
gr.Markdown("# Phi-3 Mini 4k Instruct Graph (by Emergent Methods)")
gr.Markdown("Extract a JSON graph from a text input and visualize it.")
with gr.Row():
with gr.Column(scale=1):
input_model = gr.Dropdown(
MODEL_LIST, label="Model",
# value=MODEL_LIST[0]
)
input_text = gr.TextArea(label="Text", info="The text to be extracted")
examples = gr.Examples(
examples=[
handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing
lead singer Steven Tyler's unrecoverable vocal cord injury.
The decision comes after months of unsuccessful treatment for Tyler's fractured larynx,
which he suffered in September 2023."""),
handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual
court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI)
in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe,
pleaded not guilty to the charges."""),
],
inputs=input_text
)
submit_button = gr.Button("Extract and Visualize")
with gr.Column(scale=1):
output_entity_viz = gr.HTML(label="Entities Visualization", show_label=True)
output_graph = gr.HTML(label="Graph Visualization", show_label=True)
# output_json = gr.JSON(label="JSON Graph")
submit_button.click(
fn=process_and_visualize,
inputs=[input_text, input_model],
outputs=[output_graph, output_entity_viz]
)
demo.launch(share=False) |