AutoPharmaV2 / utils /vision_model.py
MohammedSameerSyed's picture
Initial commit
20dcaab verified
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from io import BytesIO
def load_vision_model(model_name="flaviagiammarino/medsam-vit-base"):
"""
Load MedSAM model from Hugging Face
Args:
model_name (str): Model repository name
Returns:
tuple: (model, processor)
"""
from transformers import SamModel, SamProcessor
try:
# Try loading the model
model = SamModel.from_pretrained(model_name)
processor = SamProcessor.from_pretrained(model_name)
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
return model, processor
except Exception as e:
raise RuntimeError(f"Failed to load vision model {model_name}: {e}")
def identify_image_type(image):
"""Identify the type of medical image based on visual characteristics"""
# Convert to numpy array if it's a PIL image
if isinstance(image, Image.Image):
img_array = np.array(image)
else:
img_array = image
# Get image dimensions and ratio
height, width = img_array.shape[:2]
aspect_ratio = width / height
# Basic image type detection logic
if aspect_ratio > 1.4: # Wide format
# Likely a panoramic X-ray or abdominal scan
return "Panoramic X-ray"
elif aspect_ratio < 0.7: # Tall format
# Likely a full spine X-ray
return "Full Spine X-ray"
else: # Square-ish format
# Check brightness distribution for chest vs other X-rays
# Chest X-rays typically have more contrast between dark (lungs) and bright (bones) areas
# If grayscale, use directly, otherwise convert
if len(img_array.shape) > 2: # Color image
gray_img = np.mean(img_array, axis=2)
else:
gray_img = img_array
# Normalize to 0-1
if gray_img.max() > 0:
gray_img = gray_img / gray_img.max()
# Check if has clear lung fields (darker regions in center)
center_region = gray_img[height//4:3*height//4, width//4:3*width//4]
edges_region = gray_img.copy()
edges_region[height//4:3*height//4, width//4:3*width//4] = 1 # Mask out center
center_mean = np.mean(center_region)
edges_mean = np.mean(edges_region)
# Chest X-rays typically have darker center (lung fields)
if center_mean < edges_mean * 0.85:
return "Chest X-ray"
else:
# Look for bone structures
high_intensity = np.percentile(gray_img, 95) * 0.95
bone_pixels = np.sum(gray_img > high_intensity) / (height * width)
if bone_pixels > 0.15: # Significant bone content
if height > width:
return "Spine X-ray"
else:
return "Extremity X-ray"
# Default
return "Medical X-ray"
def detect_abnormalities(image_type, mask, image_array):
"""Detect potential abnormalities based on image type and mask area"""
# Create more meaningful default findings
findings = {
"regions_of_interest": ["No specific abnormalities detected"],
"potential_findings": ["Normal study"],
"additional_notes": []
}
# Get mask properties
if len(mask.shape) > 2:
mask = mask[:,:,0] # Take first channel if multi-channel
# Extract masked region stats
if np.any(mask):
rows, cols = np.where(mask)
min_row, max_row = min(rows), max(rows)
min_col, max_col = min(cols), max(cols)
# Get region location
height, width = mask.shape
region_center_y = np.mean(rows)
region_center_x = np.mean(cols)
rel_y = region_center_y / height
rel_x = region_center_x / width
# Get image intensity stats in masked region
if len(image_array.shape) > 2:
gray_img = np.mean(image_array, axis=2)
else:
gray_img = image_array
if gray_img.max() > 0:
gray_img = gray_img / gray_img.max()
# Get statistics of the region
mask_intensities = gray_img[mask]
if len(mask_intensities) > 0:
region_mean = np.mean(mask_intensities)
region_std = np.std(mask_intensities)
# Calculate stats outside the mask for comparison
inverse_mask = ~mask
outside_intensities = gray_img[inverse_mask]
if len(outside_intensities) > 0:
outside_mean = np.mean(outside_intensities)
intensity_diff = abs(region_mean - outside_mean)
else:
outside_mean = 0
intensity_diff = 0
# Identify regions of interest based on image type
if image_type == "Chest X-ray":
findings["regions_of_interest"] = []
# Identify anatomical regions in chest X-ray
if rel_y < 0.3: # Upper chest
if rel_x < 0.4:
findings["regions_of_interest"].append("Left upper lung field")
elif rel_x > 0.6:
findings["regions_of_interest"].append("Right upper lung field")
else:
findings["regions_of_interest"].append("Upper mediastinum")
elif rel_y < 0.6: # Mid chest
if rel_x < 0.4:
findings["regions_of_interest"].append("Left mid lung field")
elif rel_x > 0.6:
findings["regions_of_interest"].append("Right mid lung field")
else:
findings["regions_of_interest"].append("Central mediastinum")
findings["regions_of_interest"].append("Cardiac silhouette")
else: # Lower chest
if rel_x < 0.4:
findings["regions_of_interest"].append("Left lower lung field")
findings["regions_of_interest"].append("Left costophrenic angle")
elif rel_x > 0.6:
findings["regions_of_interest"].append("Right lower lung field")
findings["regions_of_interest"].append("Right costophrenic angle")
else:
findings["regions_of_interest"].append("Lower mediastinum")
findings["regions_of_interest"].append("Upper abdomen")
# Check for potential abnormalities based on intensity
findings["potential_findings"] = []
if region_mean < outside_mean * 0.7 and region_std < 0.15:
findings["potential_findings"].append("Potential hyperlucency/emphysematous changes")
elif region_mean > outside_mean * 1.3:
if region_std > 0.2:
findings["potential_findings"].append("Heterogeneous opacity")
else:
findings["potential_findings"].append("Homogeneous opacity/consolidation")
# Add size of area
mask_height = max_row - min_row
mask_width = max_col - min_col
if max(mask_height, mask_width) > min(height, width) * 0.25:
findings["additional_notes"].append(f"Large area of interest ({mask_height}x{mask_width} pixels)")
else:
findings["additional_notes"].append(f"Focal area of interest ({mask_height}x{mask_width} pixels)")
elif "Spine" in image_type:
# Vertebral analysis for spine X-rays
findings["regions_of_interest"] = []
if rel_y < 0.3:
findings["regions_of_interest"].append("Cervical spine region")
elif rel_y < 0.6:
findings["regions_of_interest"].append("Thoracic spine region")
else:
findings["regions_of_interest"].append("Lumbar spine region")
# Check for potential findings
findings["potential_findings"] = []
if region_std > 0.25: # High variability in vertebral region could indicate irregularity
findings["potential_findings"].append("Potential vertebral irregularity")
if intensity_diff > 0.3:
findings["potential_findings"].append("Area of abnormal density")
elif "Extremity" in image_type:
# Extremity X-ray analysis
findings["regions_of_interest"] = []
# Basic positioning
if rel_y < 0.5 and rel_x < 0.5:
findings["regions_of_interest"].append("Proximal joint region")
elif rel_y > 0.5 and rel_x > 0.5:
findings["regions_of_interest"].append("Distal joint region")
else:
findings["regions_of_interest"].append("Mid-shaft bone region")
# Check for potential findings
findings["potential_findings"] = []
if region_std > 0.25: # High variability could indicate irregular bone contour
findings["potential_findings"].append("Potential cortical irregularity")
if intensity_diff > 0.4:
findings["potential_findings"].append("Area of abnormal bone density")
# Default if no findings identified
if len(findings["potential_findings"]) == 0:
findings["potential_findings"] = ["No obvious abnormalities in segmented region"]
return findings
def analyze_medical_image(image_type, image, mask, metadata):
"""Generate a comprehensive medical image analysis"""
# Convert to numpy if PIL image
if isinstance(image, Image.Image):
image_array = np.array(image)
else:
image_array = image
# Detect abnormalities based on image type and region
abnormalities = detect_abnormalities(image_type, mask, image_array)
# Get mask properties
mask_area = metadata["mask_percentage"]
confidence = metadata["score"]
# Determine anatomical positioning
height, width = mask.shape if len(mask.shape) == 2 else mask.shape[:2]
if np.any(mask):
rows, cols = np.where(mask)
center_y = np.mean(rows) / height
center_x = np.mean(cols) / width
# Determine laterality
if center_x < 0.4:
laterality = "Left side predominant"
elif center_x > 0.6:
laterality = "Right side predominant"
else:
laterality = "Midline/central"
# Determine superior/inferior position
if center_y < 0.4:
position = "Superior/upper region"
elif center_y > 0.6:
position = "Inferior/lower region"
else:
position = "Mid/central region"
else:
laterality = "Undetermined"
position = "Undetermined"
# Generate analysis text
if image_type == "Chest X-ray":
image_description = "anteroposterior (AP) or posteroanterior (PA) chest radiograph"
regions = ", ".join(abnormalities["regions_of_interest"])
findings = ", ".join(abnormalities["potential_findings"])
elif "Spine" in image_type:
image_description = "spinal radiograph"
regions = ", ".join(abnormalities["regions_of_interest"])
findings = ", ".join(abnormalities["potential_findings"])
elif "Extremity" in image_type:
image_description = "extremity radiograph"
regions = ", ".join(abnormalities["regions_of_interest"])
findings = ", ".join(abnormalities["potential_findings"])
else:
image_description = "medical radiograph"
regions = ", ".join(abnormalities["regions_of_interest"])
findings = ", ".join(abnormalities["potential_findings"])
# Finalize analysis text
analysis_text = f"""
## Radiological Analysis
**Image Type**: {image_type}
**Segmentation Details**:
- Region: {position} ({regions})
- Laterality: {laterality}
- Coverage: {mask_area:.1f}% of the image
**Findings**:
- {findings}
- {'; '.join(abnormalities["additional_notes"]) if abnormalities["additional_notes"] else 'No additional notes'}
**Technical Assessment**:
- Segmentation confidence: {confidence:.2f} (on a 0-1 scale)
- Image quality: {'Adequate' if confidence > 0.4 else 'Suboptimal'} for assessment
**Impression**:
This {image_description} demonstrates a highlighted area in the {position.lower()} with {laterality.lower()}.
{findings.capitalize() if findings else 'No significant abnormalities identified in the segmented region.'} Additional clinical correlation is recommended.
*Note: This is an automated analysis and should be reviewed by a qualified healthcare professional.*
"""
# Create analysis results as dict
analysis_results = {
"image_type": image_type,
"region": position,
"laterality": laterality,
"regions_of_interest": abnormalities["regions_of_interest"],
"potential_findings": abnormalities["potential_findings"],
"additional_notes": abnormalities["additional_notes"],
"coverage_percentage": mask_area,
"confidence_score": confidence
}
return analysis_text, analysis_results
def process_medical_image(image, model=None, processor=None):
"""
Process medical image with MedSAM using automatic segmentation
Args:
image (PIL.Image): Input image
model: SamModel instance (optional, will be loaded if not provided)
processor: SamProcessor instance (optional, will be loaded if not provided)
Returns:
tuple: (PIL.Image of segmentation, metadata dict, analysis text)
"""
# Load model and processor if not provided
if model is None or processor is None:
from transformers import SamModel, SamProcessor
model_name = "flaviagiammarino/medsam-vit-base"
model, processor = load_vision_model(model_name)
# Convert image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Convert to grayscale if it's not already
if image.mode != 'L':
grayscale = image.convert('L')
# Convert back to RGB for processing
image_for_processing = grayscale.convert('RGB')
else:
# If already grayscale, convert to RGB for processing
image_for_processing = image.convert('RGB')
# Resize image to a standard size (FIX: make sure we use consistent dimensions)
image_size = 512 # Use power of 2 for better compatibility
processed_image = image_for_processing.resize((image_size, image_size), Image.LANCZOS)
image_array = np.array(processed_image)
# Identify the type of medical image
image_type = identify_image_type(image)
try:
# For chest X-rays, target the full central region
# This ensures we analyze most of the image rather than just a tiny portion
height, width = image_array.shape[:2]
# FIX: Ensure input_boxes are in the correct format: [[x1, y1, x2, y2]] (not [x1, y1, x2, y2])
# Create a large box covering ~75% of the image
margin = width // 8 # 12.5% margin on each side
# Correct box format: list of lists where each inner list is [x1, y1, x2, y2]
box = [[margin, margin, width - margin, height - margin]]
# Process with the larger box
inputs = processor(
images=processed_image, # FIX: Use PIL image instead of numpy array
input_boxes=[box], # FIX: Ensure correct nesting
return_tensors="pt"
)
# Transfer inputs to the same device as the model
inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}
# Run inference
with torch.no_grad():
outputs = model(**inputs)
# Process the masks - FIX: Make sure we use the correct dimensions
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.squeeze(1),
inputs["original_sizes"],
inputs["reshaped_input_sizes"]
)
# Get scores
scores = outputs.iou_scores
best_idx = torch.argmax(scores)
score_value = float(scores[0][best_idx].cpu().numpy())
# Get the best mask
mask = masks[0][best_idx].cpu().numpy() > 0
except Exception as e:
print(f"Error in MedSAM processing: {e}")
# Create a fallback mask covering most of the central image area
mask = np.zeros((image_size, image_size), dtype=bool)
margin = image_size // 8
mask[margin:image_size-margin, margin:image_size-margin] = True
score_value = 0.5
# Visualize results
fig, ax = plt.subplots(figsize=(12, 12))
# Use the grayscale image for visualization if it was an X-ray
ax.imshow(image_array, cmap='gray' if image.mode == 'L' else None)
# Show mask as overlay with improved visibility
color_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.float32)
color_mask[mask] = [1, 0, 0, 0.4] # Semi-transparent red
ax.imshow(color_mask)
# Add title with image type
ax.set_title(f"Medical Image Segmentation: {image_type}", fontsize=14)
ax.axis('off')
# Convert plot to image
fig.patch.set_facecolor('white')
buf = BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1,
facecolor='white', dpi=150)
plt.close(fig)
buf.seek(0)
result_image = Image.open(buf)
# Prepare metadata
metadata = {
"mask_percentage": float(np.mean(mask) * 100), # Percentage of image that is masked
"score": score_value,
"size": {
"width": mask.shape[1],
"height": mask.shape[0]
}
}
# Generate analysis
analysis_text, analysis_results = analyze_medical_image(image_type, processed_image, mask, metadata)
# FIX: Return the result_image directly, not as part of a tuple with metadata and analysis
return result_image, metadata, analysis_text