Mahbodez commited on
Commit
72eff0b
1 Parent(s): e4c48df

Delete interface.py

Browse files
Files changed (1) hide show
  1. interface.py +0 -406
interface.py DELETED
@@ -1,406 +0,0 @@
1
- import json
2
- import numpy as np
3
- import treegraph as tg
4
- import colorama
5
- from colorama import Fore
6
- import networkx as nx
7
- import utils
8
- import re
9
-
10
- DEBUG = True
11
- INPUT_COLOR = Fore.LIGHTGREEN_EX
12
- DEBUG_COLOR = Fore.LIGHTBLACK_EX
13
- OUTPUT_COLOR = Fore.LIGHTMAGENTA_EX
14
- INFO_COLOR = Fore.BLUE
15
- HELP_COLOR = Fore.CYAN
16
-
17
-
18
- def print_debug(*args, color=DEBUG_COLOR):
19
- """
20
- Prints debug messages if DEBUG is set to True.
21
- """
22
- if DEBUG:
23
- for arg in args:
24
- print(color + str(arg))
25
-
26
-
27
- class ReportInterface:
28
- def __init__(
29
- self,
30
- llm: utils.LLM,
31
- system_prompt: str,
32
- tree_graph: nx.Graph,
33
- nodes_dict: dict[str, tg.Node],
34
- api_key: str = None,
35
- ):
36
- self.llm = llm
37
- self.system_prompt = system_prompt
38
- self.tree_graph = tree_graph
39
- self.nodes_dict = nodes_dict
40
- self.api_key = api_key
41
- self.build()
42
-
43
- def build(self):
44
- utils.set_api_key(self.api_key)
45
- self.system_prompt = utils.make_message("system", self.system_prompt)
46
- self.visitable_nodes = self._get_visitable_nodes()
47
- self.report_dict = self._get_report_dict()
48
-
49
- self.active_node: tg.Node = self.nodes_dict["root"]
50
- self.unique_visited_nodes = set() # set of nodes visited
51
- self.node_journey = [] # list of nodes visited
52
- self.distance_travelled = 0 # number of edges travelled
53
- self.jumps = 0 # number of jumps
54
- self.jump_lengths = [] # list of jump lengths
55
- self.counter = 0 # number of questions asked
56
-
57
- colorama.init(autoreset=True) # to reset the color after each print statement
58
-
59
- self.help_message = f"""You are presented with a Knee MRI.
60
- You are asked to fill out a radiology report.
61
- Please only report the findings in the MRI.
62
- Please mention your findings with the corresponding anatomical structures.
63
- There are {len(self.visitable_nodes.keys())} visitable nodes in the tree.
64
- You must visit as many nodes as possible, while avoiding too many jumps."""
65
-
66
- def _get_visitable_nodes(self):
67
- return dict(
68
- zip(
69
- [
70
- node.name
71
- for node in self.tree_graph.nodes
72
- if node.name != "root" and node.has_children() is False
73
- ],
74
- [
75
- node
76
- for node in self.tree_graph.nodes
77
- if node.name != "root" and node.has_children() is False
78
- ],
79
- )
80
- )
81
-
82
- def _get_report_dict(self):
83
- return {
84
- node.name: tg.Node(node.name, "", node.children)
85
- for node in self.visitable_nodes.values()
86
- }
87
-
88
- @utils.debug(DEBUG, print_debug)
89
- def _check_question_validity(
90
- self,
91
- question: str,
92
- ):
93
- # let's ask the question from the model and check if it's valid
94
- template_json = json.dumps(
95
- {key: node.value for key, node in self.visitable_nodes.items()},
96
- indent=4,
97
- )
98
- q = f"""the following is a Knee MRI report "template" in a JSON format with keys and values.
99
- You are given a "finding" phrase from a radiologist.
100
- Match as best as possible the "finding" with one of keys in the "template".
101
- <template>
102
- {template_json}
103
- </template>
104
- <finding>
105
- {question}
106
- </finding>
107
- "available": [Is the "finding" relevant to any key in the "template"? say "yes" or "no".
108
- Make sure the "finding" is relevant to Knee MRI and knee anatomy otherwise say 'no'.
109
- Do not answer irrelevant phrases.]
110
- "node": [if the above answer is 'yes', write only the KEY of the most relevant node to the "finding". otherwise, say 'none'.]
111
- """
112
-
113
- keys = ["available", "node"]
114
- prompt = [self.system_prompt] + [
115
- utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys)
116
- ]
117
- response = self.llm(prompt)
118
- print_debug(
119
- prompt,
120
- response,
121
- )
122
- available = utils.json2dict(response)["available"].strip().lower()
123
- node = utils.json2dict(response)["node"]
124
- return available, node
125
-
126
- def _update_node(self, node_name, findings):
127
- self.report_dict[node_name].value += str(findings) + "\n"
128
- response = f"Updated node '{node_name}' with finding '{findings}'"
129
- print(OUTPUT_COLOR + response)
130
- return response
131
-
132
- def save_report(self, filename: str):
133
- # convert performance metrics to json
134
- metrics = {
135
- "distance_travelled": self.distance_travelled,
136
- "jumps": self.jumps,
137
- "jump_lengths": self.jump_lengths,
138
- "unique_visited_nodes": [node.name for node in self.unique_visited_nodes],
139
- "node_journey": [node.name for node in self.node_journey],
140
- "report": {
141
- node_name: node.value for node_name, node in self.report_dict.items()
142
- },
143
- }
144
- # save the report
145
- with open(filename, "w") as file:
146
- json.dump(metrics, file, indent=4)
147
-
148
- def prime_model(self):
149
- """
150
- Primes the model with the system prompt.
151
- """
152
- q = "Are you ready to begin?\nSay 'yes' or 'no'."
153
- keys = ["answer"]
154
- response = self.llm(
155
- [
156
- self.system_prompt,
157
- utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys),
158
- ],
159
- )
160
- print_debug(q, response)
161
- if utils.json2dict(response)["answer"].lower() == "yes":
162
- print(INFO_COLOR + "The model is ready.")
163
- return True
164
- else:
165
- print(INFO_COLOR + "The model is not ready.")
166
- return False
167
-
168
- def performance_summary(self):
169
- # print out the summary info
170
- print(INFO_COLOR + "Performance Summary:")
171
- print(
172
- INFO_COLOR + f"Total distance travelled: {self.distance_travelled} edge(s)"
173
- )
174
- print(INFO_COLOR + f"Jump lengths: {self.jump_lengths}")
175
- print(INFO_COLOR + f"Jump lengths mean: {np.mean(self.jump_lengths):.1f}")
176
- print(INFO_COLOR + f"Jump lengths SD: {np.std(self.jump_lengths):.1f}")
177
- print(INFO_COLOR + f"Nodes visited in order: {self.node_journey}")
178
- print(INFO_COLOR + f"Unique nodes visited: {self.unique_visited_nodes}")
179
- print(
180
- INFO_COLOR
181
- + f"You have explored {len(self.unique_visited_nodes)/len(self.visitable_nodes):.1%} ({len(self.unique_visited_nodes)}/{len(self.visitable_nodes)}) of the tree."
182
- )
183
- print_debug("\n")
184
- print_debug("Report Summary:".rjust(20))
185
- for name, node in self.report_dict.items():
186
- if node.value != "":
187
- print_debug(f"{name}: {node.value}")
188
- print(INFO_COLOR + f"total cost: ${self.llm.cost:.4f}")
189
- print(INFO_COLOR + f"total tokens used: {self.llm.token_counter}")
190
-
191
- def get_stats(self):
192
- report_string = ""
193
- for name, node in self.report_dict.items():
194
- if node.value != "":
195
- report_string += f"{name}: <{node.value}> \n"
196
- return {
197
- "Lengths travelled": self.distance_travelled,
198
- "Number of jumps": self.jumps,
199
- "Jump lengths": self.jump_lengths,
200
- "Unique nodes visited": [node.name for node in self.unique_visited_nodes],
201
- "Visited Nodes": [node.name for node in self.node_journey],
202
- "Report": report_string,
203
- }
204
-
205
- def visualize_tree(self, **kwargs):
206
- tg.visualize_graph(tg.from_list(self.node_journey), self.tree_graph, **kwargs)
207
-
208
- def get_plot(self, **kwargs):
209
- return tg.get_graph(tg.from_list(self.node_journey), self.tree_graph, **kwargs)
210
-
211
- def process_input(self, input_text: str):
212
- res = "n/a"
213
- try:
214
- finding = input_text
215
- if finding.strip().lower() == "quit":
216
- print(INFO_COLOR + "Exiting...")
217
- return "quit"
218
- elif finding.strip().lower() == "help":
219
- return "help"
220
-
221
- available, node = self._check_question_validity(finding)
222
- if available != "yes":
223
- print(
224
- OUTPUT_COLOR
225
- + "Could not find a relevant node.\nWrite more clearly and provide more details."
226
- )
227
- return "n/a"
228
- if node not in self.visitable_nodes.keys():
229
- print(
230
- OUTPUT_COLOR
231
- + "Could not find a relevant node.\nWrite more clearly and provide more details."
232
- )
233
- return "n/a"
234
- else:
235
- # modify the tree to update the node with findings
236
- res = self._update_node(node, finding)
237
-
238
- print(
239
- INFO_COLOR
240
- + f"jumping from node '{self.active_node}' to node '{node}'..."
241
- )
242
- distance = tg.num_edges_between_nodes(
243
- self.tree_graph, self.active_node, self.nodes_dict[node]
244
- )
245
- print(INFO_COLOR + f"distance travelled: {distance} edge(s)")
246
-
247
- self.active_node = self.nodes_dict[node]
248
- self.jumps += 1
249
- self.jump_lengths.append(distance)
250
- self.distance_travelled += distance
251
- if self.active_node.name != "root":
252
- self.unique_visited_nodes.add(self.active_node)
253
- self.node_journey.append(self.active_node)
254
- except Exception as ex:
255
- print_debug(ex, color=Fore.LIGHTRED_EX)
256
- return "exception"
257
-
258
- self.counter += 1
259
- try:
260
- self.performance_summary()
261
- except Exception as ex:
262
- print_debug(ex, color=Fore.LIGHTRED_EX)
263
- return res
264
-
265
-
266
- class ReportChecklistInterface:
267
- def __init__(
268
- self,
269
- llm: utils.LLM,
270
- system_prompt: str,
271
- graph: nx.Graph,
272
- nodes_dict: dict[str, tg.Node],
273
- api_key: str = None,
274
- ):
275
- self.llm = llm
276
- self.system_prompt = system_prompt
277
- self.tree_graph: nx.Graph = graph
278
- self.nodes_dict = nodes_dict
279
- self.api_key = api_key
280
- self.build()
281
-
282
- def build(self):
283
- utils.set_api_key(self.api_key)
284
- self.system_prompt = utils.make_message("system", self.system_prompt)
285
- self.visitable_nodes = self._get_visitable_nodes()
286
-
287
- colorama.init(autoreset=True) # to reset the color after each print statement
288
-
289
- self.help_message = f"""You are presented with a Knee MRI.
290
- You are asked to fill out a radiology report.
291
- Please only report the findings in the MRI.
292
- Please mention your findings with the corresponding anatomical structures.
293
- There are {len(self.visitable_nodes.keys())} visitable nodes in the tree."""
294
-
295
- def _get_visitable_nodes(self):
296
- return dict(
297
- zip(
298
- [
299
- node.name
300
- for node in self.tree_graph.nodes
301
- if node.name != "root" and node.has_children() is False
302
- ],
303
- [
304
- node
305
- for node in self.tree_graph.nodes
306
- if node.name != "root" and node.has_children() is False
307
- ],
308
- )
309
- )
310
-
311
- @utils.debug(DEBUG, print_debug)
312
- def _check_report(
313
- self,
314
- report: str,
315
- ):
316
- # let's ask the question from the model and check if it's valid
317
- checklist_json = json.dumps(
318
- {key: node.value for key, node in self.visitable_nodes.items()},
319
- indent=4,
320
- )
321
- q = f"""the following is a Knee MRI "checklist" in JSON format with keys as items and values as findings:
322
- A knee MRI "report" is also provided in raw text format written by a radiologist:
323
- <checklist>
324
- {checklist_json}
325
- </checklist>
326
- <report>
327
- {report}
328
- </report>
329
- Your task is to find all the corresponding items from the "checklist" in the "report" and fill out a JSON with the same keys as the "checklist" but extract the corresponding values from the "report".
330
- If a key is not found in the "report", please set the value to "n/a", otherwise set it to the corresponding finding from the "report".
331
- You must check the "report" phrases one by one and find a corresponding key(s) for EACH phrase in the "report" from the "checklist" and fill out the "report_checked" JSON.
332
- Try to fill out as many items as possible.
333
- ALL of the items in the "checklist" must be filled out.
334
- Don't generate findings that are not present in the "report" (new findings).
335
- Be comprehensive and don't miss any findings that are present in the "report".
336
- Watch out for encompassing terms (e.g., "cruciate ligaments" means both "ACL" and "PCL").
337
- "thought_process": [Think in steps on how you would do this task.]
338
- "report_ckecked" : [a JSON with the same keys as the "checklist" but take the values from the "report", as described above.]
339
- """
340
-
341
- keys = ["thought_process", "report_checked"]
342
- prompt = [self.system_prompt] + [
343
- utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys)
344
- ]
345
- response = self.llm(prompt)
346
- print_debug(
347
- prompt,
348
- response,
349
- )
350
- report_checked = utils.json2dict(response)
351
- return report_checked["report_checked"]
352
-
353
- def prime_model(self):
354
- """
355
- Primes the model with the system prompt.
356
- """
357
- q = "Are you ready to begin?\nSay 'yes' or 'no'."
358
- keys = ["answer"]
359
- response = self.llm(
360
- [
361
- self.system_prompt,
362
- utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys),
363
- ],
364
- )
365
- print_debug(q, response)
366
- if utils.json2dict(response)["answer"].lower() == "yes":
367
- print(INFO_COLOR + "The model is ready.")
368
- return True
369
- else:
370
- print(INFO_COLOR + "The model is not ready.")
371
- return False
372
-
373
- def process_input(self, input_text: str):
374
- try:
375
- report = input_text
376
- if report.strip().lower() == "quit":
377
- print(INFO_COLOR + "Exiting...")
378
- return "quit"
379
- elif report.strip().lower() == "help":
380
- return "help"
381
-
382
- checked_report: dict = self._check_report(report)
383
- # make a string of the report
384
- # replace true with [checkmark emoji] and false with [cross emoji]
385
- report_string = ""
386
- CHECKMARK = "\u2705"
387
- CROSS = "\u274C"
388
-
389
- # we need a regex to convert the camelCase keys to a readable format
390
- def camel2readable(camel: str):
391
- string = re.sub("([a-z])([A-Z])", r"\1 \2", camel)
392
- # captialize every word
393
- string = " ".join([word.capitalize() for word in string.split()])
394
- return string
395
-
396
- for key, value in checked_report.items():
397
- if str(value).lower() == "true":
398
- report_string += f"{camel2readable(key)}: {CHECKMARK}\n"
399
- elif str(value).lower() == "n/a":
400
- report_string += f"{camel2readable(key)}: {CROSS}\n"
401
- else:
402
- report_string += f"{camel2readable(key)}: <{value}> {CHECKMARK}\n"
403
- return report_string
404
- except Exception as ex:
405
- print_debug(ex, color=Fore.LIGHTRED_EX)
406
- return "exception"