File size: 7,890 Bytes
b6cf9eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import spaces
import gradio as gr
from phi3_instruct_graph import MODEL_LIST, Phi3InstructGraph
from textwrap import dedent
import rapidjson
import spaces
from pyvis.network import Network
import networkx as nx
import spacy
from spacy import displacy
from spacy.tokens import Span
import random

json_example = {'nodes': [{'id': 'Aerosmith', 'type': 'organization', 'detailed_type': 'rock band'}, {'id': 'Steven Tyler', 'type': 'person', 'detailed_type': 'lead singer'}, {'id': 'vocal cord injury', 'type': 'medical condition', 'detailed_type': 'fractured larynx'}, {'id': 'retirement', 'type': 'event', 'detailed_type': 'announcement'}, {'id': 'touring', 'type': 'activity', 'detailed_type': 'musical performance'}, {'id': 'September 2023', 'type': 'date', 'detailed_type': 'specific time'}], 'edges': [{'from': 'Aerosmith', 'to': 'Steven Tyler', 'label': 'led by'}, {'from': 'Steven Tyler', 'to': 'vocal cord injury', 'label': 'suffered'}, {'from': 'vocal cord injury', 'to': 'retirement', 'label': 'caused'}, {'from': 'retirement', 'to': 'touring', 'label': 'ended'}, {'from': 'vocal cord injury', 'to': 'September 2023', 'label': 'occurred in'}]}

@spaces.GPU
def extract(text, model):
    model = Phi3InstructGraph(model=model)    
    result = model.extract(text)
    return rapidjson.loads(result)

def handle_text(text):
    return " ".join(text.split())

def get_random_color():
  return f"#{random.randint(0, 0xFFFFFF):06x}"

def get_random_light_color():
    # Generate higher RGB values to ensure a lighter color
    r = random.randint(128, 255)
    g = random.randint(128, 255)
    b = random.randint(128, 255)
    return f"#{r:02x}{g:02x}{b:02x}"

def get_random_color():
    return f"#{random.randint(0, 0xFFFFFF):06x}"

def find_token_indices(doc, substring, text):
    result = []
    start_index = text.find(substring)
    
    while start_index != -1:
        end_index = start_index + len(substring)
        start_token = None
        end_token = None

        for token in doc:
            if token.idx == start_index:
                start_token = token.i
            if token.idx + len(token) == end_index:
                end_token = token.i + 1

        if start_token is None or end_token is None:
            print(f"Token boundaries not found for '{substring}' at index {start_index}")
        else:
            result.append({
                "start": start_token,
                "end": end_token
            })
        
        # Search for next occurrence
        start_index = text.find(substring, end_index)

    if not result:
        print(f"Token boundaries not found for '{substring}'")

    return result


def create_custom_entity_viz(data, full_text):
    nlp = spacy.blank("xx")
    doc = nlp(full_text)

    spans = []
    colors = {}
    for node in data["nodes"]:
    #   entity_spans = [m.span() for m in re.finditer(re.escape(node["id"]), full_text)]
        entity_spans = find_token_indices(doc, node["id"], full_text)
        for dataentity in entity_spans:
            start = dataentity["start"]
            end = dataentity["end"]
            
            print("entity spans:", entity_spans)
            if start < len(doc) and end <= len(doc):
                span = Span(doc, start, end, label=node["type"])
                
                # print(span)
                spans.append(span)
                if node["type"] not in colors:
                    colors[node["type"]] = get_random_light_color()

    for span in spans:
        print(f"Span: {span.text}, Label: {span.label_}")

    doc.set_ents(spans, default="unmodified")
    doc.spans["sc"] = spans

    options = {
        "colors": colors,
        "ents": list(colors.keys()),
        "style": "ent",
        "manual": True
    }

    html = displacy.render(doc, style="span", options=options)
    return html


def create_graph(json_data):
    G = nx.Graph()

    for node in json_data['nodes']:
        G.add_node(node['id'], title=f"{node['type']}: {node['detailed_type']}")

    for edge in json_data['edges']:
        G.add_edge(edge['from'], edge['to'], title=edge['label'], label=edge['label'])

    nt = Network(
        width="720px",
        height="600px",
        directed=True,
        notebook=False,
        # bgcolor="#111827", 
        # font_color="white"
        bgcolor="#FFFFFF",
        font_color="#111827"
    )
    nt.from_nx(G)
    nt.barnes_hut(
        gravity=-3000,
        central_gravity=0.3,
        spring_length=50,
        spring_strength=0.001,
        damping=0.09,
        overlap=0,
    )

    # Customize edge appearance
    # for edge in nt.edges:
    #     edge['font'] = {'size': 12, 'color': '#FFD700', 'face': 'Arial'}  # Removed strokeWidth
    #     edge['color'] = {'color': '#FF4500', 'highlight': '#FF4500'}
    #     edge['width'] = 1
    #     edge['arrows'] = {'to': {'enabled': True, 'type': 'arrow'}}
        # edge['smooth'] = {'type': 'curvedCW', 'roundness': 0.2}

    html = nt.generate_html()
    # need to remove ' from HTML
    html = html.replace("'", '"')
    # return html

    return f"""<iframe style="width: 140%; height: 620px; margin: 0 auto;" name="result" 
        allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;" 
        sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups 
        allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" 
        allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
  

def process_and_visualize(text, model):
    if not text or not model:
        raise gr.Error("Text and model must be provided.")
    json_data = extract(text, model)
    # json_data = json_example
    print(json_data)
    entities_viz = create_custom_entity_viz(json_data, text)
    
    graph_html = create_graph(json_data)
    return graph_html, entities_viz, json_data



with gr.Blocks(title="Phi-3 Mini 4k Instruct Graph (by Emergent Methods") as demo:
    gr.Markdown("# Phi-3 Mini 4k Instruct Graph (by Emergent Methods)")
    gr.Markdown("Extract a JSON graph from a text input and visualize it.")

    with gr.Row():
        with gr.Column(scale=1):
            input_model = gr.Dropdown(
                MODEL_LIST, label="Model",
                # value=MODEL_LIST[0]
            )
            input_text = gr.TextArea(label="Text", info="The text to be extracted")

            examples = gr.Examples(
                examples=[
                    handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing 
                    lead singer Steven Tyler's unrecoverable vocal cord injury. 
                    The decision comes after months of unsuccessful treatment for Tyler's fractured larynx, 
                    which he suffered in September 2023."""),
                    handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual 
                    court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI) 
                    in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe, 
                    pleaded not guilty to the charges."""),
                ],
                inputs=input_text
            )

            submit_button = gr.Button("Extract and Visualize")
    
        with gr.Column(scale=1):
            output_entity_viz = gr.HTML(label="Entities Visualization", show_label=True)
            output_graph = gr.HTML(label="Graph Visualization", show_label=True)            
            # output_json = gr.JSON(label="JSON Graph")

            submit_button.click(
                fn=process_and_visualize, 
                inputs=[input_text, input_model],
                outputs=[output_graph, output_entity_viz]
            )

demo.launch(share=False)