vietexob commited on
Commit
9021458
Β·
1 Parent(s): 855980b

Multiple updates

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. app.py +117 -68
  3. app-backup.py β†’ app_old.py +0 -0
  4. phi3_instruct_graph.py +32 -24
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Graph Mind
3
- emoji: πŸ‘€
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: gradio
 
1
  ---
2
+ title: Text2Graph
3
+ emoji: 🌐
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: gradio
app.py CHANGED
@@ -1,19 +1,20 @@
1
  # import spaces
 
 
 
 
 
2
  import gradio as gr
 
 
3
  from phi3_instruct_graph import Phi3InstructGraph
4
- import rapidjson
5
  from pyvis.network import Network
6
- import networkx as nx
7
- import spacy
8
  from spacy import displacy
9
  from spacy.tokens import Span
10
- import random
11
- import os
12
- import pickle
13
 
14
  # Constants
15
- TITLE = "🌐 GraphMind: Phi-3 Instruct Graph Explorer"
16
- SUBTITLE = "✨ Extract and visualize knowledge graphs from any text in multiple languages"
17
 
18
  # Basic CSS for styling
19
  CUSTOM_CSS = """
@@ -29,40 +30,63 @@ EXAMPLE_CACHE_FILE = os.path.join(CACHE_DIR, "first_example_cache.pkl")
29
  # Create cache directory if it doesn't exist
30
  os.makedirs(CACHE_DIR, exist_ok=True)
31
 
32
- # Color utilities
33
  def get_random_light_color():
 
 
 
 
34
  r = random.randint(140, 255)
35
  g = random.randint(140, 255)
36
  b = random.randint(140, 255)
 
37
  return f"#{r:02x}{g:02x}{b:02x}"
38
 
39
- # Text preprocessing
40
- def handle_text(text):
 
 
 
 
 
 
 
41
  return " ".join(text.split())
42
 
43
- # Main processing functions
44
  # @spaces.GPU
45
- def extract(text):
 
 
 
 
 
 
 
46
  try:
47
- model = Phi3InstructGraph()
48
  result = model.extract(text)
49
  return rapidjson.loads(result)
50
  except Exception as e:
51
- raise gr.Error(f"Extraction error: {str(e)}")
52
 
53
  def find_token_indices(doc, substring, text):
 
 
 
 
 
54
  result = []
55
- start_index = text.find(substring)
56
 
57
- while start_index != -1:
58
- end_index = start_index + len(substring)
59
  start_token = None
60
  end_token = None
61
 
62
  for token in doc:
63
- if token.idx == start_index:
64
  start_token = token.i
65
- if token.idx + len(token) == end_index:
66
  end_token = token.i + 1
67
 
68
  if start_token is not None and end_token is not None:
@@ -72,35 +96,41 @@ def find_token_indices(doc, substring, text):
72
  })
73
 
74
  # Search for next occurrence
75
- start_index = text.find(substring, end_index)
76
 
77
  return result
78
 
79
  def create_custom_entity_viz(data, full_text):
 
 
 
 
80
  nlp = spacy.blank("xx")
81
  doc = nlp(full_text)
82
-
83
  spans = []
84
  colors = {}
 
85
  for node in data["nodes"]:
86
  entity_spans = find_token_indices(doc, node["id"], full_text)
87
- for dataentity in entity_spans:
88
- start = dataentity["start"]
89
- end = dataentity["end"]
 
90
 
91
  if start < len(doc) and end <= len(doc):
92
  # Check for overlapping spans
93
  overlapping = any(s.start < end and start < s.end for s in spans)
 
94
  if not overlapping:
95
  node_type = node.get("type", "Entity")
96
  span = Span(doc, start, end, label=node_type)
97
  spans.append(span)
 
98
  if node_type not in colors:
99
  colors[node_type] = get_random_light_color()
100
 
101
  doc.set_ents(spans, default="unmodified")
102
  doc.spans["sc"] = spans
103
-
104
  options = {
105
  "colors": colors,
106
  "ents": list(colors.keys()),
@@ -111,24 +141,30 @@ def create_custom_entity_viz(data, full_text):
111
  html = displacy.render(doc, style="span", options=options)
112
  # Add custom styling to the entity visualization
113
  styled_html = f"""
114
- <div style="padding: 20px; border-radius: 12px; background-color: white; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);">
115
  {html}
116
  </div>
