Amarthya7's picture
Upload 21 files
86a74e6 verified
import base64
import io
import logging
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
# Set up logging
logger = logging.getLogger(__name__)
def plot_image_prediction(image, predictions, title=None, figsize=(10, 8)):
"""
Plot an image with its predictions.
Args:
image (PIL.Image or str): Image or path to image
predictions (list): List of (label, probability) tuples
title (str, optional): Plot title
figsize (tuple): Figure size
Returns:
matplotlib.figure.Figure: The figure object
"""
try:
# Load image if path is provided
if isinstance(image, str):
img = Image.open(image)
else:
img = image
# Create figure
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
# Plot image
ax1.imshow(img)
ax1.set_title("X-ray Image")
ax1.axis("off")
# Plot predictions
if predictions:
# Sort predictions by probability
sorted_pred = sorted(predictions, key=lambda x: x[1], reverse=True)
# Get top 5 predictions
top_n = min(5, len(sorted_pred))
labels = [pred[0] for pred in sorted_pred[:top_n]]
probs = [pred[1] for pred in sorted_pred[:top_n]]
# Plot horizontal bar chart
y_pos = np.arange(top_n)
ax2.barh(y_pos, probs, align="center")
ax2.set_yticks(y_pos)
ax2.set_yticklabels(labels)
ax2.set_xlabel("Probability")
ax2.set_title("Top Predictions")
ax2.set_xlim(0, 1)
# Annotate probabilities
for i, prob in enumerate(probs):
ax2.text(prob + 0.02, i, f"{prob:.1%}", va="center")
# Set overall title
if title:
fig.suptitle(title, fontsize=16)
fig.tight_layout()
return fig
except Exception as e:
logger.error(f"Error plotting image prediction: {e}")
# Create empty figure if error occurs
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
return fig
def create_heatmap_overlay(image, heatmap, alpha=0.4):
"""
Create a heatmap overlay on an X-ray image to highlight areas of interest.
Args:
image (PIL.Image or str): Image or path to image
heatmap (numpy.ndarray): Heatmap array
alpha (float): Transparency of the overlay
Returns:
PIL.Image: Image with heatmap overlay
"""
try:
# Load image if path is provided
if isinstance(image, str):
img = cv2.imread(image)
if img is None:
raise ValueError(f"Could not load image: {image}")
elif isinstance(image, Image.Image):
img = np.array(image)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
else:
img = image
# Ensure image is in BGR format for OpenCV
if len(img.shape) == 2: # Grayscale
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
# Resize heatmap to match image dimensions
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
# Normalize heatmap (0-1)
heatmap = np.maximum(heatmap, 0)
heatmap = np.minimum(heatmap / np.max(heatmap), 1)
# Apply colormap (jet) to heatmap
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Create overlay
overlay = cv2.addWeighted(img, 1 - alpha, heatmap, alpha, 0)
# Convert back to PIL image
overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
overlay_img = Image.fromarray(overlay)
return overlay_img
except Exception as e:
logger.error(f"Error creating heatmap overlay: {e}")
# Return original image if error occurs
if isinstance(image, str):
return Image.open(image)
elif isinstance(image, Image.Image):
return image
else:
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def plot_report_entities(text, entities, figsize=(12, 8)):
"""
Visualize entities extracted from a medical report.
Args:
text (str): Report text
entities (dict): Dictionary of entities by category
figsize (tuple): Figure size
Returns:
matplotlib.figure.Figure: The figure object
"""
try:
fig, ax = plt.subplots(figsize=figsize)
ax.axis("off")
# Set background color
fig.patch.set_facecolor("#f8f9fa")
ax.set_facecolor("#f8f9fa")
# Title
ax.text(
0.5,
0.98,
"Medical Report Analysis",
ha="center",
va="top",
fontsize=18,
fontweight="bold",
color="#2c3e50",
)
# Display entity counts
y_pos = 0.9
ax.text(
0.05,
y_pos,
"Extracted Entities:",
fontsize=14,
fontweight="bold",
color="#2c3e50",
)
y_pos -= 0.05
# Define colors for different entity categories
category_colors = {
"problem": "#e74c3c", # Red
"test": "#3498db", # Blue
"treatment": "#2ecc71", # Green
"anatomy": "#9b59b6", # Purple
}
# Display entities by category
for category, items in entities.items():
if items:
y_pos -= 0.05
ax.text(
0.1,
y_pos,
f"{category.capitalize()}:",
fontsize=12,
fontweight="bold",
)
y_pos -= 0.05
ax.text(
0.15,
y_pos,
", ".join(items),
wrap=True,
fontsize=11,
color=category_colors.get(category, "black"),
)
# Add the report text with highlighted entities
y_pos -= 0.1
ax.text(
0.05,
y_pos,
"Report Text (with highlighted entities):",
fontsize=14,
fontweight="bold",
color="#2c3e50",
)
y_pos -= 0.05
# Get all entities to highlight
all_entities = []
for category, items in entities.items():
for item in items:
all_entities.append((item, category))
# Sort entities by length (longest first to avoid overlap issues)
all_entities.sort(key=lambda x: len(x[0]), reverse=True)
# Highlight entities in text
highlighted_text = text
for entity, category in all_entities:
# Escape regex special characters
entity_escaped = (
entity.replace("(", r"\(")
.replace(")", r"\)")
.replace("[", r"\[")
.replace("]", r"\]")
)
# Find entity in text (word boundary)
pattern = r"\b" + entity_escaped + r"\b"
color_code = category_colors.get(category, "black")
replacement = f"\\textcolor{{{color_code}}}{{{entity}}}"
highlighted_text = highlighted_text.replace(entity, replacement)
# Display highlighted text
ax.text(0.05, y_pos, highlighted_text, va="top", fontsize=10, wrap=True)
fig.tight_layout(rect=[0, 0.03, 1, 0.97])
return fig
except Exception as e:
logger.error(f"Error plotting report entities: {e}")
# Create empty figure if error occurs
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
return fig
def plot_multimodal_results(
fused_results, image=None, report_text=None, figsize=(12, 10)
):
"""
Visualize the results of multimodal analysis.
Args:
fused_results (dict): Results from multimodal fusion
image (PIL.Image or str, optional): Image or path to image
report_text (str, optional): Report text
figsize (tuple): Figure size
Returns:
matplotlib.figure.Figure: The figure object
"""
try:
# Create figure with a grid layout
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(2, 2)
# Add title
fig.suptitle(
"Multimodal Medical Analysis Results",
fontsize=18,
fontweight="bold",
y=0.98,
)
# 1. Overview panel (top left)
ax_overview = fig.add_subplot(gs[0, 0])
ax_overview.axis("off")
# Get severity info
severity = fused_results.get("severity", {})
severity_level = severity.get("level", "Unknown")
severity_score = severity.get("score", 0)
# Get primary finding
primary_finding = fused_results.get("primary_finding", "Unknown")
# Get agreement score
agreement = fused_results.get("agreement_score", 0)
# Create overview text
overview_text = [
"ANALYSIS OVERVIEW",
f"Primary Finding: {primary_finding}",
f"Severity Level: {severity_level} ({severity_score}/4)",
f"Agreement Score: {agreement:.0%}",
]
# Define severity colors
severity_colors = {
"Normal": "#2ecc71", # Green
"Mild": "#3498db", # Blue
"Moderate": "#f39c12", # Orange
"Severe": "#e74c3c", # Red
"Critical": "#c0392b", # Dark Red
}
# Add overview text to the panel
y_pos = 0.9
ax_overview.text(
0.5,
y_pos,
overview_text[0],
fontsize=14,
fontweight="bold",
ha="center",
va="center",
)
y_pos -= 0.15
ax_overview.text(
0.1, y_pos, overview_text[1], fontsize=12, ha="left", va="center"
)
y_pos -= 0.1
# Severity with color
severity_color = severity_colors.get(severity_level, "black")
ax_overview.text(
0.1, y_pos, "Severity Level:", fontsize=12, ha="left", va="center"
)
ax_overview.text(
0.4,
y_pos,
severity_level,
fontsize=12,
color=severity_color,
fontweight="bold",
ha="left",
va="center",
)
ax_overview.text(
0.6, y_pos, f"({severity_score}/4)", fontsize=10, ha="left", va="center"
)
y_pos -= 0.1
# Agreement score with color
agreement_color = (
"#2ecc71"
if agreement > 0.7
else "#f39c12"
if agreement > 0.4
else "#e74c3c"
)
ax_overview.text(
0.1, y_pos, "Agreement Score:", fontsize=12, ha="left", va="center"
)
ax_overview.text(
0.4,
y_pos,
f"{agreement:.0%}",
fontsize=12,
color=agreement_color,
fontweight="bold",
ha="left",
va="center",
)
# 2. Findings panel (top right)
ax_findings = fig.add_subplot(gs[0, 1])
ax_findings.axis("off")
# Get findings
findings = fused_results.get("findings", [])
# Add findings to the panel
y_pos = 0.9
ax_findings.text(
0.5,
y_pos,
"KEY FINDINGS",
fontsize=14,
fontweight="bold",
ha="center",
va="center",
)
y_pos -= 0.1
if findings:
for i, finding in enumerate(findings[:5]): # Limit to 5 findings
ax_findings.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center")
ax_findings.text(
0.1, y_pos, finding, fontsize=11, ha="left", va="center", wrap=True
)
y_pos -= 0.15
else:
ax_findings.text(
0.1,
y_pos,
"No specific findings detailed.",
fontsize=11,
ha="left",
va="center",
)
# 3. Image panel (bottom left)
ax_image = fig.add_subplot(gs[1, 0])
if image is not None:
# Load image if path is provided
if isinstance(image, str):
img = Image.open(image)
else:
img = image
# Display image
ax_image.imshow(img)
ax_image.set_title("X-ray Image", fontsize=12)
else:
ax_image.text(0.5, 0.5, "No image available", ha="center", va="center")
ax_image.axis("off")
# 4. Recommendation panel (bottom right)
ax_rec = fig.add_subplot(gs[1, 1])
ax_rec.axis("off")
# Get recommendations
recommendations = fused_results.get("followup_recommendations", [])
# Add recommendations to the panel
y_pos = 0.9
ax_rec.text(
0.5,
y_pos,
"RECOMMENDATIONS",
fontsize=14,
fontweight="bold",
ha="center",
va="center",
)
y_pos -= 0.1
if recommendations:
for i, rec in enumerate(recommendations):
ax_rec.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center")
ax_rec.text(
0.1, y_pos, rec, fontsize=11, ha="left", va="center", wrap=True
)
y_pos -= 0.15
else:
ax_rec.text(
0.1,
y_pos,
"No specific recommendations provided.",
fontsize=11,
ha="left",
va="center",
)
# Add disclaimer
fig.text(
0.5,
0.03,
"DISCLAIMER: This analysis is for informational purposes only and should not replace professional medical advice.",
fontsize=9,
style="italic",
ha="center",
)
fig.tight_layout(rect=[0, 0.05, 1, 0.95])
return fig
except Exception as e:
logger.error(f"Error plotting multimodal results: {e}")
# Create empty figure if error occurs
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
return fig
def figure_to_base64(fig):
"""
Convert matplotlib figure to base64 string.
Args:
fig (matplotlib.figure.Figure): Figure object
Returns:
str: Base64 encoded string
"""
try:
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
img_str = base64.b64encode(buf.read()).decode("utf-8")
return img_str
except Exception as e:
logger.error(f"Error converting figure to base64: {e}")
return ""