Spaces:
Runtime error
Runtime error
import json | |
import numpy as np | |
import treegraph as tg | |
import colorama | |
from colorama import Fore | |
import networkx as nx | |
import utils | |
import re | |
import logger as lg | |
DEBUG = True | |
INPUT_COLOR = Fore.LIGHTGREEN_EX | |
DEBUG_COLOR = Fore.LIGHTBLACK_EX | |
OUTPUT_COLOR = Fore.LIGHTMAGENTA_EX | |
INFO_COLOR = Fore.BLUE | |
HELP_COLOR = Fore.CYAN | |
def print_debug(*args, color=DEBUG_COLOR): | |
""" | |
Prints debug messages if DEBUG is set to True. | |
""" | |
if DEBUG: | |
for arg in args: | |
print(color + str(arg)) | |
class ReportInterface: | |
def __init__( | |
self, | |
llm: utils.LLM, | |
system_prompt: str, | |
tree_graph: nx.Graph, | |
nodes_dict: dict[str, tg.Node], | |
api_key: str = None, | |
): | |
self.llm = llm | |
self.system_prompt = system_prompt | |
self.tree_graph = tree_graph | |
self.nodes_dict = nodes_dict | |
self.api_key = api_key | |
self.build() | |
def build(self): | |
utils.set_api_key(self.api_key) | |
self.system_prompt = utils.make_message("system", self.system_prompt) | |
self.visitable_nodes = self._get_visitable_nodes() | |
self.report_dict = self._get_report_dict() | |
self.active_node: tg.Node = self.nodes_dict["root"] | |
self.unique_visited_nodes = set() # set of nodes visited | |
self.node_journey = [] # list of nodes visited | |
self.distance_travelled = 0 # number of edges travelled | |
self.jumps = 0 # number of jumps | |
self.jump_lengths = [] # list of jump lengths | |
self.counter = 0 # number of questions asked | |
colorama.init(autoreset=True) # to reset the color after each print statement | |
self.help_message = f"""You are presented with a Knee MRI. | |
You are asked to fill out a radiology report. | |
Please only report the findings in the MRI. | |
Please mention your findings with the corresponding anatomical structures. | |
There are {len(self.visitable_nodes.keys())} visitable nodes in the tree. | |
You must visit as many nodes as possible, while avoiding too many jumps.""" | |
def _get_visitable_nodes(self): | |
return dict( | |
zip( | |
[ | |
node.name | |
for node in self.tree_graph.nodes | |
if node.name != "root" and node.has_children() is False | |
], | |
[ | |
node | |
for node in self.tree_graph.nodes | |
if node.name != "root" and node.has_children() is False | |
], | |
) | |
) | |
def _get_report_dict(self): | |
return { | |
node.name: tg.Node(node.name, "", node.children) | |
for node in self.visitable_nodes.values() | |
} | |
def _check_question_validity( | |
self, | |
question: str, | |
): | |
# let's ask the question from the model and check if it's valid | |
template_json = json.dumps( | |
{key: node.value for key, node in self.visitable_nodes.items()}, | |
indent=4, | |
) | |
q = f"""the following is a Knee MRI report "template" in a JSON format with keys and values. | |
You are given a "finding" phrase from a radiologist. | |
Match as best as possible the "finding" with one of keys in the "template". | |
<template> | |
{template_json} | |
</template> | |
<finding> | |
{question} | |
</finding> | |
"available": [Is the "finding" relevant to any key in the "template"? say "yes" or "no". | |
Make sure the "finding" is relevant to Knee MRI and knee anatomy otherwise say 'no'. | |
Do not answer irrelevant phrases.] | |
"node": [if the above answer is 'yes', write only the KEY of the most relevant node to the "finding". otherwise, say 'none'.] | |
""" | |
keys = ["available", "node"] | |
prompt = [self.system_prompt] + [ | |
utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys) | |
] | |
response = self.llm(prompt) | |
print_debug( | |
prompt, | |
response, | |
) | |
available = utils.json2dict(response)["available"].strip().lower() | |
node = utils.json2dict(response)["node"] | |
return available, node | |
def _update_node(self, node_name, findings): | |
self.report_dict[node_name].value += str(findings) + "\n" | |
response = f"Updated node '{node_name}' with finding '{findings}'" | |
print(OUTPUT_COLOR + response) | |
return response | |
def save_report(self, filename: str): | |
# convert performance metrics to json | |
metrics = { | |
"distance_travelled": self.distance_travelled, | |
"jumps": self.jumps, | |
"jump_lengths": self.jump_lengths, | |
"unique_visited_nodes": [node.name for node in self.unique_visited_nodes], | |
"node_journey": [node.name for node in self.node_journey], | |
"report": { | |
node_name: node.value for node_name, node in self.report_dict.items() | |
}, | |
} | |
# save the report | |
with open(filename, "w") as file: | |
json.dump(metrics, file, indent=4) | |
def prime_model(self): | |
""" | |
Primes the model with the system prompt. | |
""" | |
q = "Are you ready to begin?\nSay 'yes' or 'no'." | |
keys = ["answer"] | |
response = self.llm( | |
[ | |
self.system_prompt, | |
utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys), | |
], | |
) | |
print_debug(q, response) | |
if utils.json2dict(response)["answer"].lower() == "yes": | |
print(INFO_COLOR + "The model is ready.") | |
return True | |
else: | |
print(INFO_COLOR + "The model is not ready.") | |
return False | |
def performance_summary(self): | |
# print out the summary info | |
print(INFO_COLOR + "Performance Summary:") | |
print( | |
INFO_COLOR + f"Total distance travelled: {self.distance_travelled} edge(s)" | |
) | |
print(INFO_COLOR + f"Jump lengths: {self.jump_lengths}") | |
print(INFO_COLOR + f"Jump lengths mean: {np.mean(self.jump_lengths):.1f}") | |
print(INFO_COLOR + f"Jump lengths SD: {np.std(self.jump_lengths):.1f}") | |
print(INFO_COLOR + f"Nodes visited in order: {self.node_journey}") | |
print(INFO_COLOR + f"Unique nodes visited: {self.unique_visited_nodes}") | |
print( | |
INFO_COLOR | |
+ f"You have explored {len(self.unique_visited_nodes)/len(self.visitable_nodes):.1%} ({len(self.unique_visited_nodes)}/{len(self.visitable_nodes)}) of the tree." | |
) | |
print_debug("\n") | |
print_debug("Report Summary:".rjust(20)) | |
for name, node in self.report_dict.items(): | |
if node.value != "": | |
print_debug(f"{name}: {node.value}") | |
print(INFO_COLOR + f"total cost: ${self.llm.cost:.4f}") | |
print(INFO_COLOR + f"total tokens used: {self.llm.token_counter}") | |
def get_stats(self): | |
report_string = "" | |
for name, node in self.report_dict.items(): | |
if node.value != "": | |
report_string += f"{name}: <{node.value}> \n" | |
return { | |
"Lengths travelled": self.distance_travelled, | |
"Number of jumps": self.jumps, | |
"Jump lengths": self.jump_lengths, | |
"Unique nodes visited": [node.name for node in self.unique_visited_nodes], | |
"Visited Nodes": [node.name for node in self.node_journey], | |
"Report": report_string, | |
} | |
def visualize_tree(self, **kwargs): | |
tg.visualize_graph(tg.from_list(self.node_journey), self.tree_graph, **kwargs) | |
def get_plot(self, **kwargs): | |
return tg.get_graph(tg.from_list(self.node_journey), self.tree_graph, **kwargs) | |
def process_input(self, input_text: str): | |
res = "n/a" | |
try: | |
finding = input_text | |
if finding.strip().lower() == "quit": | |
print(INFO_COLOR + "Exiting...") | |
return "quit" | |
elif finding.strip().lower() == "help": | |
return "help" | |
available, node = self._check_question_validity(finding) | |
if available != "yes": | |
print( | |
OUTPUT_COLOR | |
+ "Could not find a relevant node.\nWrite more clearly and provide more details." | |
) | |
return "n/a" | |
if node not in self.visitable_nodes.keys(): | |
print( | |
OUTPUT_COLOR | |
+ "Could not find a relevant node.\nWrite more clearly and provide more details." | |
) | |
return "n/a" | |
else: | |
# modify the tree to update the node with findings | |
res = self._update_node(node, finding) | |
print( | |
INFO_COLOR | |
+ f"jumping from node '{self.active_node}' to node '{node}'..." | |
) | |
distance = tg.num_edges_between_nodes( | |
self.tree_graph, self.active_node, self.nodes_dict[node] | |
) | |
print(INFO_COLOR + f"distance travelled: {distance} edge(s)") | |
self.active_node = self.nodes_dict[node] | |
self.jumps += 1 | |
self.jump_lengths.append(distance) | |
self.distance_travelled += distance | |
if self.active_node.name != "root": | |
self.unique_visited_nodes.add(self.active_node) | |
self.node_journey.append(self.active_node) | |
except Exception as ex: | |
print_debug(ex, color=Fore.LIGHTRED_EX) | |
return "exception" | |
self.counter += 1 | |
try: | |
self.performance_summary() | |
except Exception as ex: | |
print_debug(ex, color=Fore.LIGHTRED_EX) | |
return res | |
class ReportChecklistInterface: | |
def __init__( | |
self, | |
llm: utils.LLM, | |
system_prompt: str, | |
graph: nx.Graph, | |
nodes_dict: dict[str, tg.Node], | |
api_key: str = None, | |
logger: lg.Logger = None, | |
username: str = None, | |
): | |
self.llm = llm | |
self.system_prompt = system_prompt | |
self.tree_graph: nx.Graph = graph | |
self.nodes_dict = nodes_dict | |
self.api_key = api_key | |
self.logger = logger | |
self.username = username | |
self.build() | |
def build(self): | |
utils.set_api_key(self.api_key) | |
self.system_prompt = utils.make_message("system", self.system_prompt) | |
self.visitable_nodes = self._get_visitable_nodes() | |
colorama.init(autoreset=True) # to reset the color after each print statement | |
self.help_message = f"""You are presented with a Knee MRI. | |
You are asked to fill out a radiology report. | |
Please only report the findings in the MRI. | |
Please mention your findings with the corresponding anatomical structures. | |
There are {len(self.visitable_nodes.keys())} visitable nodes in the tree.""" | |
def _get_visitable_nodes(self): | |
return dict( | |
zip( | |
[ | |
node.name | |
for node in self.tree_graph.nodes | |
if node.name != "root" and node.has_children() is False | |
], | |
[ | |
node | |
for node in self.tree_graph.nodes | |
if node.name != "root" and node.has_children() is False | |
], | |
) | |
) | |
def _check_report( | |
self, | |
report: str, | |
): | |
# let's ask the question from the model and check if it's valid | |
checklist_json = json.dumps( | |
{key: node.value for key, node in self.visitable_nodes.items()}, | |
indent=4, | |
) | |
q = f"""the following is a Knee MRI "checklist" in JSON format with keys as items and values as findings: | |
A knee MRI "report" is also provided in raw text format written by a radiologist: | |
<checklist> | |
{checklist_json} | |
</checklist> | |
<report> | |
{report} | |
</report> | |
Your task is to find all the corresponding items from the "checklist" in the "report" and fill out a JSON with the same keys as the "checklist" but extract the corresponding values from the "report". | |
If a key is not found in the "report", please set the value to "n/a", otherwise set it to the corresponding finding from the "report". | |
You must check the "report" phrases one by one and find a corresponding key(s) for EACH phrase in the "report" from the "checklist" and fill out the "report_checked" JSON. | |
Try to fill out as many items as possible. | |
ALL of the items in the "checklist" must be filled out. | |
Don't generate findings that are not present in the "report" (new findings). | |
Be comprehensive and don't miss any findings that are present in the "report". | |
Watch out for encompassing terms (e.g., "cruciate ligaments" means both "ACL" and "PCL"). | |
"thought_process": [Think in steps on how you would do this task.] | |
"report_ckecked" : [a JSON with the same keys as the "checklist" but take the values from the "report", as described above.] | |
""" | |
keys = ["thought_process", "report_checked"] | |
prompt = [self.system_prompt] + [ | |
utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys) | |
] | |
response = self.llm(prompt) | |
print_debug( | |
prompt, | |
response, | |
) | |
if self.logger: | |
# set name to class name | |
self.logger( | |
name=self.__class__.__name__, | |
message=f"prompt: {prompt}\nresponse: {response}", | |
) | |
report_checked = utils.json2dict(response) | |
return report_checked["report_checked"] | |
def prime_model(self): | |
""" | |
Primes the model with the system prompt. | |
""" | |
q = "Are you ready to begin?\nSay 'yes' or 'no'." | |
keys = ["answer"] | |
response = self.llm( | |
[ | |
self.system_prompt, | |
utils.make_question(utils.JSON_TEMPLATE, question=q, keys=keys), | |
], | |
) | |
print_debug(q, response) | |
if utils.json2dict(response)["answer"].lower() == "yes": | |
print(INFO_COLOR + "The model is ready.") | |
return True | |
else: | |
print(INFO_COLOR + "The model is not ready.") | |
return False | |
def process_input(self, input_text: str): | |
try: | |
report = input_text | |
if self.logger: | |
self.logger(self.username, f"report: {report}") | |
if report.strip().lower() == "quit": | |
print(INFO_COLOR + "Exiting...") | |
if self.logger: | |
self.logger(self.username, "Exiting...") | |
return "quit" | |
elif report.strip().lower() == "help": | |
if self.logger: | |
self.logger(self.username, "Help") | |
return "help" | |
checked_report: dict = self._check_report(report) | |
# make a string of the report | |
# replace true with [checkmark emoji] and false with [cross emoji] | |
report_string = "" | |
CHECKMARK = "\u2705" | |
CROSS = "\u274C" | |
# we need a regex to convert the camelCase keys to a readable format | |
def camel2readable(camel: str): | |
string = re.sub("([a-z])([A-Z])", r"\1 \2", camel) | |
# captialize every word | |
string = " ".join([word.capitalize() for word in string.split()]) | |
return string | |
for key, value in checked_report.items(): | |
if str(value).lower() == "n/a": | |
report_string += f"{camel2readable(key)}: {CROSS}\n" | |
else: | |
report_string += f"{camel2readable(key)}: <{value}> {CHECKMARK}\n" | |
portion_visited: float = report_string.count(CHECKMARK) / len( | |
checked_report.keys() | |
) | |
report_string += f"Portion of the checklist visited: {portion_visited:.1%}" | |
if self.logger: | |
self.logger(self.__class__.__name__, report_string) | |
return report_string | |
except Exception as ex: | |
print_debug(ex, color=Fore.LIGHTRED_EX) | |
if self.logger: | |
self.logger(self.__class__.__name__, "Exception: " + ex) | |
return "exception" | |