wagnercosta commited on
Commit
b6cf9eb
1 Parent(s): d289335

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +210 -0
  2. phi3_instruct_graph.py +98 -0
main.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ from phi3_instruct_graph import MODEL_LIST, Phi3InstructGraph
4
+ from textwrap import dedent
5
+ import rapidjson
6
+ import spaces
7
+ from pyvis.network import Network
8
+ import networkx as nx
9
+ import spacy
10
+ from spacy import displacy
11
+ from spacy.tokens import Span
12
+ import random
13
+
14
+ json_example = {'nodes': [{'id': 'Aerosmith', 'type': 'organization', 'detailed_type': 'rock band'}, {'id': 'Steven Tyler', 'type': 'person', 'detailed_type': 'lead singer'}, {'id': 'vocal cord injury', 'type': 'medical condition', 'detailed_type': 'fractured larynx'}, {'id': 'retirement', 'type': 'event', 'detailed_type': 'announcement'}, {'id': 'touring', 'type': 'activity', 'detailed_type': 'musical performance'}, {'id': 'September 2023', 'type': 'date', 'detailed_type': 'specific time'}], 'edges': [{'from': 'Aerosmith', 'to': 'Steven Tyler', 'label': 'led by'}, {'from': 'Steven Tyler', 'to': 'vocal cord injury', 'label': 'suffered'}, {'from': 'vocal cord injury', 'to': 'retirement', 'label': 'caused'}, {'from': 'retirement', 'to': 'touring', 'label': 'ended'}, {'from': 'vocal cord injury', 'to': 'September 2023', 'label': 'occurred in'}]}
15
+
16
+ @spaces.GPU
17
+ def extract(text, model):
18
+ model = Phi3InstructGraph(model=model)
19
+ result = model.extract(text)
20
+ return rapidjson.loads(result)
21
+
22
+ def handle_text(text):
23
+ return " ".join(text.split())
24
+
25
+ def get_random_color():
26
+ return f"#{random.randint(0, 0xFFFFFF):06x}"
27
+
28
+ def get_random_light_color():
29
+ # Generate higher RGB values to ensure a lighter color
30
+ r = random.randint(128, 255)
31
+ g = random.randint(128, 255)
32
+ b = random.randint(128, 255)
33
+ return f"#{r:02x}{g:02x}{b:02x}"
34
+
35
+ def get_random_color():
36
+ return f"#{random.randint(0, 0xFFFFFF):06x}"
37
+
38
+ def find_token_indices(doc, substring, text):
39
+ result = []
40
+ start_index = text.find(substring)
41
+
42
+ while start_index != -1:
43
+ end_index = start_index + len(substring)
44
+ start_token = None
45
+ end_token = None
46
+
47
+ for token in doc:
48
+ if token.idx == start_index:
49
+ start_token = token.i
50
+ if token.idx + len(token) == end_index:
51
+ end_token = token.i + 1
52
+
53
+ if start_token is None or end_token is None:
54
+ print(f"Token boundaries not found for '{substring}' at index {start_index}")
55
+ else:
56
+ result.append({
57
+ "start": start_token,
58
+ "end": end_token
59
+ })
60
+
61
+ # Search for next occurrence
62
+ start_index = text.find(substring, end_index)
63
+
64
+ if not result:
65
+ print(f"Token boundaries not found for '{substring}'")
66
+
67
+ return result
68
+
69
+
70
+ def create_custom_entity_viz(data, full_text):
71
+ nlp = spacy.blank("xx")
72
+ doc = nlp(full_text)
73
+
74
+ spans = []
75
+ colors = {}
76
+ for node in data["nodes"]:
77
+ # entity_spans = [m.span() for m in re.finditer(re.escape(node["id"]), full_text)]
78
+ entity_spans = find_token_indices(doc, node["id"], full_text)
79
+ for dataentity in entity_spans:
80
+ start = dataentity["start"]
81
+ end = dataentity["end"]
82
+
83
+ print("entity spans:", entity_spans)
84
+ if start < len(doc) and end <= len(doc):
85
+ span = Span(doc, start, end, label=node["type"])
86
+
87
+ # print(span)
88
+ spans.append(span)
89
+ if node["type"] not in colors:
90
+ colors[node["type"]] = get_random_light_color()
91
+
92
+ for span in spans:
93
+ print(f"Span: {span.text}, Label: {span.label_}")
94
+
95
+ doc.set_ents(spans, default="unmodified")
96
+ doc.spans["sc"] = spans
97
+
98
+ options = {
99
+ "colors": colors,
100
+ "ents": list(colors.keys()),
101
+ "style": "ent",
102
+ "manual": True
103
+ }
104
+
105
+ html = displacy.render(doc, style="span", options=options)
106
+ return html
107
+
108
+
109
+ def create_graph(json_data):
110
+ G = nx.Graph()
111
+
112
+ for node in json_data['nodes']:
113
+ G.add_node(node['id'], title=f"{node['type']}: {node['detailed_type']}")
114
+
115
+ for edge in json_data['edges']:
116
+ G.add_edge(edge['from'], edge['to'], title=edge['label'], label=edge['label'])
117
+
118
+ nt = Network(
119
+ width="720px",
120
+ height="600px",
121
+ directed=True,
122
+ notebook=False,
123
+ # bgcolor="#111827",
124
+ # font_color="white"
125
+ bgcolor="#FFFFFF",
126
+ font_color="#111827"
127
+ )
128
+ nt.from_nx(G)
129
+ nt.barnes_hut(
130
+ gravity=-3000,
131
+ central_gravity=0.3,
132
+ spring_length=50,
133
+ spring_strength=0.001,
134
+ damping=0.09,
135
+ overlap=0,
136
+ )
137
+
138
+ # Customize edge appearance
139
+ # for edge in nt.edges:
140
+ # edge['font'] = {'size': 12, 'color': '#FFD700', 'face': 'Arial'} # Removed strokeWidth
141
+ # edge['color'] = {'color': '#FF4500', 'highlight': '#FF4500'}
142
+ # edge['width'] = 1
143
+ # edge['arrows'] = {'to': {'enabled': True, 'type': 'arrow'}}
144
+ # edge['smooth'] = {'type': 'curvedCW', 'roundness': 0.2}
145
+
146
+ html = nt.generate_html()
147
+ # need to remove ' from HTML
148
+ html = html.replace("'", '"')
149
+ # return html
150
+
151
+ return f"""<iframe style="width: 140%; height: 620px; margin: 0 auto;" name="result"
152
+ allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;"
153
+ sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups
154
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
155
+ allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
156
+
157
+
158
+ def process_and_visualize(text, model):
159
+ if not text or not model:
160
+ raise gr.Error("Text and model must be provided.")
161
+ json_data = extract(text, model)
162
+ # json_data = json_example
163
+ print(json_data)
164
+ entities_viz = create_custom_entity_viz(json_data, text)
165
+
166
+ graph_html = create_graph(json_data)
167
+ return graph_html, entities_viz, json_data
168
+
169
+
170
+
171
+ with gr.Blocks(title="Phi-3 Mini 4k Instruct Graph (by Emergent Methods") as demo:
172
+ gr.Markdown("# Phi-3 Mini 4k Instruct Graph (by Emergent Methods)")
173
+ gr.Markdown("Extract a JSON graph from a text input and visualize it.")
174
+
175
+ with gr.Row():
176
+ with gr.Column(scale=1):
177
+ input_model = gr.Dropdown(
178
+ MODEL_LIST, label="Model",
179
+ # value=MODEL_LIST[0]
180
+ )
181
+ input_text = gr.TextArea(label="Text", info="The text to be extracted")
182
+
183
+ examples = gr.Examples(
184
+ examples=[
185
+ handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing
186
+ lead singer Steven Tyler's unrecoverable vocal cord injury.
187
+ The decision comes after months of unsuccessful treatment for Tyler's fractured larynx,
188
+ which he suffered in September 2023."""),
189
+ handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual
190
+ court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI)
191
+ in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe,
192
+ pleaded not guilty to the charges."""),
193
+ ],
194
+ inputs=input_text
195
+ )
196
+
197
+ submit_button = gr.Button("Extract and Visualize")
198
+
199
+ with gr.Column(scale=1):
200
+ output_entity_viz = gr.HTML(label="Entities Visualization", show_label=True)
201
+ output_graph = gr.HTML(label="Graph Visualization", show_label=True)
202
+ # output_json = gr.JSON(label="JSON Graph")
203
+
204
+ submit_button.click(
205
+ fn=process_and_visualize,
206
+ inputs=[input_text, input_model],
207
+ outputs=[output_graph, output_entity_viz]
208
+ )
209
+
210
+ demo.launch(share=False)
phi3_instruct_graph.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
3
+ from textwrap import dedent
4
+ from huggingface_hub import login
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+ login(
10
+ token=os.environ["HF_TOKEN"],
11
+ )
12
+
13
+ MODEL_LIST = [
14
+ "EmergentMethods/Phi-3-mini-4k-instruct-graph",
15
+ "EmergentMethods/Phi-3-mini-128k-instruct-graph",
16
+ "EmergentMethods/Phi-3-medium-128k-instruct-graph"
17
+ ]
18
+
19
+ torch.random.manual_seed(0)
20
+
21
+ class Phi3InstructGraph:
22
+ def __init__(self, model = "EmergentMethods/Phi-3-mini-4k-instruct-graph"):
23
+ if model not in MODEL_LIST:
24
+ raise ValueError(f"model must be one of {MODEL_LIST}")
25
+
26
+ self.model_path = model
27
+ self.model = AutoModelForCausalLM.from_pretrained(
28
+ self.model_path,
29
+ device_map="cuda",
30
+ torch_dtype="auto",
31
+ trust_remote_code=True,
32
+ )
33
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
34
+ self.pipe = pipeline(
35
+ "text-generation",
36
+ model=self.model,
37
+ tokenizer=self.tokenizer,
38
+ )
39
+
40
+ def _generate(self, messages):
41
+ generation_args = {
42
+ "max_new_tokens": 2000,
43
+ "return_full_text": False,
44
+ "temperature": 0.0,
45
+ "do_sample": False,
46
+ }
47
+
48
+ return self.pipe(messages, **generation_args)
49
+
50
+ def _get_messages(self, text):
51
+ messages = [
52
+ {
53
+ "role": "system",
54
+ "content": dedent("""\n
55
+ A chat between a curious user and an artificial intelligence Assistant. The Assistant is an expert at identifying entities and relationships in text. The Assistant responds in JSON output only.
56
+
57
+ The User provides text in the format:
58
+
59
+ -------Text begin-------
60
+ <User provided text>
61
+ -------Text end-------
62
+
63
+ The Assistant follows the following steps before replying to the User:
64
+
65
+ 1. **identify the most important entities** The Assistant identifies the most important entities in the text. These entities are listed in the JSON output under the key "nodes", they follow the structure of a list of dictionaries where each dict is:
66
+
67
+ "nodes":[{"id": <entity N>, "type": <type>, "detailed_type": <detailed type>}, ...]
68
+
69
+ where "type": <type> is a broad categorization of the entity. "detailed type": <detailed_type> is a very descriptive categorization of the entity.
70
+
71
+ 2. **determine relationships** The Assistant uses the text between -------Text begin------- and -------Text end------- to determine the relationships between the entities identified in the "nodes" list defined above. These relationships are called "edges" and they follow the structure of:
72
+
73
+ "edges":[{"from": <entity 1>, "to": <entity 2>, "label": <relationship>}, ...]
74
+
75
+ The <entity N> must correspond to the "id" of an entity in the "nodes" list.
76
+
77
+ The Assistant never repeats the same node twice. The Assistant never repeats the same edge twice.
78
+ The Assistant responds to the User in JSON only, according to the following JSON schema:
79
+
80
+ {"type":"object","properties":{"nodes":{"type":"array","items":{"type":"object","properties":{"id":{"type":"string"},"type":{"type":"string"},"detailed_type":{"type":"string"}},"required":["id","type","detailed_type"],"additionalProperties":false}},"edges":{"type":"array","items":{"type":"object","properties":{"from":{"type":"string"},"to":{"type":"string"},"label":{"type":"string"}},"required":["from","to","label"],"additionalProperties":false}}},"required":["nodes","edges"],"additionalProperties":false}
81
+ """)
82
+ },
83
+ {
84
+ "role": "user",
85
+ "content": dedent(f"""\n
86
+ -------Text begin-------
87
+ {text}
88
+ -------Text end-------
89
+ """)
90
+ }
91
+ ]
92
+ return messages
93
+
94
+
95
+ def extract(self, text):
96
+ messages = self._get_messages(text)
97
+ pipe_output = self._generate(messages)
98
+ return pipe_output[0]["generated_text"]