AutoDeploy
Fix: Python 3.8 compatibility (use Tuple from typing) + Gradio 4.48.1 security update
8f59aab
"""
Script test và đánh giá mô hình
"""
import os
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
import json
from tqdm import tqdm
import torch
import torch.nn.functional as F
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from sklearn.metrics import confusion_matrix, jaccard_score, precision_score, recall_score
class MedicalImageSegmentationTester:
def __init__(self, model_path, device="auto"):
self.device = torch.device("cuda" if device == "auto" and torch.cuda.is_available() else "cpu")
print(f"🖥️ Device: {self.device}")
print(f"📁 Loading model from: {model_path}")
# Load model
self.model = SegformerForSemanticSegmentation.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
# Load processor
self.processor = SegformerImageProcessor.from_pretrained(model_path)
print("✓ Model loaded successfully")
def predict_single(self, image_path, return_probs=False):
"""Dự đoán trên một ảnh"""
# Load image
image = Image.open(image_path).convert("RGB")
original_size = image.size[::-1] # (H, W)
# Process image
inputs = self.processor(images=image, return_tensors="pt")
# Inference
with torch.no_grad():
outputs = self.model(pixel_values=inputs["pixel_values"].to(self.device))
logits = outputs.logits
# Interpolate to original size
upsampled_logits = F.interpolate(
logits,
size=original_size,
mode="bilinear",
align_corners=False
)
pred_mask = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
if return_probs:
probs = torch.softmax(upsampled_logits, dim=1)[0].cpu().numpy()
return pred_mask, probs
return pred_mask
def evaluate_dataset(self, image_dir, mask_dir, output_dir=None):
"""Đánh giá trên toàn bộ dataset"""
image_dir = Path(image_dir)
mask_dir = Path(mask_dir)
image_paths = sorted(list(image_dir.glob("*.png")))
print(f"\n📊 Evaluating {len(image_paths)} images...")
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
metrics_list = []
all_true = []
all_pred = []
for img_path in tqdm(image_paths):
img_id = img_path.stem
mask_path = mask_dir / f"{img_id}_mask.png"
if not mask_path.exists():
continue
# Predict
pred_mask = self.predict_single(img_path)
# Load ground truth
true_mask = np.array(Image.open(mask_path))
# Calculate metrics
metrics = self.calculate_metrics(true_mask, pred_mask)
metrics['image_id'] = img_id
metrics_list.append(metrics)
all_true.extend(true_mask.flatten())
all_pred.extend(pred_mask.flatten())
# Save prediction if output_dir provided
if output_dir:
pred_img = Image.fromarray((pred_mask * 50).astype(np.uint8))
pred_img.save(output_dir / f"{img_id}_pred.png")
# Overall metrics
overall_metrics = {
'mIoU': jaccard_score(all_true, all_pred, average='weighted'),
'precision': precision_score(all_true, all_pred, average='weighted', zero_division=0),
'recall': recall_score(all_true, all_pred, average='weighted', zero_division=0),
}
# Per-class metrics
for class_id in range(1, 4): # 1=large_bowel, 2=small_bowel, 3=stomach
class_true = (np.array(all_true) == class_id).astype(int)
class_pred = (np.array(all_pred) == class_id).astype(int)
if class_true.sum() > 0:
overall_metrics[f'class_{class_id}_IoU'] = jaccard_score(class_true, class_pred)
print("\n" + "="*60)
print("📈 Evaluation Results")
print("="*60)
print("\nOverall Metrics:")
for metric, value in overall_metrics.items():
print(f" {metric:20}: {value:.4f}")
print(f"\nPer-image Statistics ({len(metrics_list)} images):")
if metrics_list:
for key in metrics_list[0].keys():
if key != 'image_id':
values = [m[key] for m in metrics_list]
print(f" {key:20}: mean={np.mean(values):.4f}, std={np.std(values):.4f}")
# Save results
results = {
'overall_metrics': overall_metrics,
'per_image_metrics': metrics_list
}
if output_dir:
with open(output_dir / "evaluation_results.json", 'w') as f:
json.dump(results, f, indent=2)
print(f"\n✓ Results saved to {output_dir / 'evaluation_results.json'}")
return results
@staticmethod
def calculate_metrics(true_mask, pred_mask):
"""Tính toán metrics cho một ảnh"""
iou = jaccard_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')
precision = precision_score(true_mask.flatten(), pred_mask.flatten(),
average='weighted', zero_division=0)
recall = recall_score(true_mask.flatten(), pred_mask.flatten(),
average='weighted', zero_division=0)
return {
'iou': iou,
'precision': precision,
'recall': recall
}
def visualize_predictions(self, image_dir, mask_dir, output_dir, num_samples=5):
"""Tạo visualizations của predictions"""
image_dir = Path(image_dir)
mask_dir = Path(mask_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
image_paths = sorted(list(image_dir.glob("*.png")))[:num_samples]
print(f"\n🎨 Visualizing {len(image_paths)} predictions...")
for img_path in tqdm(image_paths):
img_id = img_path.stem
# Load original image
image = Image.open(img_path).convert("RGB")
# Predict
pred_mask, probs = self.predict_single(img_path, return_probs=True)
# Create visualization
# - Original image
# - Prediction mask
# - Confidence map
fig_width = 15
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(fig_width, 5))
# Original
axes[0].imshow(image)
axes[0].set_title("Original Image")
axes[0].axis('off')
# Prediction
axes[1].imshow(pred_mask, cmap='viridis')
axes[1].set_title("Prediction")
axes[1].axis('off')
# Confidence
confidence = np.max(probs, axis=0)
axes[2].imshow(confidence, cmap='hot')
axes[2].set_title("Confidence")
axes[2].axis('off')
plt.tight_layout()
plt.savefig(output_dir / f"{img_id}_visualization.png", dpi=100, bbox_inches='tight')
plt.close()
print(f"✓ Visualizations saved to {output_dir}")
def main():
parser = argparse.ArgumentParser(description="Test and evaluate medical image segmentation model")
parser.add_argument("--model", type=str, required=True,
help="Path to trained model")
parser.add_argument("--test-images", type=str,
help="Path to test images directory")
parser.add_argument("--test-masks", type=str,
help="Path to test masks directory")
parser.add_argument("--output-dir", type=str, default="./test_results",
help="Output directory for results")
parser.add_argument("--visualize", action="store_true",
help="Create visualizations")
parser.add_argument("--num-samples", type=int, default=5,
help="Number of samples to visualize")
args = parser.parse_args()
# Initialize tester
tester = MedicalImageSegmentationTester(args.model)
# Evaluate
if args.test_images and args.test_masks:
results = tester.evaluate_dataset(
args.test_images,
args.test_masks,
args.output_dir
)
# Visualize
if args.visualize:
tester.visualize_predictions(
args.test_images,
args.test_masks,
Path(args.output_dir) / "visualizations",
args.num_samples
)
else:
print("Please provide --test-images and --test-masks directories")
return False
return True
if __name__ == "__main__":
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
success = main()
exit(0 if success else 1)