from PIL import Image import numpy as np import matplotlib.pyplot as plt import cv2 import os from chexnet import ChexNet from unet import Unet from heatmap import HeatmapGenerator from constant import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES import sys script_dir = os.path.dirname(os.path.abspath(__file__)) imgto3d_path = os.path.join(script_dir, '.') sys.path.append(imgto3d_path) from chestXray_utils import blend_segmentation import torch import pandas as pd output_dir = "pages/images" os.makedirs(output_dir, exist_ok=True) unet_model = '20190211-101020' chexnet_model = '20180429-130928' DISEASES = np.array(CLASS_NAMES) # Initialize models unet = Unet(trained=True, model_name=unet_model) chexnet = ChexNet(trained=True, model_name=chexnet_model) heatmap_generator = HeatmapGenerator(chexnet, mode='cam') unet.eval() chexnet.eval() def process_image(image_path): image = Image.open(image_path).convert('RGB') # Run through net (t, l, b, r), mask = unet.segment(image) cropped_image = image.crop((l, t, r, b)) prob = chexnet.predict(cropped_image) # Save segmentation result blended = blend_segmentation(image, mask) blended = (blended - blended.min()) / (blended.max() - blended.min()) # Normalize to [0, 1] blended = (blended * 255).astype(np.uint8) # Convert to 0-255 range for cv2 cv2.rectangle(blended, (l, t), (r, b), (255, 0, 0), 5) # Color in BGR format for cv2 segment_result_path = os.path.join(output_dir, 'segment_result.png') plt.imsave(segment_result_path, blended) # Save CAM result w, h = cropped_image.size heatmap, _ = heatmap_generator.from_prob(prob, w, h) # Resize the heatmap to match the original image dimensions heatmap_resized = cv2.resize(heatmap, (image.width, image.height)) heatmap_resized = np.repeat(heatmap_resized[:, :, np.newaxis], 3, axis=2) # Ensure it has 3 channels heatmap_resized = ((heatmap_resized - heatmap_resized.min()) * ( 1 / (heatmap_resized.max() - heatmap_resized.min())) * 255).astype(np.uint8) cam = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET) cam = cv2.resize(cam, (image.width, image.height)) # Ensure cam has same dimensions as image cam = cv2.addWeighted(cam, 0.4, np.array(image), 0.6, 0) # Combine heatmap with the original image cam_result_path = os.path.join(output_dir, 'cam_result.png') print("a",cam_result_path) cv2.imwrite(cam_result_path, cam) # Top-10 diseases idx = np.argsort(-prob) top_prob = prob[idx[:10]] top_prob = [f'{x:.3}' for x in top_prob] top_disease = DISEASES[idx[:10]] prediction = dict(zip(top_disease, top_prob)) result = {'result': prediction} df = pd.DataFrame(result['result'].items(), columns=['Disease', 'Probability']) output_file = 'prediction_results.csv' output_file_path = os.path.join(output_dir, output_file) df.to_csv(output_file_path, index=False) return result, segment_result_path, cam_result_path # if __name__ == '__main__': # image_path = r'E:\NLP\KN2024\chestX-ray-14\src\fibrosis.jpg' # Replace with your image path # result, segment_result_path, cam_result_path = process_image(image_path) # print("Prediction Results:", result) # print(f"Segmentation Result Saved to: {segment_result_path}") # print(f"CAM Result Saved to: {cam_result_path}")