Mahbodez commited on
Commit
1d80bec
1 Parent(s): 1a15844

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +65 -0
  2. interface.py +406 -0
  3. knee_template.json +359 -0
  4. treegraph.py +226 -0
  5. utils.py +361 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import interface
3
+ import utils
4
+ import treegraph as tg
5
+
6
+ system_prompt = """
7
+ You are a critical AI radiology assistant.
8
+ You are helping a radiologist correctly fill out a radiology report.
9
+ The report is regarding a Knee MRI.
10
+ """
11
+
12
+ graph, nodes_dict = tg.build_tree_from_file("knee_template.json")
13
+ report_interface = interface.ReportChecklistInterface(
14
+ llm=utils.LLM(model="gpt-3.5-turbo"),
15
+ system_prompt=system_prompt,
16
+ graph=graph,
17
+ nodes_dict=nodes_dict,
18
+ )
19
+
20
+ if report_interface.prime_model() is False:
21
+ print("Model priming failed. Please try again.")
22
+ exit()
23
+ else:
24
+ print("Model priming successful.")
25
+
26
+ with gr.Blocks(theme="soft") as demo:
27
+ gr.Markdown("## Radiology Report Assistant")
28
+ gr.Markdown(report_interface.help_message)
29
+
30
+ running = gr.components.Variable(True)
31
+ report_textbox = gr.TextArea(label="Report", lines=20, max_lines=50)
32
+ check_btn = gr.Button(
33
+ value="Check Report",
34
+ )
35
+ clear_btn = gr.ClearButton(
36
+ value="Clear Messages",
37
+ )
38
+ quit_btn = gr.Button(
39
+ value="Quit",
40
+ )
41
+ results_textbox = gr.TextArea(label="Results", lines=20, max_lines=50)
42
+ clear_btn.add([results_textbox, report_textbox])
43
+
44
+ def check_report(report):
45
+ if running:
46
+ results = report_interface.process_input(report)
47
+ if results == "quit":
48
+ quit_fn()
49
+ elif results == "help":
50
+ return report_interface.help_message
51
+ elif results == "exception":
52
+ return "An exception occurred. Please try again."
53
+ else:
54
+ return results
55
+ else:
56
+ return "Model has been stopped."
57
+
58
+ def quit_fn():
59
+ running.value = False
60
+ results_textbox.value = "Model has been stopped."
61
+
62
+ check_btn.click(fn=check_report, inputs=[report_textbox], outputs=[results_textbox])
63
+ quit_btn.click(fn=quit_fn)
64
+
65
+ demo.launch()
interface.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import treegraph as tg
4
+ import colorama
5
+ from colorama import Fore
6
+ import networkx as nx
7
+ import utils
8
+ import re
9
+
10
+ DEBUG = True
11
+ INPUT_COLOR = Fore.LIGHTGREEN_EX
12
+ DEBUG_COLOR = Fore.LIGHTBLACK_EX
13
+ OUTPUT_COLOR = Fore.LIGHTMAGENTA_EX
14
+ INFO_COLOR = Fore.BLUE
15
+ HELP_COLOR = Fore.CYAN
16
+
17
+
18
+ def print_debug(*args, color=DEBUG_COLOR):
19
+ """
20
+ Prints debug messages if DEBUG is set to True.
21
+ """
22
+ if DEBUG:
23
+ for arg in args:
24
+ print(color + str(arg))
25
+
26
+
27
+ class ReportInterface:
28
+ def __init__(
29
+ self,
30
+ llm: utils.LLM,
31
+ system_prompt: str,
32
+ tree_graph: nx.Graph,
33
+ nodes_dict: dict[str, tg.Node],
34
+ api_key: str = None,
35
+ ):
36
+ self.llm = llm
37
+ self.system_prompt = system_prompt
38
+ self.tree_graph = tree_graph
39
+ self.nodes_dict = nodes_dict
40
+ self.api_key = api_key
41
+ self.build()
42
+
43
+ def build(self):
44
+ utils.set_api_key(self.api_key)
45
+ self.system_prompt = utils.make_message("system", self.system_prompt)
46
+ self.visitable_nodes = self._get_visitable_nodes()
47
+ self.report_dict = self._get_report_dict()
48
+
49
+ self.active_node: tg.Node = self.nodes_dict["root"]
50
+ self.unique_visited_nodes = set() # set of nodes visited
51
+ self.node_journey = [] # list of nodes visited
52
+ self.distance_travelled = 0 # number of edges travelled
53
+ self.jumps = 0 # number of jumps
54
+ self.jump_lengths = [] # list of jump lengths
55
+ self.counter = 0 # number of questions asked
56
+
57
+ colorama.init(autoreset=True) # to reset the color after each print statement
58
+
59
+ self.help_message = f"""You are presented with a Knee MRI.
60
+ You are asked to fill out a radiology report.
61
+ Please only report the findings in the MRI.
62
+ Please mention your findings with the corresponding anatomical structures.
63
+ There are {len(self.visitable_nodes.keys())} visitable nodes in the tree.
64
+ You must visit as many nodes as possible, while avoiding too many jumps."""
65
+
66
+ def _get_visitable_nodes(self):
67
+ return dict(
68
+ zip(
69
+ [
70
+ node.name
71
+ for node in self.tree_graph.nodes
72
+ if node.name != "root" and node.has_children() is False
73
+ ],
74
+ [
75
+ node
76
+ for node in self.tree_graph.nodes
77
+ if node.name != "root" and node.has_children() is False
78
+ ],
79
+ )
80
+ )
81
+
82
+ def _get_report_dict(self):
83
+ return {
84
+ node.name: tg.Node(node.name, "", node.children)
85
+ for node in self.visitable_nodes.values()
86
+ }
87
+
88
+ @utils.debug(DEBUG, print_debug)
89
+ def _check_question_validity(
90
+ self,
91
+ question: str,
92
+ ):
93
+ # let's ask the question from the model and check if it's valid
94
+ template_json = json.dumps(
95
+ {key: node.value for key, node in self.visitable_nodes.items()},
96
+ indent=4,
97
+ )
98
+ q = f"""the following is a Knee MRI report "template" in a JSON format with keys and values.
99
+ You are given a "finding" phrase from a radiologist.
100
+ Match as best as possible the "finding" with one of keys in the "template".
101
+ <template>
102
+ {template_json}
103
+ </template>
104
+ <finding>
105
+ {question}
106
+ </finding>
107
+ "available": [Is the "finding" relevant to any key in the "template"? say "yes" or "no".
108
+ Make sure the "finding" is relevant to Knee MRI and knee anatomy otherwise say 'no'.
109
+ Do not answer irrelevant phrases.]
110
+ "node": [if the above answer is 'yes', write only the KEY of the most relevant node to the "finding". otherwise, say 'none'.]
111
+ """
112
+
113
+ keys = ["available", "node"]
114
+ prompt = [self.system_prompt] + [
115
+ utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys)
116
+ ]
117
+ response = self.llm(prompt)
118
+ print_debug(
119
+ prompt,
120
+ response,
121
+ )
122
+ available = utils.json2dict(response)["available"].strip().lower()
123
+ node = utils.json2dict(response)["node"]
124
+ return available, node
125
+
126
+ def _update_node(self, node_name, findings):
127
+ self.report_dict[node_name].value += str(findings) + "\n"
128
+ response = f"Updated node '{node_name}' with finding '{findings}'"
129
+ print(OUTPUT_COLOR + response)
130
+ return response
131
+
132
+ def save_report(self, filename: str):
133
+ # convert performance metrics to json
134
+ metrics = {
135
+ "distance_travelled": self.distance_travelled,
136
+ "jumps": self.jumps,
137
+ "jump_lengths": self.jump_lengths,
138
+ "unique_visited_nodes": [node.name for node in self.unique_visited_nodes],
139
+ "node_journey": [node.name for node in self.node_journey],
140
+ "report": {
141
+ node_name: node.value for node_name, node in self.report_dict.items()
142
+ },
143
+ }
144
+ # save the report
145
+ with open(filename, "w") as file:
146
+ json.dump(metrics, file, indent=4)
147
+
148
+ def prime_model(self):
149
+ """
150
+ Primes the model with the system prompt.
151
+ """
152
+ q = "Are you ready to begin?\nSay 'yes' or 'no'."
153
+ keys = ["answer"]
154
+ response = self.llm(
155
+ [
156
+ self.system_prompt,
157
+ utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys),
158
+ ],
159
+ )
160
+ print_debug(q, response)
161
+ if utils.json2dict(response)["answer"].lower() == "yes":
162
+ print(INFO_COLOR + "The model is ready.")
163
+ return True
164
+ else:
165
+ print(INFO_COLOR + "The model is not ready.")
166
+ return False
167
+
168
+ def performance_summary(self):
169
+ # print out the summary info
170
+ print(INFO_COLOR + "Performance Summary:")
171
+ print(
172
+ INFO_COLOR + f"Total distance travelled: {self.distance_travelled} edge(s)"
173
+ )
174
+ print(INFO_COLOR + f"Jump lengths: {self.jump_lengths}")
175
+ print(INFO_COLOR + f"Jump lengths mean: {np.mean(self.jump_lengths):.1f}")
176
+ print(INFO_COLOR + f"Jump lengths SD: {np.std(self.jump_lengths):.1f}")
177
+ print(INFO_COLOR + f"Nodes visited in order: {self.node_journey}")
178
+ print(INFO_COLOR + f"Unique nodes visited: {self.unique_visited_nodes}")
179
+ print(
180
+ INFO_COLOR
181
+ + f"You have explored {len(self.unique_visited_nodes)/len(self.visitable_nodes):.1%} ({len(self.unique_visited_nodes)}/{len(self.visitable_nodes)}) of the tree."
182
+ )
183
+ print_debug("\n")
184
+ print_debug("Report Summary:".rjust(20))
185
+ for name, node in self.report_dict.items():
186
+ if node.value != "":
187
+ print_debug(f"{name}: {node.value}")
188
+ print(INFO_COLOR + f"total cost: ${self.llm.cost:.4f}")
189
+ print(INFO_COLOR + f"total tokens used: {self.llm.token_counter}")
190
+
191
+ def get_stats(self):
192
+ report_string = ""
193
+ for name, node in self.report_dict.items():
194
+ if node.value != "":
195
+ report_string += f"{name}: <{node.value}> \n"
196
+ return {
197
+ "Lengths travelled": self.distance_travelled,
198
+ "Number of jumps": self.jumps,
199
+ "Jump lengths": self.jump_lengths,
200
+ "Unique nodes visited": [node.name for node in self.unique_visited_nodes],
201
+ "Visited Nodes": [node.name for node in self.node_journey],
202
+ "Report": report_string,
203
+ }
204
+
205
+ def visualize_tree(self, **kwargs):
206
+ tg.visualize_graph(tg.from_list(self.node_journey), self.tree_graph, **kwargs)
207
+
208
+ def get_plot(self, **kwargs):
209
+ return tg.get_graph(tg.from_list(self.node_journey), self.tree_graph, **kwargs)
210
+
211
+ def process_input(self, input_text: str):
212
+ res = "n/a"
213
+ try:
214
+ finding = input_text
215
+ if finding.strip().lower() == "quit":
216
+ print(INFO_COLOR + "Exiting...")
217
+ return "quit"
218
+ elif finding.strip().lower() == "help":
219
+ return "help"
220
+
221
+ available, node = self._check_question_validity(finding)
222
+ if available != "yes":
223
+ print(
224
+ OUTPUT_COLOR
225
+ + "Could not find a relevant node.\nWrite more clearly and provide more details."
226
+ )
227
+ return "n/a"
228
+ if node not in self.visitable_nodes.keys():
229
+ print(
230
+ OUTPUT_COLOR
231
+ + "Could not find a relevant node.\nWrite more clearly and provide more details."
232
+ )
233
+ return "n/a"
234
+ else:
235
+ # modify the tree to update the node with findings
236
+ res = self._update_node(node, finding)
237
+
238
+ print(
239
+ INFO_COLOR
240
+ + f"jumping from node '{self.active_node}' to node '{node}'..."
241
+ )
242
+ distance = tg.num_edges_between_nodes(
243
+ self.tree_graph, self.active_node, self.nodes_dict[node]
244
+ )
245
+ print(INFO_COLOR + f"distance travelled: {distance} edge(s)")
246
+
247
+ self.active_node = self.nodes_dict[node]
248
+ self.jumps += 1
249
+ self.jump_lengths.append(distance)
250
+ self.distance_travelled += distance
251
+ if self.active_node.name != "root":
252
+ self.unique_visited_nodes.add(self.active_node)
253
+ self.node_journey.append(self.active_node)
254
+ except Exception as ex:
255
+ print_debug(ex, color=Fore.LIGHTRED_EX)
256
+ return "exception"
257
+
258
+ self.counter += 1
259
+ try:
260
+ self.performance_summary()
261
+ except Exception as ex:
262
+ print_debug(ex, color=Fore.LIGHTRED_EX)
263
+ return res
264
+
265
+
266
+ class ReportChecklistInterface:
267
+ def __init__(
268
+ self,
269
+ llm: utils.LLM,
270
+ system_prompt: str,
271
+ graph: nx.Graph,
272
+ nodes_dict: dict[str, tg.Node],
273
+ api_key: str = None,
274
+ ):
275
+ self.llm = llm
276
+ self.system_prompt = system_prompt
277
+ self.tree_graph: nx.Graph = graph
278
+ self.nodes_dict = nodes_dict
279
+ self.api_key = api_key
280
+ self.build()
281
+
282
+ def build(self):
283
+ utils.set_api_key(self.api_key)
284
+ self.system_prompt = utils.make_message("system", self.system_prompt)
285
+ self.visitable_nodes = self._get_visitable_nodes()
286
+
287
+ colorama.init(autoreset=True) # to reset the color after each print statement
288
+
289
+ self.help_message = f"""You are presented with a Knee MRI.
290
+ You are asked to fill out a radiology report.
291
+ Please only report the findings in the MRI.
292
+ Please mention your findings with the corresponding anatomical structures.
293
+ There are {len(self.visitable_nodes.keys())} visitable nodes in the tree."""
294
+
295
+ def _get_visitable_nodes(self):
296
+ return dict(
297
+ zip(
298
+ [
299
+ node.name
300
+ for node in self.tree_graph.nodes
301
+ if node.name != "root" and node.has_children() is False
302
+ ],
303
+ [
304
+ node
305
+ for node in self.tree_graph.nodes
306
+ if node.name != "root" and node.has_children() is False
307
+ ],
308
+ )
309
+ )
310
+
311
+ @utils.debug(DEBUG, print_debug)
312
+ def _check_report(
313
+ self,
314
+ report: str,
315
+ ):
316
+ # let's ask the question from the model and check if it's valid
317
+ checklist_json = json.dumps(
318
+ {key: node.value for key, node in self.visitable_nodes.items()},
319
+ indent=4,
320
+ )
321
+ q = f"""the following is a Knee MRI "checklist" in JSON format with keys as items and values as findings:
322
+ A knee MRI "report" is also provided in raw text format written by a radiologist:
323
+ <checklist>
324
+ {checklist_json}
325
+ </checklist>
326
+ <report>
327
+ {report}
328
+ </report>
329
+ Your task is to find all the corresponding items from the "checklist" in the "report" and fill out a JSON with the same keys as the "checklist" but extract the corresponding values from the "report".
330
+ If a key is not found in the "report", please set the value to "n/a", otherwise set it to the corresponding finding from the "report".
331
+ You must check the "report" phrases one by one and find a corresponding key(s) for EACH phrase in the "report" from the "checklist" and fill out the "report_checked" JSON.
332
+ Try to fill out as many items as possible.
333
+ ALL of the items in the "checklist" must be filled out.
334
+ Don't generate findings that are not present in the "report" (new findings).
335
+ Be comprehensive and don't miss any findings that are present in the "report".
336
+ Watch out for encompassing terms (e.g., "cruciate ligaments" means both "ACL" and "PCL").
337
+ "thought_process": [Think in steps on how you would do this task.]
338
+ "report_ckecked" : [a JSON with the same keys as the "checklist" but take the values from the "report", as described above.]
339
+ """
340
+
341
+ keys = ["thought_process", "report_checked"]
342
+ prompt = [self.system_prompt] + [
343
+ utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys)
344
+ ]
345
+ response = self.llm(prompt)
346
+ print_debug(
347
+ prompt,
348
+ response,
349
+ )
350
+ report_checked = utils.json2dict(response)
351
+ return report_checked["report_checked"]
352
+
353
+ def prime_model(self):
354
+ """
355
+ Primes the model with the system prompt.
356
+ """
357
+ q = "Are you ready to begin?\nSay 'yes' or 'no'."
358
+ keys = ["answer"]
359
+ response = self.llm(
360
+ [
361
+ self.system_prompt,
362
+ utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys),
363
+ ],
364
+ )
365
+ print_debug(q, response)
366
+ if utils.json2dict(response)["answer"].lower() == "yes":
367
+ print(INFO_COLOR + "The model is ready.")
368
+ return True
369
+ else:
370
+ print(INFO_COLOR + "The model is not ready.")
371
+ return False
372
+
373
+ def process_input(self, input_text: str):
374
+ try:
375
+ report = input_text
376
+ if report.strip().lower() == "quit":
377
+ print(INFO_COLOR + "Exiting...")
378
+ return "quit"
379
+ elif report.strip().lower() == "help":
380
+ return "help"
381
+
382
+ checked_report: dict = self._check_report(report)
383
+ # make a string of the report
384
+ # replace true with [checkmark emoji] and false with [cross emoji]
385
+ report_string = ""
386
+ CHECKMARK = "\u2705"
387
+ CROSS = "\u274C"
388
+
389
+ # we need a regex to convert the camelCase keys to a readable format
390
+ def camel2readable(camel: str):
391
+ string = re.sub("([a-z])([A-Z])", r"\1 \2", camel)
392
+ # captialize every word
393
+ string = " ".join([word.capitalize() for word in string.split()])
394
+ return string
395
+
396
+ for key, value in checked_report.items():
397
+ if str(value).lower() == "true":
398
+ report_string += f"{camel2readable(key)}: {CHECKMARK}\n"
399
+ elif str(value).lower() == "n/a":
400
+ report_string += f"{camel2readable(key)}: {CROSS}\n"
401
+ else:
402
+ report_string += f"{camel2readable(key)}: <{value}> {CHECKMARK}\n"
403
+ return report_string
404
+ except Exception as ex:
405
+ print_debug(ex, color=Fore.LIGHTRED_EX)
406
+ return "exception"
knee_template.json ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "root": {
3
+ "value": "root",
4
+ "parent": null,
5
+ "children": [
6
+ "kneeJointEffusion",
7
+ "kneeMeniscus",
8
+ "kneeAclPcl",
9
+ "kneeMcl",
10
+ "kneePosterolateralCorner",
11
+ "kneeExtensorMechanism",
12
+ "kneeCartilage",
13
+ "kneeBone",
14
+ "kneeOther"
15
+ ]
16
+ },
17
+ "kneeJointEffusion": {
18
+ "value": "Presence and/or extent of joint effusion.",
19
+ "parent": "root",
20
+ "children": []
21
+ },
22
+ "kneeMeniscus": {
23
+ "value": "",
24
+ "parent": "root",
25
+ "children": [
26
+ "kneeMeniscusMedialTearing",
27
+ "kneeMeniscusLateralTearing",
28
+ "kneeMeniscusWrisberg",
29
+ "kneeMeniscusRootTearing",
30
+ "kneeMeniscusRampLesion"
31
+ ]
32
+ },
33
+ "kneeMeniscusMedialTearing": {
34
+ "value": "Presence and/or severity of medial meniscus tearing",
35
+ "parent": "kneeMeniscus",
36
+ "children": []
37
+ },
38
+ "kneeMeniscusLateralTearing": {
39
+ "value": "Presence and/or severity of lateral meniscus tearing",
40
+ "parent": "kneeMeniscus",
41
+ "children": []
42
+ },
43
+ "kneeMeniscusWrisberg": {
44
+ "value": "Presence and/or severity of Wrisberg variant",
45
+ "parent": "kneeMeniscus",
46
+ "children": []
47
+ },
48
+ "kneeMeniscusRootTearing": {
49
+ "value": "Presence and/or severity of meniscus root tearing",
50
+ "parent": "kneeMeniscus",
51
+ "children": []
52
+ },
53
+ "kneeMeniscusRampLesion": {
54
+ "value": "Presence and/or severity of ramp lesion",
55
+ "parent": "kneeMeniscus",
56
+ "children": []
57
+ },
58
+ "kneeAclPcl": {
59
+ "value": "",
60
+ "parent": "root",
61
+ "children": [
62
+ "kneeAcl",
63
+ "kneePcl"
64
+ ]
65
+ },
66
+ "kneeAcl": {
67
+ "value": "",
68
+ "parent": "kneeAclPcl",
69
+ "children": [
70
+ "kneeAclTearing",
71
+ "kneeAclDegeneration",
72
+ "kneeAclReconstruction"
73
+ ]
74
+ },
75
+ "kneeAclTearing": {
76
+ "value": "Presence and/or severity of ACL tearing",
77
+ "parent": "kneeAcl",
78
+ "children": []
79
+ },
80
+ "kneeAclDegeneration": {
81
+ "value": "Presence and/or severity of ACL degeneration",
82
+ "parent": "kneeAcl",
83
+ "children": []
84
+ },
85
+ "kneeAclReconstruction": {
86
+ "value": "ACL reconstruction status",
87
+ "parent": "kneeAcl",
88
+ "children": []
89
+ },
90
+ "kneePcl": {
91
+ "value": "",
92
+ "parent": "kneeAclPcl",
93
+ "children": [
94
+ "kneePclTearing",
95
+ "kneePclDegeneration",
96
+ "kneePclReconstruction"
97
+ ]
98
+ },
99
+ "kneePclTearing": {
100
+ "value": "Presence and/or severity of PCL tearing",
101
+ "parent": "kneePcl",
102
+ "children": []
103
+ },
104
+ "kneePclDegeneration": {
105
+ "value": "Presence and/or severity of PCL degeneration",
106
+ "parent": "kneePcl",
107
+ "children": []
108
+ },
109
+ "kneePclReconstruction": {
110
+ "value": "PCL reconstruction status",
111
+ "parent": "kneePcl",
112
+ "children": []
113
+ },
114
+ "kneeMcl": {
115
+ "value": "",
116
+ "parent": "root",
117
+ "children": [
118
+ "kneeMclTearing",
119
+ "kneeMclDeepFibers",
120
+ "kneeMclSuperficialFibers"
121
+ ]
122
+ },
123
+ "kneeMclTearing": {
124
+ "value": "Presence and/or severity of MCL tearing",
125
+ "parent": "kneeMcl",
126
+ "children": []
127
+ },
128
+ "kneeMclDeepFibers": {
129
+ "value": "MCL deep fibers status",
130
+ "parent": "kneeMcl",
131
+ "children": []
132
+ },
133
+ "kneeMclSuperficialFibers": {
134
+ "value": "MCL superficial fibers status",
135
+ "parent": "kneeMcl",
136
+ "children": []
137
+ },
138
+ "kneePosterolateralCorner": {
139
+ "value": "",
140
+ "parent": "root",
141
+ "children": [
142
+ "kneeIlioTibialBand",
143
+ "kneeBicepsFemorisTendon",
144
+ "kneeLateralCollateralLigament"
145
+ ]
146
+ },
147
+ "kneeIlioTibialBand": {
148
+ "value": "Presence and/or severity of ilio-tibial band findings",
149
+ "parent": "kneePosterolateralCorner",
150
+ "children": []
151
+ },
152
+ "kneeBicepsFemorisTendon": {
153
+ "value": "Presence and/or severity of biceps femoris tendon findings",
154
+ "parent": "kneePosterolateralCorner",
155
+ "children": []
156
+ },
157
+ "kneeLateralCollateralLigament": {
158
+ "value": "Presence and/or severity of lateral collateral ligament findings",
159
+ "parent": "kneePosterolateralCorner",
160
+ "children": []
161
+ },
162
+ "kneeExtensorMechanism": {
163
+ "value": "",
164
+ "parent": "root",
165
+ "children": [
166
+ "kneeQuadricepsTendon",
167
+ "kneePatellarTendon"
168
+ ]
169
+ },
170
+ "kneeQuadricepsTendon": {
171
+ "value": "",
172
+ "parent": "kneeExtensorMechanism",
173
+ "children": [
174
+ "kneeQuadricepsTendonTearing",
175
+ "kneeQuadricepsTendinopathy"
176
+ ]
177
+ },
178
+ "kneeQuadricepsTendonTearing": {
179
+ "value": "Presence and/or severity of quadriceps tendon tearing",
180
+ "parent": "kneeQuadricepsTendon",
181
+ "children": []
182
+ },
183
+ "kneeQuadricepsTendinopathy": {
184
+ "value": "Presence and/or severity of quadriceps tendinopathy",
185
+ "parent": "kneeQuadricepsTendon",
186
+ "children": []
187
+ },
188
+ "kneePatellarTendon": {
189
+ "value": "",
190
+ "parent": "kneeExtensorMechanism",
191
+ "children": [
192
+ "kneePatellarTendonTearing",
193
+ "kneePatellarTendinopathy"
194
+ ]
195
+ },
196
+ "kneePatellarTendonTearing": {
197
+ "value": "Presence and/or severity of patellar tendon tearing",
198
+ "parent": "kneePatellarTendon",
199
+ "children": []
200
+ },
201
+ "kneePatellarTendinopathy": {
202
+ "value": "Presence and/or severity of patellar tendinopathy",
203
+ "parent": "kneePatellarTendon",
204
+ "children": []
205
+ },
206
+ "kneeCartilage": {
207
+ "value": "Articular cartilage status",
208
+ "parent": "root",
209
+ "children": [
210
+ "kneeCartilageFemoral",
211
+ "kneeCartilageTibial",
212
+ "kneeCartilagePatellar",
213
+ "kneeOsteochondralLesion"
214
+ ]
215
+ },
216
+ "kneeCartilageFemoral": {
217
+ "value": "",
218
+ "parent": "kneeCartilage",
219
+ "children": [
220
+ "kneeCartilageFemoralMedial",
221
+ "kneeCartilageFemoralLateral"
222
+ ]
223
+ },
224
+ "kneeCartilageFemoralMedial": {
225
+ "value": "Presence and/or severity of knee medial femoral cartilage findings",
226
+ "parent": "kneeCartilageFemoral",
227
+ "children": []
228
+ },
229
+ "kneeCartilageFemoralLateral": {
230
+ "value": "Presence and/or severity of knee lateral femoral cartilage findings",
231
+ "parent": "kneeCartilageFemoral",
232
+ "children": []
233
+ },
234
+ "kneeCartilageTibial": {
235
+ "value": "",
236
+ "parent": "kneeCartilage",
237
+ "children": [
238
+ "kneeCartilageTibialMedial",
239
+ "kneeCartilageTibialLateral"
240
+ ]
241
+ },
242
+ "kneeCartilageTibialMedial": {
243
+ "value": "Presence and/or severity of knee medial tibial cartilage findings",
244
+ "parent": "kneeCartilageTibial",
245
+ "children": []
246
+ },
247
+ "kneeCartilageTibialLateral": {
248
+ "value": "Presence and/or severity of knee lateral tibial cartilage findings",
249
+ "parent": "kneeCartilageTibial",
250
+ "children": []
251
+ },
252
+ "kneeCartilagePatellar": {
253
+ "value": "",
254
+ "parent": "kneeCartilage",
255
+ "children": [
256
+ "kneeCartilagePatellarMedial",
257
+ "kneeCartilagePatellarLateral"
258
+ ]
259
+ },
260
+ "kneeOsteochondralLesion": {
261
+ "value": "Presence and/or severity of knee osteochondral lesions/defects",
262
+ "parent": "kneeCartilage",
263
+ "children": []
264
+ },
265
+ "kneeCartilagePatellarMedial": {
266
+ "value": "Presence and/or severity of knee medial patellar cartilage findings",
267
+ "parent": "kneeCartilagePatellar",
268
+ "children": []
269
+ },
270
+ "kneeCartilagePatellarLateral": {
271
+ "value": "Presence and/or severity of knee lateral patellar cartilage findings",
272
+ "parent": "kneeCartilagePatellar",
273
+ "children": []
274
+ },
275
+ "kneeBone": {
276
+ "value": "",
277
+ "parent": "root",
278
+ "children": [
279
+ "kneeBoneFracture",
280
+ "kneeBoneMarrowEdema",
281
+ "kneeSubchondralFracture",
282
+ "kneeOsteonecrosis",
283
+ "kneeBoneAvn"
284
+ ]
285
+ },
286
+ "kneeBoneFracture": {
287
+ "value": "Presence and/or severity and/or location and/or type of knee bone fracture",
288
+ "parent": "kneeBone",
289
+ "children": []
290
+ },
291
+ "kneeBoneMarrowEdema": {
292
+ "value": "Presence and/or severity of knee bone marrow edema/contusion",
293
+ "parent": "kneeBone",
294
+ "children": []
295
+ },
296
+ "kneeSubchondralFracture": {
297
+ "value": "Presence and/or severity of knee subchondral fractures",
298
+ "parent": "kneeBone",
299
+ "children": []
300
+ },
301
+ "kneeOsteonecrosis": {
302
+ "value": "Presence and/or severity of knee osteonecrosis",
303
+ "parent": "kneeBone",
304
+ "children": []
305
+ },
306
+ "kneeBoneAvn": {
307
+ "value": "Presence and/or severity of knee avascular necrosis",
308
+ "parent": "kneeBone",
309
+ "children": []
310
+ },
311
+ "kneeOther": {
312
+ "value": "Other knee findings",
313
+ "parent": "root",
314
+ "children": [
315
+ "kneeBursa",
316
+ "kneePoplitealCyst",
317
+ "kneeGanglionCyst",
318
+ "kneeLipoma",
319
+ "kneeMass",
320
+ "kneeSynovium",
321
+ "other"
322
+ ]
323
+ },
324
+ "kneeBursa": {
325
+ "value": "Presence and/or severity of knee bursa findings, e.g. bursitis",
326
+ "parent": "kneeOther",
327
+ "children": []
328
+ },
329
+ "kneePoplitealCyst": {
330
+ "value": "Presence and/or extent of knee popliteal/Baker's cyst",
331
+ "parent": "kneeOther",
332
+ "children": []
333
+ },
334
+ "kneeGanglionCyst": {
335
+ "value": "Presence and/or extent of knee ganglion cyst",
336
+ "parent": "kneeOther",
337
+ "children": []
338
+ },
339
+ "kneeLipoma": {
340
+ "value": "Presence and/or extent of knee lipoma",
341
+ "parent": "kneeOther",
342
+ "children": []
343
+ },
344
+ "kneeMass": {
345
+ "value": "Presence and/or extent of knee mass",
346
+ "parent": "kneeOther",
347
+ "children": []
348
+ },
349
+ "kneeSynovium": {
350
+ "value": "Presence and/or extent of knee synovial findings, e.g. synovitis, thickening",
351
+ "parent": "kneeOther",
352
+ "children": []
353
+ },
354
+ "other": {
355
+ "value": "Any other findings not listed above",
356
+ "parent": "kneeOther",
357
+ "children": []
358
+ }
359
+ }
treegraph.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import json
3
+ import matplotlib.pyplot as plt
4
+
5
+
6
+ class Node:
7
+ def __init__(self, name: str, value=None, parent=None, children: list = []):
8
+ self.name = name
9
+ self.children = set(children)
10
+ self.parent = parent
11
+ self.value = value
12
+
13
+ def __repr__(self):
14
+ return self.name
15
+
16
+ def __str__(self):
17
+ return self.name
18
+
19
+ def __eq__(self, other):
20
+ return self.name == other.name
21
+
22
+ def __hash__(self) -> int:
23
+ return hash(self.name)
24
+
25
+ # make serializable for json
26
+ def __getstate__(self):
27
+ return self.__dict__
28
+
29
+ def __dict__(self):
30
+ # return a dict of the node's attributes
31
+ return {
32
+ "name": self.name,
33
+ "children": self.children,
34
+ "parent": self.parent,
35
+ "value": self.value,
36
+ }
37
+
38
+ def to_json(self):
39
+ """
40
+ Returns a JSON string representation of the node.
41
+ """
42
+ return json.dumps(self.__dict__)
43
+
44
+ def add_child(self, child):
45
+ self.children.add(child)
46
+
47
+ def has_children(self):
48
+ return len(self.children) > 0
49
+
50
+ def set_parent(self, new_parent):
51
+ self.parent = new_parent
52
+
53
+ def set_value(self, new_value):
54
+ self.value = new_value
55
+
56
+
57
+ def read_json(fname: str) -> dict:
58
+ assert fname.endswith(".json"), "File must be a json file"
59
+ with open(fname, "r") as f:
60
+ data = json.load(f)
61
+ return dict(data)
62
+
63
+
64
+ def build_tree_from_dict(data: dict, connect_children: bool = True):
65
+ # every dict key is a node's name
66
+ # dict value is a dict with keys "value", "parent", "children"
67
+ # "value" is the node's value
68
+ # "parent" is the node's parent's name
69
+ # "children" is a list of the node's children's names
70
+ # create a networkx graph
71
+ G = nx.Graph()
72
+ nodes_dict = dict()
73
+ # build the nodes
74
+ for name, info in data.items():
75
+ value = info["value"]
76
+ parent = info["parent"]
77
+ children: list = info["children"]
78
+ nodes_dict[name] = Node(
79
+ name=name, parent=parent, children=children, value=value
80
+ )
81
+ G.add_node(nodes_dict[name], value=value)
82
+ # build the edges
83
+ for _, node in nodes_dict.items():
84
+ for child in node.children:
85
+ G.add_edge(node, nodes_dict[child])
86
+ # connect children to each other if connect_children is True
87
+ if connect_children:
88
+ for child2 in node.children:
89
+ if child != child2:
90
+ G.add_edge(nodes_dict[child], nodes_dict[child2])
91
+ return G, nodes_dict
92
+
93
+
94
+ def build_tree_from_file(fname: str):
95
+ data = read_json(fname)
96
+ return build_tree_from_dict(data)
97
+
98
+
99
+ # calculate the number of edges between two nodes
100
+ def num_edges_between_nodes(G, node1, node2):
101
+ return len(nx.shortest_path(G, node1, node2)) - 1
102
+
103
+
104
+ def explore_bfs(G: nx.Graph, source: Node, nodes_dict: dict[str, Node]):
105
+ # start from a source node and explore the graph in a breadth-first manner
106
+ # prioritize nodes with non-empty values
107
+ # explore the graph and return a list of nodes in the order they were explored
108
+ explored_nodes = []
109
+ queue = [source]
110
+ while queue:
111
+ node = queue.pop(0)
112
+ explored_nodes.append(node)
113
+ for child in node.children:
114
+ if nodes_dict[child].value:
115
+ queue.insert(0, nodes_dict[child])
116
+ else:
117
+ queue.append(nodes_dict[child])
118
+ return explored_nodes
119
+
120
+
121
+ def from_list(node_list: list[Node], directional=True):
122
+ # create a tree from a list of nodes
123
+ # and label the edges from the first node to the last node from 1 to n
124
+ if directional:
125
+ G = nx.DiGraph()
126
+ else:
127
+ G = nx.Graph()
128
+ G.add_nodes_from(node_list)
129
+ for i in range(len(node_list) - 1):
130
+ G.add_edge(node_list[i], node_list[i + 1], label=i + 1)
131
+ return G
132
+
133
+
134
+ def visualize_graph(
135
+ graph: nx.Graph,
136
+ layout_graph: nx.Graph,
137
+ title="BFS Tree",
138
+ fig_size=(30, 20),
139
+ title_fontsize=20,
140
+ edge_width=1,
141
+ font_size=9,
142
+ node_size=500,
143
+ node_shape="o",
144
+ prog="dot",
145
+ ):
146
+ graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0"
147
+ _, ax = plt.subplots(figsize=fig_size)
148
+ ax.set_title(title, fontsize=title_fontsize)
149
+ # also draw edge labels
150
+ nx.draw(
151
+ graph,
152
+ ax=ax,
153
+ with_labels=True,
154
+ # color every node lightblue except the root which is colored red
155
+ node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"])
156
+ if len(graph.nodes) > 2
157
+ else ["lightgreen", "red"]
158
+ if len(graph.nodes) == 2
159
+ else ["lightgreen"],
160
+ edge_color="gray",
161
+ width=edge_width,
162
+ font_size=font_size,
163
+ # node size to be proportional to the node's value
164
+ node_size=node_size,
165
+ # shape set to rectangle
166
+ node_shape=node_shape,
167
+ pos=nx.nx_agraph.graphviz_layout(
168
+ layout_graph, prog=prog, root="root", args=graphviz_args
169
+ ),
170
+ )
171
+ nx.draw_networkx_edge_labels(
172
+ graph,
173
+ pos=nx.nx_agraph.graphviz_layout(
174
+ layout_graph, prog=prog, root="root", args=graphviz_args
175
+ ),
176
+ edge_labels=nx.get_edge_attributes(graph, "label"),
177
+ font_size=font_size,
178
+ )
179
+ plt.show()
180
+
181
+
182
+ def get_graph(
183
+ graph: nx.Graph,
184
+ layout_graph: nx.Graph,
185
+ title="BFS Tree",
186
+ fig_size=(30, 20),
187
+ title_fontsize=20,
188
+ edge_width=1,
189
+ font_size=9,
190
+ node_size=500,
191
+ node_shape="o",
192
+ prog="dot",
193
+ ):
194
+ graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0"
195
+ fig, ax = plt.subplots(figsize=fig_size)
196
+ ax.set_title(title, fontsize=title_fontsize)
197
+ nx.draw(
198
+ graph,
199
+ ax=ax,
200
+ with_labels=True,
201
+ # color every node lightblue except the root which is colored red
202
+ node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"])
203
+ if len(graph.nodes) > 2
204
+ else ["lightgreen", "red"]
205
+ if len(graph.nodes) == 2
206
+ else ["lightgreen"],
207
+ edge_color="gray",
208
+ width=edge_width,
209
+ font_size=font_size,
210
+ # node size to be proportional to the node's value
211
+ node_size=node_size,
212
+ # shape set to rectangle
213
+ node_shape=node_shape,
214
+ pos=nx.nx_agraph.graphviz_layout(
215
+ layout_graph, prog=prog, root="root", args=graphviz_args
216
+ ),
217
+ )
218
+ nx.draw_networkx_edge_labels(
219
+ graph,
220
+ pos=nx.nx_agraph.graphviz_layout(
221
+ layout_graph, prog=prog, root="root", args=graphviz_args
222
+ ),
223
+ edge_labels=nx.get_edge_attributes(graph, "label"),
224
+ font_size=font_size,
225
+ )
226
+ return fig, ax
utils.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorama
2
+ from colorama import Fore, Style
3
+ import openai
4
+ from tenacity import retry, stop_after_attempt, wait_fixed
5
+ import json
6
+ import os
7
+ import tiktoken
8
+ import functools as ft
9
+ import time
10
+
11
+ JSON_TEMPLATE = """
12
+ {question}
13
+ The required key(s) are: {keys}.
14
+ Only and only respond with the key(s) and value(s) mentioned above.
15
+ Your answer in valid JSON format:\n
16
+ """
17
+
18
+ MODEL_COST_DICT = {
19
+ "gpt-3.5-turbo": {
20
+ "input": 0.0015,
21
+ "output": 0.002,
22
+ },
23
+ "gpt-4": {
24
+ "input": 0.03,
25
+ "output": 0.06,
26
+ },
27
+ }
28
+
29
+
30
+ def set_api_key(key=None):
31
+ """Sets the OpenAI API key."""
32
+ if key is None:
33
+ key = os.environ.get("OPENAI_API_KEY")
34
+ openai.api_key = key
35
+
36
+
37
+ def num_tokens_from_string(string: str, encoding_name: str) -> int:
38
+ """Returns the number of tokens in a text string."""
39
+ encoding = tiktoken.get_encoding(encoding_name)
40
+ num_tokens = len(encoding.encode(string))
41
+ return num_tokens
42
+
43
+
44
+ def num_tokens_from_messages(messages: list[dict], model="gpt-3.5-turbo-0613"):
45
+ """Returns the number of tokens used by a list of messages."""
46
+ try:
47
+ encoding = tiktoken.encoding_for_model(model)
48
+ except KeyError:
49
+ encoding = tiktoken.get_encoding("cl100k_base")
50
+ if model == "gpt-3.5-turbo-0613": # note: future models may deviate from this
51
+ num_tokens = 0
52
+ for message in messages:
53
+ num_tokens += (
54
+ 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
55
+ )
56
+ for key, value in message.items():
57
+ num_tokens += len(encoding.encode(value))
58
+ if key == "name": # if there's a name, the role is omitted
59
+ num_tokens += -1 # role is always required and always 1 token
60
+ num_tokens += 2 # every reply is primed with <im_start>assistant
61
+ return num_tokens
62
+ else:
63
+ raise NotImplementedError(
64
+ f"""num_tokens_from_messages() is not presently implemented for model {model}.
65
+ See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
66
+ )
67
+
68
+
69
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
70
+ def chat(messages: list[dict], model="gpt-3.5-turbo", temperature=0.0):
71
+ response = openai.ChatCompletion().create(
72
+ model=model,
73
+ messages=messages,
74
+ temperature=temperature,
75
+ )
76
+ return response["choices"][0]["message"]["content"]
77
+
78
+
79
+ def make_message(role: str, content: str) -> dict:
80
+ return {
81
+ "role": role,
82
+ "content": content,
83
+ }
84
+
85
+
86
+ def make_prompt(template: str, **kwargs):
87
+ return template.format(**kwargs)
88
+
89
+
90
+ def unravel_messages(messages: list[dict]) -> list[str]:
91
+ """Returns a string representation of a list of messages."""
92
+ return [f"{message['role']}: {message['content']}" for message in messages]
93
+
94
+
95
+ class LLM:
96
+ def __init__(self, model="gpt-3.5-turbo", temperature=0.0):
97
+ self.model = model
98
+ self.temperature = temperature
99
+ self.token_counter = 0
100
+ self.cost = 0.0
101
+
102
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
103
+ def chat(self, messages: list[dict]):
104
+ response = openai.ChatCompletion().create(
105
+ model=self.model,
106
+ messages=messages,
107
+ temperature=self.temperature,
108
+ )
109
+ self.token_counter += int(response["usage"]["total_tokens"])
110
+ self.cost += (
111
+ response["usage"]["prompt_tokens"]
112
+ / 1000
113
+ * MODEL_COST_DICT[self.model]["input"]
114
+ + response["usage"]["completion_tokens"]
115
+ / 1000
116
+ * MODEL_COST_DICT[self.model]["output"]
117
+ )
118
+ return response["choices"][0]["message"]["content"]
119
+
120
+ def reset(self):
121
+ self.token_counter = 0
122
+ self.cost = 0.0
123
+
124
+ def __call__(self, messages: list[dict]):
125
+ return self.chat(messages)
126
+
127
+
128
+ class SummaryMemory:
129
+ """
130
+ A class that manages a memory of messages and automatically summarizes them when the maximum token limit is reached.
131
+
132
+ Attributes:
133
+ max_token_limit (int): The maximum number of tokens allowed in the memory before summarization occurs.
134
+ messages (list[dict]): A list of messages in the memory.
135
+ model (str): The name of the GPT model to use for chat completion.
136
+ ai_role (str): The role of the AI in the conversation.
137
+ human_role (str): The role of the human in the conversation.
138
+ auto_summarize (bool): Whether to automatically summarize the messages when the maximum token limit is reached.
139
+ """
140
+
141
+ # ...
142
+ summary_template = "Summarize the following messages into a paragraph and replace '{user}' with '{human_role}', and '{assistant}' with '{ai_role}':\n{messages}"
143
+
144
+ def __init__(
145
+ self,
146
+ system_prompt="",
147
+ max_token_limit=4000,
148
+ model="gpt-3.5-turbo",
149
+ ai_role="answer",
150
+ human_role="question/exam",
151
+ auto_summarize=False,
152
+ ):
153
+ self.max_token_limit = max_token_limit
154
+ self.messages: list[dict] = []
155
+ self.model = model
156
+ self.ai_role = ai_role
157
+ self.human_role = human_role
158
+ self.auto_summarize = auto_summarize
159
+ self.system_prompt = system_prompt
160
+ self.reset()
161
+
162
+ def reset(self):
163
+ self.messages = [self.system_prompt]
164
+
165
+ def remove_last(self):
166
+ if len(self.messages) > 1: # don't remove the system prompt
167
+ self.messages.pop()
168
+
169
+ def remove(
170
+ self, index: int
171
+ ): # don't remove the system prompt and start counting from 1
172
+ if index > 0 and index < len(self.messages):
173
+ self.messages.pop(index)
174
+
175
+ def replace(self, index: int, message: dict):
176
+ if index > 0 and index < len(self.messages):
177
+ self.messages[index] = message
178
+
179
+ def change_system_prompt(self, new_prompt: str):
180
+ self.system_prompt = new_prompt
181
+ self.messages[0] = new_prompt
182
+
183
+ def remove_first(self):
184
+ # dont remove the system prompt
185
+ if len(self.messages) > 1:
186
+ self.messages.pop(1) # remove the first message after the system prompt
187
+
188
+ def append(self, message: dict):
189
+ total_tokens = num_tokens_from_messages(self.messages + [message])
190
+
191
+ while (
192
+ self.auto_summarize and total_tokens > self.max_token_limit
193
+ ): # keep summarizing until we're under the limit
194
+ self.summarize()
195
+ total_tokens = num_tokens_from_messages(self.messages + [message])
196
+
197
+ self.messages.append(message)
198
+
199
+ def summarize(self):
200
+ prompt = make_prompt(
201
+ self.summary_template,
202
+ user="user",
203
+ human_role=self.human_role,
204
+ assistant="assistant",
205
+ ai_role=self.ai_role,
206
+ messages="\n".join(
207
+ unravel_messages(self.messages[1:])
208
+ ), # don't include the system prompt
209
+ )
210
+ summary = chat(
211
+ messages=[make_message("user", prompt)],
212
+ model=self.model,
213
+ )
214
+ self.reset()
215
+ self.append(make_message("user", summary))
216
+
217
+ def get_messages(self):
218
+ return self.messages[1:] # don't include the system prompt
219
+
220
+ def get_unraveled_messages(self):
221
+ return unravel_messages(self.messages[1:])
222
+
223
+
224
+ class MemoryBuffer:
225
+ """
226
+ A class that manages a buffer of messages and clips them to a maximum token limit.
227
+
228
+ Attributes:
229
+ max_token_limit (int): The maximum number of tokens allowed in the buffer.
230
+ messages (list[dict]): A list of messages in the buffer.
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ system_prompt,
236
+ max_token_limit=1000,
237
+ ):
238
+ """
239
+ Initializes a new instance of the MemoryBuffer class.
240
+
241
+ Args:
242
+ max_token_limit (int, optional): The maximum number of tokens allowed in the buffer. Defaults to 1000.
243
+ """
244
+ self.max_token_limit = max_token_limit
245
+ self.messages = []
246
+ self.system_prompt = system_prompt
247
+ self.reset()
248
+
249
+ def reset(self):
250
+ """
251
+ Resets the buffer by clearing all messages.
252
+ """
253
+ self.messages = [self.system_prompt]
254
+
255
+ def add(self, message: dict):
256
+ """
257
+ Adds a message to the buffer and clips the buffer to the maximum token limit.
258
+
259
+ Args:
260
+ message (dict): The message to add to the buffer.
261
+ """
262
+ total_tokens = num_tokens_from_messages(self.messages + [message])
263
+ if total_tokens > self.max_token_limit:
264
+ # clip the messages to the max token limit
265
+ # from the end of the list
266
+ # remove messages from the beginning of the list
267
+ # until the total number of tokens is less than the max token limit
268
+ while total_tokens > self.max_token_limit:
269
+ self.messages = self.messages[1:]
270
+ total_tokens = num_tokens_from_messages(self.messages + [message])
271
+ self.messages.append(message)
272
+
273
+ def remove(self, message: dict):
274
+ """
275
+ Removes a message from the buffer.
276
+
277
+ Args:
278
+ message (dict): The message to remove from the buffer.
279
+ """
280
+ if message in self.messages:
281
+ self.messages.remove(message)
282
+
283
+ def remove_last(self):
284
+ """
285
+ Removes the last message from the buffer.
286
+ """
287
+ if len(self.messages) > 0:
288
+ self.messages.pop()
289
+
290
+ def remove_first(self):
291
+ """
292
+ Removes the first message from the buffer.
293
+ """
294
+ if len(self.messages) > 0:
295
+ self.messages.pop(0)
296
+
297
+
298
+ def json2dict(string: str) -> dict:
299
+ """Returns a dictionary of variables from a string containing JSON."""
300
+ try:
301
+ return json.loads(string)
302
+ except json.decoder.JSONDecodeError:
303
+ print("Error: JSONDecodeError")
304
+ return {}
305
+
306
+
307
+ def print_help(num_nodes, color):
308
+ """
309
+ Prints the help message for the AI assistant.
310
+ """
311
+ colorama.init()
312
+ print(color + "The AI assistant presents a clinical case and asks for a diagnosis.")
313
+ print(
314
+ color + "You need to explore the case by asking questions to the AI assistant."
315
+ )
316
+ print(
317
+ color
318
+ + "You have to ask questions in a logical order, conforming to the clinical guidelines."
319
+ )
320
+ print(
321
+ color
322
+ + "You need to minimize the number of jump between subjects, while covering as many subjects as possible."
323
+ )
324
+ print(color + f"there are a total of {num_nodes} visitable nodes in the tree")
325
+ print(
326
+ color
327
+ + "you have to explore the tree as much as possible while avoiding jumps and travelling excessively."
328
+ )
329
+ print(Style.RESET_ALL)
330
+
331
+
332
+ def make_question(template=JSON_TEMPLATE, role="user", **kwargs) -> dict:
333
+ prompt = make_prompt(template=template, **kwargs)
334
+ message = make_message(role, prompt)
335
+ return message
336
+
337
+
338
+ # a debugging decorator and use functools to preserve the function name and docstring
339
+ # the decorator gets DEBUG as an argument to turn on or off debugging
340
+ def debug(DEBUG, print_func, measure_time=True):
341
+ def decorator(func):
342
+ @ft.wraps(func)
343
+ def wrapper(*args, **kwargs):
344
+ if DEBUG:
345
+ print_func(f"\nCalling {func.__name__}")
346
+ if measure_time and DEBUG:
347
+ start = time.time()
348
+ result = func(*args, **kwargs)
349
+ if measure_time and DEBUG:
350
+ end = time.time()
351
+ print_func(f"Elapsed time: {end - start:.2f}s")
352
+ if DEBUG:
353
+ print_func(f"Returning {func.__name__}")
354
+ return result
355
+
356
+ return wrapper
357
+
358
+ return decorator
359
+
360
+
361
+ # to use the decorator, add @debug(DEBUG) above the function definition