|
import base64
|
|
import io
|
|
import logging
|
|
|
|
import cv2
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def plot_image_prediction(image, predictions, title=None, figsize=(10, 8)):
|
|
"""
|
|
Plot an image with its predictions.
|
|
|
|
Args:
|
|
image (PIL.Image or str): Image or path to image
|
|
predictions (list): List of (label, probability) tuples
|
|
title (str, optional): Plot title
|
|
figsize (tuple): Figure size
|
|
|
|
Returns:
|
|
matplotlib.figure.Figure: The figure object
|
|
"""
|
|
try:
|
|
|
|
if isinstance(image, str):
|
|
img = Image.open(image)
|
|
else:
|
|
img = image
|
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
|
|
|
|
|
ax1.imshow(img)
|
|
ax1.set_title("X-ray Image")
|
|
ax1.axis("off")
|
|
|
|
|
|
if predictions:
|
|
|
|
sorted_pred = sorted(predictions, key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
top_n = min(5, len(sorted_pred))
|
|
labels = [pred[0] for pred in sorted_pred[:top_n]]
|
|
probs = [pred[1] for pred in sorted_pred[:top_n]]
|
|
|
|
|
|
y_pos = np.arange(top_n)
|
|
ax2.barh(y_pos, probs, align="center")
|
|
ax2.set_yticks(y_pos)
|
|
ax2.set_yticklabels(labels)
|
|
ax2.set_xlabel("Probability")
|
|
ax2.set_title("Top Predictions")
|
|
ax2.set_xlim(0, 1)
|
|
|
|
|
|
for i, prob in enumerate(probs):
|
|
ax2.text(prob + 0.02, i, f"{prob:.1%}", va="center")
|
|
|
|
|
|
if title:
|
|
fig.suptitle(title, fontsize=16)
|
|
|
|
fig.tight_layout()
|
|
return fig
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error plotting image prediction: {e}")
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 6))
|
|
ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
|
|
return fig
|
|
|
|
|
|
def create_heatmap_overlay(image, heatmap, alpha=0.4):
|
|
"""
|
|
Create a heatmap overlay on an X-ray image to highlight areas of interest.
|
|
|
|
Args:
|
|
image (PIL.Image or str): Image or path to image
|
|
heatmap (numpy.ndarray): Heatmap array
|
|
alpha (float): Transparency of the overlay
|
|
|
|
Returns:
|
|
PIL.Image: Image with heatmap overlay
|
|
"""
|
|
try:
|
|
|
|
if isinstance(image, str):
|
|
img = cv2.imread(image)
|
|
if img is None:
|
|
raise ValueError(f"Could not load image: {image}")
|
|
elif isinstance(image, Image.Image):
|
|
img = np.array(image)
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
else:
|
|
img = image
|
|
|
|
|
|
if len(img.shape) == 2:
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
|
|
|
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
|
|
|
|
|
|
heatmap = np.maximum(heatmap, 0)
|
|
heatmap = np.minimum(heatmap / np.max(heatmap), 1)
|
|
|
|
|
|
heatmap = np.uint8(255 * heatmap)
|
|
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
|
|
|
|
|
overlay = cv2.addWeighted(img, 1 - alpha, heatmap, alpha, 0)
|
|
|
|
|
|
overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
|
|
overlay_img = Image.fromarray(overlay)
|
|
|
|
return overlay_img
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating heatmap overlay: {e}")
|
|
|
|
if isinstance(image, str):
|
|
return Image.open(image)
|
|
elif isinstance(image, Image.Image):
|
|
return image
|
|
else:
|
|
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
|
|
|
|
|
def plot_report_entities(text, entities, figsize=(12, 8)):
|
|
"""
|
|
Visualize entities extracted from a medical report.
|
|
|
|
Args:
|
|
text (str): Report text
|
|
entities (dict): Dictionary of entities by category
|
|
figsize (tuple): Figure size
|
|
|
|
Returns:
|
|
matplotlib.figure.Figure: The figure object
|
|
"""
|
|
try:
|
|
fig, ax = plt.subplots(figsize=figsize)
|
|
ax.axis("off")
|
|
|
|
|
|
fig.patch.set_facecolor("#f8f9fa")
|
|
ax.set_facecolor("#f8f9fa")
|
|
|
|
|
|
ax.text(
|
|
0.5,
|
|
0.98,
|
|
"Medical Report Analysis",
|
|
ha="center",
|
|
va="top",
|
|
fontsize=18,
|
|
fontweight="bold",
|
|
color="#2c3e50",
|
|
)
|
|
|
|
|
|
y_pos = 0.9
|
|
ax.text(
|
|
0.05,
|
|
y_pos,
|
|
"Extracted Entities:",
|
|
fontsize=14,
|
|
fontweight="bold",
|
|
color="#2c3e50",
|
|
)
|
|
y_pos -= 0.05
|
|
|
|
|
|
category_colors = {
|
|
"problem": "#e74c3c",
|
|
"test": "#3498db",
|
|
"treatment": "#2ecc71",
|
|
"anatomy": "#9b59b6",
|
|
}
|
|
|
|
|
|
for category, items in entities.items():
|
|
if items:
|
|
y_pos -= 0.05
|
|
ax.text(
|
|
0.1,
|
|
y_pos,
|
|
f"{category.capitalize()}:",
|
|
fontsize=12,
|
|
fontweight="bold",
|
|
)
|
|
y_pos -= 0.05
|
|
ax.text(
|
|
0.15,
|
|
y_pos,
|
|
", ".join(items),
|
|
wrap=True,
|
|
fontsize=11,
|
|
color=category_colors.get(category, "black"),
|
|
)
|
|
|
|
|
|
y_pos -= 0.1
|
|
ax.text(
|
|
0.05,
|
|
y_pos,
|
|
"Report Text (with highlighted entities):",
|
|
fontsize=14,
|
|
fontweight="bold",
|
|
color="#2c3e50",
|
|
)
|
|
y_pos -= 0.05
|
|
|
|
|
|
all_entities = []
|
|
for category, items in entities.items():
|
|
for item in items:
|
|
all_entities.append((item, category))
|
|
|
|
|
|
all_entities.sort(key=lambda x: len(x[0]), reverse=True)
|
|
|
|
|
|
highlighted_text = text
|
|
for entity, category in all_entities:
|
|
|
|
entity_escaped = (
|
|
entity.replace("(", r"\(")
|
|
.replace(")", r"\)")
|
|
.replace("[", r"\[")
|
|
.replace("]", r"\]")
|
|
)
|
|
|
|
|
|
pattern = r"\b" + entity_escaped + r"\b"
|
|
color_code = category_colors.get(category, "black")
|
|
replacement = f"\\textcolor{{{color_code}}}{{{entity}}}"
|
|
highlighted_text = highlighted_text.replace(entity, replacement)
|
|
|
|
|
|
ax.text(0.05, y_pos, highlighted_text, va="top", fontsize=10, wrap=True)
|
|
|
|
fig.tight_layout(rect=[0, 0.03, 1, 0.97])
|
|
return fig
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error plotting report entities: {e}")
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 6))
|
|
ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
|
|
return fig
|
|
|
|
|
|
def plot_multimodal_results(
|
|
fused_results, image=None, report_text=None, figsize=(12, 10)
|
|
):
|
|
"""
|
|
Visualize the results of multimodal analysis.
|
|
|
|
Args:
|
|
fused_results (dict): Results from multimodal fusion
|
|
image (PIL.Image or str, optional): Image or path to image
|
|
report_text (str, optional): Report text
|
|
figsize (tuple): Figure size
|
|
|
|
Returns:
|
|
matplotlib.figure.Figure: The figure object
|
|
"""
|
|
try:
|
|
|
|
fig = plt.figure(figsize=figsize)
|
|
gs = fig.add_gridspec(2, 2)
|
|
|
|
|
|
fig.suptitle(
|
|
"Multimodal Medical Analysis Results",
|
|
fontsize=18,
|
|
fontweight="bold",
|
|
y=0.98,
|
|
)
|
|
|
|
|
|
ax_overview = fig.add_subplot(gs[0, 0])
|
|
ax_overview.axis("off")
|
|
|
|
|
|
severity = fused_results.get("severity", {})
|
|
severity_level = severity.get("level", "Unknown")
|
|
severity_score = severity.get("score", 0)
|
|
|
|
|
|
primary_finding = fused_results.get("primary_finding", "Unknown")
|
|
|
|
|
|
agreement = fused_results.get("agreement_score", 0)
|
|
|
|
|
|
overview_text = [
|
|
"ANALYSIS OVERVIEW",
|
|
f"Primary Finding: {primary_finding}",
|
|
f"Severity Level: {severity_level} ({severity_score}/4)",
|
|
f"Agreement Score: {agreement:.0%}",
|
|
]
|
|
|
|
|
|
severity_colors = {
|
|
"Normal": "#2ecc71",
|
|
"Mild": "#3498db",
|
|
"Moderate": "#f39c12",
|
|
"Severe": "#e74c3c",
|
|
"Critical": "#c0392b",
|
|
}
|
|
|
|
|
|
y_pos = 0.9
|
|
ax_overview.text(
|
|
0.5,
|
|
y_pos,
|
|
overview_text[0],
|
|
fontsize=14,
|
|
fontweight="bold",
|
|
ha="center",
|
|
va="center",
|
|
)
|
|
y_pos -= 0.15
|
|
|
|
ax_overview.text(
|
|
0.1, y_pos, overview_text[1], fontsize=12, ha="left", va="center"
|
|
)
|
|
y_pos -= 0.1
|
|
|
|
|
|
severity_color = severity_colors.get(severity_level, "black")
|
|
ax_overview.text(
|
|
0.1, y_pos, "Severity Level:", fontsize=12, ha="left", va="center"
|
|
)
|
|
ax_overview.text(
|
|
0.4,
|
|
y_pos,
|
|
severity_level,
|
|
fontsize=12,
|
|
color=severity_color,
|
|
fontweight="bold",
|
|
ha="left",
|
|
va="center",
|
|
)
|
|
ax_overview.text(
|
|
0.6, y_pos, f"({severity_score}/4)", fontsize=10, ha="left", va="center"
|
|
)
|
|
y_pos -= 0.1
|
|
|
|
|
|
agreement_color = (
|
|
"#2ecc71"
|
|
if agreement > 0.7
|
|
else "#f39c12"
|
|
if agreement > 0.4
|
|
else "#e74c3c"
|
|
)
|
|
ax_overview.text(
|
|
0.1, y_pos, "Agreement Score:", fontsize=12, ha="left", va="center"
|
|
)
|
|
ax_overview.text(
|
|
0.4,
|
|
y_pos,
|
|
f"{agreement:.0%}",
|
|
fontsize=12,
|
|
color=agreement_color,
|
|
fontweight="bold",
|
|
ha="left",
|
|
va="center",
|
|
)
|
|
|
|
|
|
ax_findings = fig.add_subplot(gs[0, 1])
|
|
ax_findings.axis("off")
|
|
|
|
|
|
findings = fused_results.get("findings", [])
|
|
|
|
|
|
y_pos = 0.9
|
|
ax_findings.text(
|
|
0.5,
|
|
y_pos,
|
|
"KEY FINDINGS",
|
|
fontsize=14,
|
|
fontweight="bold",
|
|
ha="center",
|
|
va="center",
|
|
)
|
|
y_pos -= 0.1
|
|
|
|
if findings:
|
|
for i, finding in enumerate(findings[:5]):
|
|
ax_findings.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center")
|
|
ax_findings.text(
|
|
0.1, y_pos, finding, fontsize=11, ha="left", va="center", wrap=True
|
|
)
|
|
y_pos -= 0.15
|
|
else:
|
|
ax_findings.text(
|
|
0.1,
|
|
y_pos,
|
|
"No specific findings detailed.",
|
|
fontsize=11,
|
|
ha="left",
|
|
va="center",
|
|
)
|
|
|
|
|
|
ax_image = fig.add_subplot(gs[1, 0])
|
|
|
|
if image is not None:
|
|
|
|
if isinstance(image, str):
|
|
img = Image.open(image)
|
|
else:
|
|
img = image
|
|
|
|
|
|
ax_image.imshow(img)
|
|
ax_image.set_title("X-ray Image", fontsize=12)
|
|
else:
|
|
ax_image.text(0.5, 0.5, "No image available", ha="center", va="center")
|
|
|
|
ax_image.axis("off")
|
|
|
|
|
|
ax_rec = fig.add_subplot(gs[1, 1])
|
|
ax_rec.axis("off")
|
|
|
|
|
|
recommendations = fused_results.get("followup_recommendations", [])
|
|
|
|
|
|
y_pos = 0.9
|
|
ax_rec.text(
|
|
0.5,
|
|
y_pos,
|
|
"RECOMMENDATIONS",
|
|
fontsize=14,
|
|
fontweight="bold",
|
|
ha="center",
|
|
va="center",
|
|
)
|
|
y_pos -= 0.1
|
|
|
|
if recommendations:
|
|
for i, rec in enumerate(recommendations):
|
|
ax_rec.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center")
|
|
ax_rec.text(
|
|
0.1, y_pos, rec, fontsize=11, ha="left", va="center", wrap=True
|
|
)
|
|
y_pos -= 0.15
|
|
else:
|
|
ax_rec.text(
|
|
0.1,
|
|
y_pos,
|
|
"No specific recommendations provided.",
|
|
fontsize=11,
|
|
ha="left",
|
|
va="center",
|
|
)
|
|
|
|
|
|
fig.text(
|
|
0.5,
|
|
0.03,
|
|
"DISCLAIMER: This analysis is for informational purposes only and should not replace professional medical advice.",
|
|
fontsize=9,
|
|
style="italic",
|
|
ha="center",
|
|
)
|
|
|
|
fig.tight_layout(rect=[0, 0.05, 1, 0.95])
|
|
return fig
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error plotting multimodal results: {e}")
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 6))
|
|
ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
|
|
return fig
|
|
|
|
|
|
def figure_to_base64(fig):
|
|
"""
|
|
Convert matplotlib figure to base64 string.
|
|
|
|
Args:
|
|
fig (matplotlib.figure.Figure): Figure object
|
|
|
|
Returns:
|
|
str: Base64 encoded string
|
|
"""
|
|
try:
|
|
buf = io.BytesIO()
|
|
fig.savefig(buf, format="png", bbox_inches="tight")
|
|
buf.seek(0)
|
|
img_str = base64.b64encode(buf.read()).decode("utf-8")
|
|
return img_str
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error converting figure to base64: {e}")
|
|
return ""
|
|
|