| import base64
|
| import io
|
| import logging
|
|
|
| import cv2
|
| import matplotlib.pyplot as plt
|
| import numpy as np
|
| from PIL import Image
|
|
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| def plot_image_prediction(image, predictions, title=None, figsize=(10, 8)):
|
| """
|
| Plot an image with its predictions.
|
|
|
| Args:
|
| image (PIL.Image or str): Image or path to image
|
| predictions (list): List of (label, probability) tuples
|
| title (str, optional): Plot title
|
| figsize (tuple): Figure size
|
|
|
| Returns:
|
| matplotlib.figure.Figure: The figure object
|
| """
|
| try:
|
|
|
| if isinstance(image, str):
|
| img = Image.open(image)
|
| else:
|
| img = image
|
|
|
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
|
|
|
|
| ax1.imshow(img)
|
| ax1.set_title("X-ray Image")
|
| ax1.axis("off")
|
|
|
|
|
| if predictions:
|
|
|
| sorted_pred = sorted(predictions, key=lambda x: x[1], reverse=True)
|
|
|
|
|
| top_n = min(5, len(sorted_pred))
|
| labels = [pred[0] for pred in sorted_pred[:top_n]]
|
| probs = [pred[1] for pred in sorted_pred[:top_n]]
|
|
|
|
|
| y_pos = np.arange(top_n)
|
| ax2.barh(y_pos, probs, align="center")
|
| ax2.set_yticks(y_pos)
|
| ax2.set_yticklabels(labels)
|
| ax2.set_xlabel("Probability")
|
| ax2.set_title("Top Predictions")
|
| ax2.set_xlim(0, 1)
|
|
|
|
|
| for i, prob in enumerate(probs):
|
| ax2.text(prob + 0.02, i, f"{prob:.1%}", va="center")
|
|
|
|
|
| if title:
|
| fig.suptitle(title, fontsize=16)
|
|
|
| fig.tight_layout()
|
| return fig
|
|
|
| except Exception as e:
|
| logger.error(f"Error plotting image prediction: {e}")
|
|
|
| fig, ax = plt.subplots(figsize=(8, 6))
|
| ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
|
| return fig
|
|
|
|
|
| def create_heatmap_overlay(image, heatmap, alpha=0.4):
|
| """
|
| Create a heatmap overlay on an X-ray image to highlight areas of interest.
|
|
|
| Args:
|
| image (PIL.Image or str): Image or path to image
|
| heatmap (numpy.ndarray): Heatmap array
|
| alpha (float): Transparency of the overlay
|
|
|
| Returns:
|
| PIL.Image: Image with heatmap overlay
|
| """
|
| try:
|
|
|
| if isinstance(image, str):
|
| img = cv2.imread(image)
|
| if img is None:
|
| raise ValueError(f"Could not load image: {image}")
|
| elif isinstance(image, Image.Image):
|
| img = np.array(image)
|
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| else:
|
| img = image
|
|
|
|
|
| if len(img.shape) == 2:
|
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
|
|
| heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
|
|
|
|
|
| heatmap = np.maximum(heatmap, 0)
|
| heatmap = np.minimum(heatmap / np.max(heatmap), 1)
|
|
|
|
|
| heatmap = np.uint8(255 * heatmap)
|
| heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
|
|
|
|
| overlay = cv2.addWeighted(img, 1 - alpha, heatmap, alpha, 0)
|
|
|
|
|
| overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
|
| overlay_img = Image.fromarray(overlay)
|
|
|
| return overlay_img
|
|
|
| except Exception as e:
|
| logger.error(f"Error creating heatmap overlay: {e}")
|
|
|
| if isinstance(image, str):
|
| return Image.open(image)
|
| elif isinstance(image, Image.Image):
|
| return image
|
| else:
|
| return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
|
|
|
|
| def plot_report_entities(text, entities, figsize=(12, 8)):
|
| """
|
| Visualize entities extracted from a medical report.
|
|
|
| Args:
|
| text (str): Report text
|
| entities (dict): Dictionary of entities by category
|
| figsize (tuple): Figure size
|
|
|
| Returns:
|
| matplotlib.figure.Figure: The figure object
|
| """
|
| try:
|
| fig, ax = plt.subplots(figsize=figsize)
|
| ax.axis("off")
|
|
|
|
|
| fig.patch.set_facecolor("#f8f9fa")
|
| ax.set_facecolor("#f8f9fa")
|
|
|
|
|
| ax.text(
|
| 0.5,
|
| 0.98,
|
| "Medical Report Analysis",
|
| ha="center",
|
| va="top",
|
| fontsize=18,
|
| fontweight="bold",
|
| color="#2c3e50",
|
| )
|
|
|
|
|
| y_pos = 0.9
|
| ax.text(
|
| 0.05,
|
| y_pos,
|
| "Extracted Entities:",
|
| fontsize=14,
|
| fontweight="bold",
|
| color="#2c3e50",
|
| )
|
| y_pos -= 0.05
|
|
|
|
|
| category_colors = {
|
| "problem": "#e74c3c",
|
| "test": "#3498db",
|
| "treatment": "#2ecc71",
|
| "anatomy": "#9b59b6",
|
| }
|
|
|
|
|
| for category, items in entities.items():
|
| if items:
|
| y_pos -= 0.05
|
| ax.text(
|
| 0.1,
|
| y_pos,
|
| f"{category.capitalize()}:",
|
| fontsize=12,
|
| fontweight="bold",
|
| )
|
| y_pos -= 0.05
|
| ax.text(
|
| 0.15,
|
| y_pos,
|
| ", ".join(items),
|
| wrap=True,
|
| fontsize=11,
|
| color=category_colors.get(category, "black"),
|
| )
|
|
|
|
|
| y_pos -= 0.1
|
| ax.text(
|
| 0.05,
|
| y_pos,
|
| "Report Text (with highlighted entities):",
|
| fontsize=14,
|
| fontweight="bold",
|
| color="#2c3e50",
|
| )
|
| y_pos -= 0.05
|
|
|
|
|
| all_entities = []
|
| for category, items in entities.items():
|
| for item in items:
|
| all_entities.append((item, category))
|
|
|
|
|
| all_entities.sort(key=lambda x: len(x[0]), reverse=True)
|
|
|
|
|
| highlighted_text = text
|
| for entity, category in all_entities:
|
|
|
| entity_escaped = (
|
| entity.replace("(", r"\(")
|
| .replace(")", r"\)")
|
| .replace("[", r"\[")
|
| .replace("]", r"\]")
|
| )
|
|
|
|
|
| pattern = r"\b" + entity_escaped + r"\b"
|
| color_code = category_colors.get(category, "black")
|
| replacement = f"\\textcolor{{{color_code}}}{{{entity}}}"
|
| highlighted_text = highlighted_text.replace(entity, replacement)
|
|
|
|
|
| ax.text(0.05, y_pos, highlighted_text, va="top", fontsize=10, wrap=True)
|
|
|
| fig.tight_layout(rect=[0, 0.03, 1, 0.97])
|
| return fig
|
|
|
| except Exception as e:
|
| logger.error(f"Error plotting report entities: {e}")
|
|
|
| fig, ax = plt.subplots(figsize=(8, 6))
|
| ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
|
| return fig
|
|
|
|
|
| def plot_multimodal_results(
|
| fused_results, image=None, report_text=None, figsize=(12, 10)
|
| ):
|
| """
|
| Visualize the results of multimodal analysis.
|
|
|
| Args:
|
| fused_results (dict): Results from multimodal fusion
|
| image (PIL.Image or str, optional): Image or path to image
|
| report_text (str, optional): Report text
|
| figsize (tuple): Figure size
|
|
|
| Returns:
|
| matplotlib.figure.Figure: The figure object
|
| """
|
| try:
|
|
|
| fig = plt.figure(figsize=figsize)
|
| gs = fig.add_gridspec(2, 2)
|
|
|
|
|
| fig.suptitle(
|
| "Multimodal Medical Analysis Results",
|
| fontsize=18,
|
| fontweight="bold",
|
| y=0.98,
|
| )
|
|
|
|
|
| ax_overview = fig.add_subplot(gs[0, 0])
|
| ax_overview.axis("off")
|
|
|
|
|
| severity = fused_results.get("severity", {})
|
| severity_level = severity.get("level", "Unknown")
|
| severity_score = severity.get("score", 0)
|
|
|
|
|
| primary_finding = fused_results.get("primary_finding", "Unknown")
|
|
|
|
|
| agreement = fused_results.get("agreement_score", 0)
|
|
|
|
|
| overview_text = [
|
| "ANALYSIS OVERVIEW",
|
| f"Primary Finding: {primary_finding}",
|
| f"Severity Level: {severity_level} ({severity_score}/4)",
|
| f"Agreement Score: {agreement:.0%}",
|
| ]
|
|
|
|
|
| severity_colors = {
|
| "Normal": "#2ecc71",
|
| "Mild": "#3498db",
|
| "Moderate": "#f39c12",
|
| "Severe": "#e74c3c",
|
| "Critical": "#c0392b",
|
| }
|
|
|
|
|
| y_pos = 0.9
|
| ax_overview.text(
|
| 0.5,
|
| y_pos,
|
| overview_text[0],
|
| fontsize=14,
|
| fontweight="bold",
|
| ha="center",
|
| va="center",
|
| )
|
| y_pos -= 0.15
|
|
|
| ax_overview.text(
|
| 0.1, y_pos, overview_text[1], fontsize=12, ha="left", va="center"
|
| )
|
| y_pos -= 0.1
|
|
|
|
|
| severity_color = severity_colors.get(severity_level, "black")
|
| ax_overview.text(
|
| 0.1, y_pos, "Severity Level:", fontsize=12, ha="left", va="center"
|
| )
|
| ax_overview.text(
|
| 0.4,
|
| y_pos,
|
| severity_level,
|
| fontsize=12,
|
| color=severity_color,
|
| fontweight="bold",
|
| ha="left",
|
| va="center",
|
| )
|
| ax_overview.text(
|
| 0.6, y_pos, f"({severity_score}/4)", fontsize=10, ha="left", va="center"
|
| )
|
| y_pos -= 0.1
|
|
|
|
|
| agreement_color = (
|
| "#2ecc71"
|
| if agreement > 0.7
|
| else "#f39c12"
|
| if agreement > 0.4
|
| else "#e74c3c"
|
| )
|
| ax_overview.text(
|
| 0.1, y_pos, "Agreement Score:", fontsize=12, ha="left", va="center"
|
| )
|
| ax_overview.text(
|
| 0.4,
|
| y_pos,
|
| f"{agreement:.0%}",
|
| fontsize=12,
|
| color=agreement_color,
|
| fontweight="bold",
|
| ha="left",
|
| va="center",
|
| )
|
|
|
|
|
| ax_findings = fig.add_subplot(gs[0, 1])
|
| ax_findings.axis("off")
|
|
|
|
|
| findings = fused_results.get("findings", [])
|
|
|
|
|
| y_pos = 0.9
|
| ax_findings.text(
|
| 0.5,
|
| y_pos,
|
| "KEY FINDINGS",
|
| fontsize=14,
|
| fontweight="bold",
|
| ha="center",
|
| va="center",
|
| )
|
| y_pos -= 0.1
|
|
|
| if findings:
|
| for i, finding in enumerate(findings[:5]):
|
| ax_findings.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center")
|
| ax_findings.text(
|
| 0.1, y_pos, finding, fontsize=11, ha="left", va="center", wrap=True
|
| )
|
| y_pos -= 0.15
|
| else:
|
| ax_findings.text(
|
| 0.1,
|
| y_pos,
|
| "No specific findings detailed.",
|
| fontsize=11,
|
| ha="left",
|
| va="center",
|
| )
|
|
|
|
|
| ax_image = fig.add_subplot(gs[1, 0])
|
|
|
| if image is not None:
|
|
|
| if isinstance(image, str):
|
| img = Image.open(image)
|
| else:
|
| img = image
|
|
|
|
|
| ax_image.imshow(img)
|
| ax_image.set_title("X-ray Image", fontsize=12)
|
| else:
|
| ax_image.text(0.5, 0.5, "No image available", ha="center", va="center")
|
|
|
| ax_image.axis("off")
|
|
|
|
|
| ax_rec = fig.add_subplot(gs[1, 1])
|
| ax_rec.axis("off")
|
|
|
|
|
| recommendations = fused_results.get("followup_recommendations", [])
|
|
|
|
|
| y_pos = 0.9
|
| ax_rec.text(
|
| 0.5,
|
| y_pos,
|
| "RECOMMENDATIONS",
|
| fontsize=14,
|
| fontweight="bold",
|
| ha="center",
|
| va="center",
|
| )
|
| y_pos -= 0.1
|
|
|
| if recommendations:
|
| for i, rec in enumerate(recommendations):
|
| ax_rec.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center")
|
| ax_rec.text(
|
| 0.1, y_pos, rec, fontsize=11, ha="left", va="center", wrap=True
|
| )
|
| y_pos -= 0.15
|
| else:
|
| ax_rec.text(
|
| 0.1,
|
| y_pos,
|
| "No specific recommendations provided.",
|
| fontsize=11,
|
| ha="left",
|
| va="center",
|
| )
|
|
|
|
|
| fig.text(
|
| 0.5,
|
| 0.03,
|
| "DISCLAIMER: This analysis is for informational purposes only and should not replace professional medical advice.",
|
| fontsize=9,
|
| style="italic",
|
| ha="center",
|
| )
|
|
|
| fig.tight_layout(rect=[0, 0.05, 1, 0.95])
|
| return fig
|
|
|
| except Exception as e:
|
| logger.error(f"Error plotting multimodal results: {e}")
|
|
|
| fig, ax = plt.subplots(figsize=(8, 6))
|
| ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
|
| return fig
|
|
|
|
|
| def figure_to_base64(fig):
|
| """
|
| Convert matplotlib figure to base64 string.
|
|
|
| Args:
|
| fig (matplotlib.figure.Figure): Figure object
|
|
|
| Returns:
|
| str: Base64 encoded string
|
| """
|
| try:
|
| buf = io.BytesIO()
|
| fig.savefig(buf, format="png", bbox_inches="tight")
|
| buf.seek(0)
|
| img_str = base64.b64encode(buf.read()).decode("utf-8")
|
| return img_str
|
|
|
| except Exception as e:
|
| logger.error(f"Error converting figure to base64: {e}")
|
| return ""
|
|
|