117
  """
 
118
  return styled_html
119
 
120
  def create_graph(json_data):
 
 
 
 
121
  G = nx.Graph()
122
 
123
- # Add nodes with tooltips - with error handling for missing keys
124
  for node in json_data['nodes']:
125
  # Get node type with fallback
126
- node_type = node.get("type", "Entity")
 
127
  # Get detailed type with fallback
128
- detailed_type = node.get("detailed_type", node_type)
129
 
130
  # Use node ID and type info for the tooltip
131
- G.add_node(node['id'], title=f"{node_type}: {detailed_type}")
132
 
133
  # Add edges with labels
134
  for edge in json_data['edges']:
@@ -138,18 +174,18 @@ def create_graph(json_data):
138
  G.add_edge(edge['from'], edge['to'], title=label, label=label)
139
 
140
  # Create network visualization
141
- nt = Network(
142
  width="100%",
143
  height="700px",
144
- directed=True,
145
  notebook=False,
146
  bgcolor="#f8fafc",
147
  font_color="#1e293b"
148
  )
149
 
150
  # Configure network display
151
- nt.from_nx(G)
152
- nt.barnes_hut(
153
  gravity=-3000,
154
  central_gravity=0.3,
155
  spring_length=50,
@@ -159,21 +195,21 @@ def create_graph(json_data):
159
  )
160
 
161
  # Customize edge appearance
162
- for edge in nt.edges:
163
  edge['width'] = 2
164
- edge['arrows'] = {'to': {'enabled': True, 'type': 'arrow'}}
165
  edge['color'] = {'color': '#6366f1', 'highlight': '#4f46e5'}
166
  edge['font'] = {'size': 12, 'color': '#4b5563', 'face': 'Arial'}
167
 
168
  # Customize node appearance
169
- for node in nt.nodes:
170
  node['color'] = {'background': '#e0e7ff', 'border': '#6366f1', 'highlight': {'background': '#c7d2fe', 'border': '#4f46e5'}}
171
  node['font'] = {'size': 14, 'color': '#1e293b'}
172
  node['shape'] = 'dot'
173
  node['size'] = 25
174
 
175
  # Generate HTML with iframe to isolate styles
176
- html = nt.generate_html()
177
  html = html.replace("'", '"')
178
 
179
  return f"""<iframe style="width: 100%; height: 700px; margin: 0 auto; border-radius: 12px; box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -4px rgba(0, 0, 0, 0.1);"
@@ -183,6 +219,10 @@ def create_graph(json_data):
183
  allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
184
 
185
  def process_and_visualize(text, progress=gr.Progress()):
 
 
 
 
186
  if not text:
187
  raise gr.Error("⚠️ Text must be provided!")
188
 
@@ -200,10 +240,10 @@ def process_and_visualize(text, progress=gr.Progress()):
200
  return cache_data["graph_html"], cache_data["entities_viz"], cache_data["json_data"], cache_data["stats"]
201
  except Exception as e:
202
  print(f"Cache loading error: {str(e)}")
203
- # Continue with normal processing if cache fails
204
-
205
  progress(0, desc="Starting extraction...")
206
- json_data = extract(text)
207
 
208
  progress(0.5, desc="Creating entity visualization...")
209
  entities_viz = create_custom_entity_viz(json_data, text)
@@ -232,33 +272,40 @@ def process_and_visualize(text, progress=gr.Progress()):
232
  progress(1.0, desc="Complete!")
233
  return graph_html, entities_viz, json_data, stats
234
 
235
- # Example texts in different languages
236
  EXAMPLES = [
237
- [handle_text("""The family of Azerbaijan President Ilham Aliyev leads a charmed, glamorous life, thanks in part to financial interests in almost every sector of the economy. His wife, Mehriban, comes from the privileged and powerful Pashayev family that owns banks, insurance and construction companies, a television station and a line of cosmetics. She has led the Heydar Aliyev Foundation, Azerbaijan’s pre-eminent charity behind the construction of schools, hospitals and the country’s major sports complex. Their eldest daughter, Leyla, editor of Baku magazine, and her sister, Arzu, have financial stakes in a firm that won rights to mine for gold in the western village of Chovdar and Azerfon, the country’s largest mobile phone business. Arzu is also a significant shareholder in SW Holding, which controls nearly every operation related to Azerbaijan Airlines (β€œAzal”), from meals to airport taxis. Both sisters and brother Heydar own property in Dubai valued at roughly $75 million in 2010; Heydar is the legal owner of nine luxury mansions in Dubai purchased for some $44 million.""")],
238
-
239
- [handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing
240
- lead singer Steven Tyler's unrecoverable vocal cord injury.
241
- The decision comes after months of unsuccessful treatment for Tyler's fractured larynx,
242
- which he suffered in September 2023.""")],
 
 
 
 
 
 
243
 
244
  [handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual
245
- court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI)
246
- in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe,
247
- pleaded not guilty to the charges.""")],
248
  ]
