Spaces:
Runtime error
Runtime error
| import json | |
| import numpy as np | |
| import treegraph as tg | |
| import colorama | |
| from colorama import Fore | |
| import networkx as nx | |
| import utils | |
| import re | |
| import logger as lg | |
| DEBUG = True | |
| INPUT_COLOR = Fore.LIGHTGREEN_EX | |
| DEBUG_COLOR = Fore.LIGHTBLACK_EX | |
| OUTPUT_COLOR = Fore.LIGHTMAGENTA_EX | |
| INFO_COLOR = Fore.BLUE | |
| HELP_COLOR = Fore.CYAN | |
| def print_debug(*args, color=DEBUG_COLOR): | |
| """ | |
| Prints debug messages if DEBUG is set to True. | |
| """ | |
| if DEBUG: | |
| for arg in args: | |
| print(color + str(arg)) | |
| class ReportInterface: | |
| def __init__( | |
| self, | |
| llm: utils.LLM, | |
| system_prompt: str, | |
| tree_graph: nx.Graph, | |
| nodes_dict: dict[str, tg.Node], | |
| api_key: str = None, | |
| ): | |
| self.llm = llm | |
| self.system_prompt = system_prompt | |
| self.tree_graph = tree_graph | |
| self.nodes_dict = nodes_dict | |
| self.api_key = api_key | |
| self.build() | |
| def build(self): | |
| utils.set_api_key(self.api_key) | |
| self.system_prompt = utils.make_message("system", self.system_prompt) | |
| self.visitable_nodes = self._get_visitable_nodes() | |
| self.report_dict = self._get_report_dict() | |
| self.active_node: tg.Node = self.nodes_dict["root"] | |
| self.unique_visited_nodes = set() # set of nodes visited | |
| self.node_journey = [] # list of nodes visited | |
| self.distance_travelled = 0 # number of edges travelled | |
| self.jumps = 0 # number of jumps | |
| self.jump_lengths = [] # list of jump lengths | |
| self.counter = 0 # number of questions asked | |
| colorama.init(autoreset=True) # to reset the color after each print statement | |
| self.help_message = f"""You are presented with a Knee MRI. | |
| You are asked to fill out a radiology report. | |
| Please only report the findings in the MRI. | |
| Please mention your findings with the corresponding anatomical structures. | |
| There are {len(self.visitable_nodes.keys())} visitable nodes in the tree. | |
| You must visit as many nodes as possible, while avoiding too many jumps.""" | |
| def _get_visitable_nodes(self): | |
| return dict( | |
| zip( | |
| [ | |
| node.name | |
| for node in self.tree_graph.nodes | |
| if node.name != "root" and node.has_children() is False | |
| ], | |
| [ | |
| node | |
| for node in self.tree_graph.nodes | |
| if node.name != "root" and node.has_children() is False | |
| ], | |
| ) | |
| ) | |
| def _get_report_dict(self): | |
| return { | |
| node.name: tg.Node(node.name, "", node.children) | |
| for node in self.visitable_nodes.values() | |
| } | |
| def _check_question_validity( | |
| self, | |
| question: str, | |
| ): | |
| # let's ask the question from the model and check if it's valid | |
| template_json = json.dumps( | |
| {key: node.value for key, node in self.visitable_nodes.items()}, | |
| indent=4, | |
| ) | |
| q = f"""the following is a Knee MRI report "template" in a JSON format with keys and values. | |
| You are given a "finding" phrase from a radiologist. | |
| Match as best as possible the "finding" with one of keys in the "template". | |
| <template> | |
| {template_json} | |
| </template> | |
| <finding> | |
| {question} | |
| </finding> | |
| "available": [Is the "finding" relevant to any key in the "template"? say "yes" or "no". | |
| Make sure the "finding" is relevant to Knee MRI and knee anatomy otherwise say 'no'. | |
| Do not answer irrelevant phrases.] | |
| "node": [if the above answer is 'yes', write only the KEY of the most relevant node to the "finding". otherwise, say 'none'.] | |
| """ | |
| keys = ["available", "node"] | |
| prompt = [self.system_prompt] + [ | |
| utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys) | |
| ] | |
| response = self.llm(prompt) | |
| print_debug( | |
| prompt, | |
| response, | |
| ) | |
| available = utils.json2dict(response)["available"].strip().lower() | |
| node = utils.json2dict(response)["node"] | |
| return available, node | |
| def _update_node(self, node_name, findings): | |
| self.report_dict[node_name].value += str(findings) + "\n" | |
| response = f"Updated node '{node_name}' with finding '{findings}'" | |
| print(OUTPUT_COLOR + response) | |
| return response | |
| def save_report(self, filename: str): | |
| # convert performance metrics to json | |
| metrics = { | |
| "distance_travelled": self.distance_travelled, | |
| "jumps": self.jumps, | |
| "jump_lengths": self.jump_lengths, | |
| "unique_visited_nodes": [node.name for node in self.unique_visited_nodes], | |
| "node_journey": [node.name for node in self.node_journey], | |
| "report": { | |
| node_name: node.value for node_name, node in self.report_dict.items() | |
| }, | |
| } | |
| # save the report | |
| with open(filename, "w") as file: | |
| json.dump(metrics, file, indent=4) | |
| def prime_model(self): | |
| """ | |
| Primes the model with the system prompt. | |
| """ | |
| q = "Are you ready to begin?\nSay 'yes' or 'no'." | |
| keys = ["answer"] | |
| response = self.llm( | |
| [ | |
| self.system_prompt, | |
| utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys), | |
| ], | |
| ) | |
| print_debug(q, response) | |
| if utils.json2dict(response)["answer"].lower() == "yes": | |
| print(INFO_COLOR + "The model is ready.") | |
| return True | |
| else: | |
| print(INFO_COLOR + "The model is not ready.") | |
| return False | |
| def performance_summary(self): | |
| # print out the summary info | |
| print(INFO_COLOR + "Performance Summary:") | |
| print( | |
| INFO_COLOR + f"Total distance travelled: {self.distance_travelled} edge(s)" | |
| ) | |
| print(INFO_COLOR + f"Jump lengths: {self.jump_lengths}") | |
| print(INFO_COLOR + f"Jump lengths mean: {np.mean(self.jump_lengths):.1f}") | |
| print(INFO_COLOR + f"Jump lengths SD: {np.std(self.jump_lengths):.1f}") | |
| print(INFO_COLOR + f"Nodes visited in order: {self.node_journey}") | |
| print(INFO_COLOR + f"Unique nodes visited: {self.unique_visited_nodes}") | |
| print( | |
| INFO_COLOR | |
| + f"You have explored {len(self.unique_visited_nodes)/len(self.visitable_nodes):.1%} ({len(self.unique_visited_nodes)}/{len(self.visitable_nodes)}) of the tree." | |
| ) | |
| print_debug("\n") | |
| print_debug("Report Summary:".rjust(20)) | |
| for name, node in self.report_dict.items(): | |
| if node.value != "": | |
| print_debug(f"{name}: {node.value}") | |
| print(INFO_COLOR + f"total cost: ${self.llm.cost:.4f}") | |
| print(INFO_COLOR + f"total tokens used: {self.llm.token_counter}") | |
| def get_stats(self): | |
| report_string = "" | |
| for name, node in self.report_dict.items(): | |
| if node.value != "": | |
| report_string += f"{name}: <{node.value}> \n" | |
| return { | |
| "Lengths travelled": self.distance_travelled, | |
| "Number of jumps": self.jumps, | |
| "Jump lengths": self.jump_lengths, | |
| "Unique nodes visited": [node.name for node in self.unique_visited_nodes], | |
| "Visited Nodes": [node.name for node in self.node_journey], | |
| "Report": report_string, | |
| } | |
| def visualize_tree(self, **kwargs): | |
| tg.visualize_graph(tg.from_list(self.node_journey), self.tree_graph, **kwargs) | |
| def get_plot(self, **kwargs): | |
| return tg.get_graph(tg.from_list(self.node_journey), self.tree_graph, **kwargs) | |
| def process_input(self, input_text: str): | |
| res = "n/a" | |
| try: | |
| finding = input_text | |
| if finding.strip().lower() == "quit": | |
| print(INFO_COLOR + "Exiting...") | |
| return "quit" | |
| elif finding.strip().lower() == "help": | |
| return "help" | |
| available, node = self._check_question_validity(finding) | |
| if available != "yes": | |
| print( | |
| OUTPUT_COLOR | |
| + "Could not find a relevant node.\nWrite more clearly and provide more details." | |
| ) | |
| return "n/a" | |
| if node not in self.visitable_nodes.keys(): | |
| print( | |
| OUTPUT_COLOR | |
| + "Could not find a relevant node.\nWrite more clearly and provide more details." | |
| ) | |
| return "n/a" | |
| else: | |
| # modify the tree to update the node with findings | |
| res = self._update_node(node, finding) | |
| print( | |
| INFO_COLOR | |
| + f"jumping from node '{self.active_node}' to node '{node}'..." | |
| ) | |
| distance = tg.num_edges_between_nodes( | |
| self.tree_graph, self.active_node, self.nodes_dict[node] | |
| ) | |
| print(INFO_COLOR + f"distance travelled: {distance} edge(s)") | |
| self.active_node = self.nodes_dict[node] | |
| self.jumps += 1 | |
| self.jump_lengths.append(distance) | |
| self.distance_travelled += distance | |
| if self.active_node.name != "root": | |
| self.unique_visited_nodes.add(self.active_node) | |
| self.node_journey.append(self.active_node) | |
| except Exception as ex: | |
| print_debug(ex, color=Fore.LIGHTRED_EX) | |
| return "exception" | |
| self.counter += 1 | |
| try: | |
| self.performance_summary() | |
| except Exception as ex: | |
| print_debug(ex, color=Fore.LIGHTRED_EX) | |
| return res | |
| class ReportChecklistInterface: | |
| def __init__( | |
| self, | |
| llm: utils.LLM, | |
| system_prompt: str, | |
| graph: nx.Graph, | |
| nodes_dict: dict[str, tg.Node], | |
| api_key: str = None, | |
| logger: lg.Logger = None, | |
| username: str = None, | |
| ): | |
| self.llm = llm | |
| self.system_prompt = system_prompt | |
| self.tree_graph: nx.Graph = graph | |
| self.nodes_dict = nodes_dict | |
| self.api_key = api_key | |
| self.logger = logger | |
| self.username = username | |
| self.build() | |
| def build(self): | |
| utils.set_api_key(self.api_key) | |
| self.system_prompt = utils.make_message("system", self.system_prompt) | |
| self.visitable_nodes = self._get_visitable_nodes() | |
| colorama.init(autoreset=True) # to reset the color after each print statement | |
| self.help_message = f"""You are presented with a Knee MRI. | |
| You are asked to fill out a radiology report. | |
| Please only report the findings in the MRI. | |
| Please mention your findings with the corresponding anatomical structures. | |
| There are {len(self.visitable_nodes.keys())} visitable nodes in the tree.""" | |
| def _get_visitable_nodes(self): | |
| return dict( | |
| zip( | |
| [ | |
| node.name | |
| for node in self.tree_graph.nodes | |
| if node.name != "root" and node.has_children() is False | |
| ], | |
| [ | |
| node | |
| for node in self.tree_graph.nodes | |
| if node.name != "root" and node.has_children() is False | |
| ], | |
| ) | |
| ) | |
| def _check_report( | |
| self, | |
| report: str, | |
| ): | |
| # let's ask the question from the model and check if it's valid | |
| checklist_json = json.dumps( | |
| {key: node.value for key, node in self.visitable_nodes.items()}, | |
| indent=4, | |
| ) | |
| q = f"""the following is a Knee MRI "checklist" in JSON format with keys as items and values as findings: | |
| A knee MRI "report" is also provided in raw text format written by a radiologist: | |
| <checklist> | |
| {checklist_json} | |
| </checklist> | |
| <report> | |
| {report} | |
| </report> | |
| Your task is to find all the corresponding items from the "checklist" in the "report" and fill out a JSON with the same keys as the "checklist" but extract the corresponding values from the "report". | |
| If a key is not found in the "report", please set the value to "n/a", otherwise set it to the corresponding finding from the "report". | |
| You must check the "report" phrases one by one and find a corresponding key(s) for EACH phrase in the "report" from the "checklist" and fill out the "report_checked" JSON. | |
| Try to fill out as many items as possible. | |
| ALL of the items in the "checklist" must be filled out. | |
| Don't generate findings that are not present in the "report" (new findings). | |
| Be comprehensive and don't miss any findings that are present in the "report". | |
| Watch out for encompassing terms (e.g., "cruciate ligaments" means both "ACL" and "PCL"). | |
| "thought_process": [Think in steps on how you would do this task.] | |
| "report_ckecked" : [a JSON with the same keys as the "checklist" but take the values from the "report", as described above.] | |
| """ | |
| keys = ["thought_process", "report_checked"] | |
| prompt = [self.system_prompt] + [ | |
| utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys) | |
| ] | |
| response = self.llm(prompt) | |
| print_debug( | |
| prompt, | |
| response, | |
| ) | |
| if self.logger: | |
| # set name to class name | |
| self.logger( | |
| name=self.__class__.__name__, | |
| message=f"prompt: {prompt}\nresponse: {response}", | |
| ) | |
| report_checked = utils.json2dict(response) | |
| return report_checked["report_checked"] | |
| def prime_model(self): | |
| """ | |
| Primes the model with the system prompt. | |
| """ | |
| q = "Are you ready to begin?\nSay 'yes' or 'no'." | |
| keys = ["answer"] | |
| response = self.llm( | |
| [ | |
| self.system_prompt, | |
| utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys), | |
| ], | |
| ) | |
| print_debug(q, response) | |
| if utils.json2dict(response)["answer"].lower() == "yes": | |
| print(INFO_COLOR + "The model is ready.") | |
| return True | |
| else: | |
| print(INFO_COLOR + "The model is not ready.") | |
| return False | |
| def process_input(self, input_text: str): | |
| try: | |
| report = input_text | |
| if self.logger: | |
| self.logger(self.username, f"report: {report}") | |
| if report.strip().lower() == "quit": | |
| print(INFO_COLOR + "Exiting...") | |
| if self.logger: | |
| self.logger(self.username, "Exiting...") | |
| return "quit" | |
| elif report.strip().lower() == "help": | |
| if self.logger: | |
| self.logger(self.username, "Help") | |
| return "help" | |
| checked_report: dict = self._check_report(report) | |
| # make a string of the report | |
| # replace true with [checkmark emoji] and false with [cross emoji] | |
| report_string = "" | |
| CHECKMARK = "\u2705" | |
| CROSS = "\u274C" | |
| # we need a regex to convert the camelCase keys to a readable format | |
| def camel2readable(camel: str): | |
| string = re.sub("([a-z])([A-Z])", r"\1 \2", camel) | |
| # captialize every word | |
| string = " ".join([word.capitalize() for word in string.split()]) | |
| return string | |
| for key, value in checked_report.items(): | |
| if str(value).lower() == "n/a": | |
| report_string += f"{camel2readable(key)}: {CROSS}\n" | |
| else: | |
| report_string += f"{camel2readable(key)}: <{value}> {CHECKMARK}\n" | |
| portion_visited: float = report_string.count(CHECKMARK) / len( | |
| checked_report.keys() | |
| ) | |
| report_string += f"Portion of the checklist visited: {portion_visited:.1%}" | |
| if self.logger: | |
| self.logger(self.__class__.__name__, report_string) | |
| return report_string | |
| except Exception as ex: | |
| print_debug(ex, color=Fore.LIGHTRED_EX) | |
| if self.logger: | |
| self.logger(self.__class__.__name__, "Exception: " + ex) | |
| return "exception" | |