249
 
250
- # Function to preprocess the first example when the app starts
251
  def generate_first_example_cache():
252
- """Generate cache for the first example if it doesn't exist"""
 
 
 
253
  if not os.path.exists(EXAMPLE_CACHE_FILE):
254
  print("Generating cache for first example...")
 
255
  try:
256
  text = EXAMPLES[0][0]
257
- # model = MODEL_LIST[0] if MODEL_LIST else None
258
-
259
- # if model:
260
  # Extract data
261
- json_data = extract(text, model)
262
  entities_viz = create_custom_entity_viz(json_data, text)
263
  graph_html = create_graph(json_data)
264
 
@@ -267,17 +314,18 @@ def generate_first_example_cache():
267
  stats = f"πŸ“Š Extracted {node_count} entities and {edge_count} relationships"
268
 
269
  # Save to cache
270
- cache_data = {
271
  "graph_html": graph_html,
272
  "entities_viz": entities_viz,
273
  "json_data": json_data,
274
  "stats": stats
275
  }
 
276
  with open(EXAMPLE_CACHE_FILE, 'wb') as f:
277
- pickle.dump(cache_data, f)
278
-
279
  print("First example cache generated successfully")
280
- return cache_data
 
281
  except Exception as e:
282
  print(f"Error generating first example cache: {str(e)}")
283
  else:
@@ -291,6 +339,10 @@ def generate_first_example_cache():
291
  return None
292
 
293
  def create_ui():
 
 
 
 
294
  # Try to generate/load the first example cache
295
  first_example_cache = generate_first_example_cache()
296
 
@@ -299,9 +351,6 @@ def create_ui():
299
  gr.Markdown(f"# {TITLE}")
300
  gr.Markdown(f"{SUBTITLE}")
301
 
302
- with gr.Row():
303
- gr.Markdown("🌍 **Multilingual Support Available**")
304
-
305
  # Main content area - redesigned layout
306
  with gr.Row():
307
  # Left panel - Input controls
@@ -381,10 +430,10 @@ def create_ui():
381
 
382
  # Footer
383
  gr.Markdown("---")
384
- gr.Markdown("πŸ“‹ **Instructions:** Enter text in any language, select a model, and click 'Extract & Visualize' to generate a knowledge graph.")
385
- gr.Markdown("πŸ› οΈ Powered by Phi-3 Instruct Graph | Emergent Methods")
386
 
387
  return demo
388
 
389
  demo = create_ui()
390
- demo.launch(share=False)
 
1
  # import spaces
2
+ import os
3
+ import spacy
4
+ import pickle
5
+ import random
6
+ import rapidjson
7
  import gradio as gr
8
+ import networkx as nx
9
+
10
  from phi3_instruct_graph import Phi3InstructGraph
 
11
  from pyvis.network import Network
 
 
12
  from spacy import displacy
13
  from spacy.tokens import Span
 
 
 
14
 
15
  # Constants
16
+ TITLE = "🌐 Text2Graph: Extract Knowledge Graphs from Natural Language"
17
+ SUBTITLE = "✨ Extract and visualize knowledge graphs from texts in any language!"
18
 
19
  # Basic CSS for styling
20
  CUSTOM_CSS = """
 
30
  # Create cache directory if it doesn't exist
31
  os.makedirs(CACHE_DIR, exist_ok=True)
32
 
 
33
  def get_random_light_color():
34
+ """
35
+ Color utilities
36
+ """
37
+
38
  r = random.randint(140, 255)
39
  g = random.randint(140, 255)
40
  b = random.randint(140, 255)
41
+
42
  return f"#{r:02x}{g:02x}{b:02x}"
43
 
44
+ def handle_text(text=""):
45
+ """
46
+ Text preprocessing
47
+ """
48
+
49
+ # Catch empty text
50
+ if not text:
51
+ return ""
52
+
53
  return " ".join(text.split())
54
 
55
+ #
56
  # @spaces.GPU
57
+ def extract_kg(text=""):
58
+ """
59
+ Extract knowledge graph from text
60
+ """
61
+
62
+ # Catch empty text
63
+ if not text:
64
+ raise gr.Error("⚠️ Text must be provided!")
65
  try:
66
+ model = Phi3InstructGraph()
67
  result = model.extract(text)
68
  return rapidjson.loads(result)
69
  except Exception as e:
70
+ raise gr.Error(f"❌ Extraction error: {str(e)}")
71
 
72
  def find_token_indices(doc, substring, text):
73
+ """
74
+ Find token indices for a given substring in the text
75
+ based on the provided spaCy doc.
76
+ """
77
+
78
  result = []
79
+ start_idx = text.find(substring)
80
 
81
+ while start_idx != -1:
82
+ end_idx = start_idx + len(substring)
83
  start_token = None
84
  end_token = None
85
 
86
  for token in doc:
87
+ if token.idx == start_idx:
88
  start_token = token.i
89
+ if token.idx + len(token) == end_idx:
90
  end_token = token.i + 1
91
 
92
  if start_token is not None and end_token is not None:
 
96
  })
97
 
98
  # Search for next occurrence
99
+ start_idx = text.find(substring, end_idx)
100
 
101
  return result
102
 
103
  def create_custom_entity_viz(data, full_text):
104
+ """
105
+ Create custom entity visualization using spaCy's displacy
106
+ """
107
+
108
  nlp = spacy.blank("xx")
109
  doc = nlp(full_text)
 
110
  spans = []
111
  colors = {}
112
+
113
  for node in data["nodes"]:
114
  entity_spans = find_token_indices(doc, node["id"], full_text)
115
+
116
+ for entity in entity_spans:
117
+ start = entity["start"]
118
+ end = entity["end"]
119
 
120
  if start < len(doc) and end <= len(doc):
121
  # Check for overlapping spans
122
  overlapping = any(s.start < end and start < s.end for s in spans)
123
+
124
  if not overlapping:
125
  node_type = node.get("type", "Entity")
126
  span = Span(doc, start, end, label=node_type)
127
  spans.append(span)
128
+
129
  if node_type not in colors:
130
  colors[node_type] = get_random_light_color()
131
 
132
  doc.set_ents(spans, default="unmodified")
133
  doc.spans["sc"] = spans
 
134
  options = {
135
  "colors": colors,
136
  "ents": list(colors.keys()),
 
141
  html = displacy.render(doc, style="span", options=options)
142
  # Add custom styling to the entity visualization
143
  styled_html = f"""
144
+ <div style="padding: 20px; border-radius: 12px; background-color: gray; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);">
145
  {html}
146
  </div>
147
  """
148
+
149
  return styled_html
150
 
151
  def create_graph(json_data):
152
+ """
153
+ Create interactive knowledge graph using pyvis
154
+ """
155
+
156
  G = nx.Graph()
157
 
158
+ # Add nodes with tooltips and error handling for missing keys
159
  for node in json_data['nodes']:
160
  # Get node type with fallback
161
+ type = node.get("type", "Entity")
162
+
163
  # Get detailed type with fallback
164
+ detailed_type = node.get("detailed_type", type)
165
 
166
  # Use node ID and type info for the tooltip
167
+ G.add_node(node['id'], title=f"{type}: {detailed_type}")
168
 
169
  # Add edges with labels
170
  for edge in json_data['edges']:
 
174
  G.add_edge(edge['from'], edge['to'], title=label, label=label)
175
 
176
  # Create network visualization
177
+ network = Network(
178
  width="100%",
179
  height="700px",
180
+ directed=False,
181
  notebook=False,
182
  bgcolor="#f8fafc",
183
  font_color="#1e293b"
184
  )
185
 
186
  # Configure network display
187
+ network.from_nx(G)
188
+ network.barnes_hut(
189
  gravity=-3000,
190
  central_gravity=0.3,
191
  spring_length=50,
 
195
  )
196
 
197
  # Customize edge appearance
198
+ for edge in network.edges:
199
  edge['width'] = 2
200
+ edge['arrows'] = {'to': {'enabled': False, 'type': 'arrow'}}
201
  edge['color'] = {'color': '#6366f1', 'highlight': '#4f46e5'}
202
  edge['font'] = {'size': 12, 'color': '#4b5563', 'face': 'Arial'}
203
 
204
  # Customize node appearance
205
+ for node in network.nodes:
206
  node['color'] = {'background': '#e0e7ff', 'border': '#6366f1', 'highlight': {'background': '#c7d2fe', 'border': '#4f46e5'}}
207
  node['font'] = {'size': 14, 'color': '#1e293b'}
208
  node['shape'] = 'dot'
209
  node['size'] = 25
210
 
211
  # Generate HTML with iframe to isolate styles
212
+ html = network.generate_html()
213
  html = html.replace("'", '"')
214
 
215
  return f"""<iframe style="width: 100%; height: 700px; margin: 0 auto; border-radius: 12px; box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -4px rgba(0, 0, 0, 0.1);"
 
219
  allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
220
 
221
  def process_and_visualize(text, progress=gr.Progress()):
222
+ """
223
+ Process text and visualize knowledge graph and entities
224
+ """
225
+
226
  if not text:
227
  raise gr.Error("⚠️ Text must be provided!")
228
 
 
240
  return cache_data["graph_html"], cache_data["entities_viz"], cache_data["json_data"], cache_data["stats"]
241
  except Exception as e:
242
  print(f"Cache loading error: {str(e)}")
243
+
244
+ # Continue with normal processing if cache fails
245
  progress(0, desc="Starting extraction...")
246
+ json_data = extract_kg(text)
247
 
248
  progress(0.5, desc="Creating entity visualization...")
249
  entities_viz = create_custom_entity_viz(json_data, text)
 
272
  progress(1.0, desc="Complete!")
273
  return graph_html, entities_viz, json_data, stats
274
 
275
+ # Example texts
276
  EXAMPLES = [
277
+ [handle_text("""The family of Azerbaijan President Ilham Aliyev leads a charmed, glamorous life, thanks in part to financial interests in almost every sector of the economy.
278
+ His wife, Mehriban, comes from the privileged and powerful Pashayev family that owns banks, insurance and construction companies,
279
+ a television station and a line of cosmetics. She has led the Heydar Aliyev Foundation, Azerbaijan’s pre-eminent charity behind the construction of schools,
280
+ hospitals and the country’s major sports complex. Their eldest daughter, Leyla, editor of Baku magazine, and her sister, Arzu,
281
+ have financial stakes in a firm that won rights to mine for gold in the western village of Chovdar and Azerfon, the country’s largest mobile phone business.
282
+ Arzu is also a significant shareholder in SW Holding, which controls nearly every operation related to Azerbaijan Airlines (β€œAzal”), from meals to airport taxis.
283
+ Both sisters and brother Heydar own property in Dubai valued at roughly $75 million in 2010;
284
+ Heydar is the legal owner of nine luxury mansions in Dubai purchased for some $44 million.""")],
285
+
286
+ [handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years,
287
+ citing lead singer Steven Tyler's unrecoverable vocal cord injury.
288
+ The decision comes after months of unsuccessful treatment for Tyler's fractured larynx, which he suffered in September 2023.""")],
289
 
290
  [handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual
291
+ court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI)
292
+ in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe, pleaded not guilty to the charges.""")],
 
293
  ]
294
 
 
295
  def generate_first_example_cache():
296
+ """
297
+ Generate cache for the first example if it doesn't exist when the app starts
298
+ """
299
+
300
  if not os.path.exists(EXAMPLE_CACHE_FILE):
301
  print("Generating cache for first example...")
302
+
303
  try:
304
  text = EXAMPLES[0][0]
305
+ model = Phi3InstructGraph()
306
+
 
307
  # Extract data
308
+ json_data = extract_kg(text, model)
309
  entities_viz = create_custom_entity_viz(json_data, text)
310
  graph_html = create_graph(json_data)
311
 
 
314
  stats = f"πŸ“Š Extracted {node_count} entities and {edge_count} relationships"
315
 
316
  # Save to cache
317
+ cached_data = {
318
  "graph_html": graph_html,
319
  "entities_viz": entities_viz,
320
  "json_data": json_data,
321
  "stats": stats
322
  }
323
+
324
  with open(EXAMPLE_CACHE_FILE, 'wb') as f:
325
+ pickle.dump(cached_data, f)
 
326
  print("First example cache generated successfully")
327
+
328
+ return cached_data
329
  except Exception as e:
330
  print(f"Error generating first example cache: {str(e)}")
331
  else:
 
339
  return None
340
 
341
  def create_ui():
342
+ """
343
+ Create the Gradio UI
344
+ """
345
+
346
  # Try to generate/load the first example cache
347
  first_example_cache = generate_first_example_cache()
348
 
 
351
  gr.Markdown(f"# {TITLE}")
352
  gr.Markdown(f"{SUBTITLE}")
353
 
 
 
 
354
  # Main content area - redesigned layout
355
  with gr.Row():
356
  # Left panel - Input controls
 
430
 
431
  # Footer
432
  gr.Markdown("---")
433
+ gr.Markdown("πŸ“‹ **Instructions:** Enter text in any language and click 'Extract & Visualize' to generate a knowledge graph.")
434
+ gr.Markdown("πŸ› οΈ Powered by [Phi-3-mini-128k-instruct-graph](https://huggingface.co/EmergentMethods/Phi-3-mini-128k-instruct-graph)")
435
 
436
  return demo
437
 
438
  demo = create_ui()
439
+ demo.launch(share=False)
app-backup.py β†’ app_old.py RENAMED
File without changes
phi3_instruct_graph.py CHANGED
@@ -16,10 +16,17 @@ client = InferenceClient(
16
 
17
  class Phi3InstructGraph:
18
  def __init__(self, model = "EmergentMethods/Phi-3-mini-4k-instruct-graph"):
19
-
 
 
 
20
  self.model_path = model
21
 
22
  def _generate(self, messages):
 
 
 
 
23
  # Use the chat_completion method
24
  response = client.chat_completion(
25
  messages=messages,
@@ -31,6 +38,10 @@ class Phi3InstructGraph:
31
  return generated_text
32
 
33
  def _get_messages(self, text):
 
 
 
 
34
  context = dedent("""\n
35
  A chat between a curious user and an artificial intelligence Assistant. The Assistant is an expert at identifying entities and relationships in text. The Assistant responds in JSON output only.
36
 
@@ -66,31 +77,28 @@ class Phi3InstructGraph:
66
  -------Text end-------
67
  """)
68
 
69
- if self.model_path == "EmergentMethods/Phi-3-medium-128k-instruct-graph":
70
- # model without system message -- why??
71
- messages = [
72
- {
73
- "role": "user",
74
- "content": f"{context}\n Input: {user_message}",
75
- }
76
- ]
77
- return messages
78
- else:
79
- messages = [
80
- {
81
- "role": "system",
82
- "content": context
83
- },
84
- {
85
- "role": "user",
86
- "content": user_message
87
- }
88
- ]
89
- return messages
90
-
91
 
92
  def extract(self, text):
 
 
 
 
93
  messages = self._get_messages(text)
94
  generated_text = self._generate(messages)
95
- # return pipe_output[0]["generated_text"]
96
  return generated_text
 
16
 
17
  class Phi3InstructGraph:
18
  def __init__(self, model = "EmergentMethods/Phi-3-mini-4k-instruct-graph"):
19
+ """
20
+ Initialize the Phi3InstructGraph with a specified model.
21
+ """
22
+
23
  self.model_path = model
24
 
25
  def _generate(self, messages):
26
+ """
27
+ Generate a response from the model based on the provided messages.
28
+ """
29
+
30
  # Use the chat_completion method
31
  response = client.chat_completion(
32
  messages=messages,
 
38
  return generated_text
39
 
40
  def _get_messages(self, text):
41
+ """
42
+ Construct the message list for the chat model.
43
+ """
44
+
45
  context = dedent("""\n
46
  A chat between a curious user and an artificial intelligence Assistant. The Assistant is an expert at identifying entities and relationships in text. The Assistant responds in JSON output only.
47
 
 
77
  -------Text end-------
78
  """)
79
 
80
+ # if self.model_path == "EmergentMethods/Phi-3-medium-128k-instruct-graph":
81
+ messages = [
82
+ {
83
+ "role": "system",
84
+ "content": context
85
+ },
86
+ {
87
+ "role": "user",
88
+ "content": user_message
89
+ }
90
+ ]
91
+ # else:
92
+ # # TODO: update for other models
93
+
94
+ return messages
 
 
 
 
 
 
 
95
 
96
  def extract(self, text):
97
+ """
98
+ Extract knowledge graph from text
99
+ """
100
+
101
  messages = self._get_messages(text)
102
  generated_text = self._generate(messages)
103
+
104
  return generated_text