datnguyentien204's picture
Upload 337 files
ce91ea1 verified
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
from chexnet import ChexNet
from unet import Unet
from heatmap import HeatmapGenerator
from constant import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES
import sys
script_dir = os.path.dirname(os.path.abspath(__file__))
imgto3d_path = os.path.join(script_dir, '.')
sys.path.append(imgto3d_path)
from chestXray_utils import blend_segmentation
import torch
import pandas as pd
output_dir = "pages/images"
os.makedirs(output_dir, exist_ok=True)
unet_model = '20190211-101020'
chexnet_model = '20180429-130928'
DISEASES = np.array(CLASS_NAMES)
# Initialize models
unet = Unet(trained=True, model_name=unet_model)
chexnet = ChexNet(trained=True, model_name=chexnet_model)
heatmap_generator = HeatmapGenerator(chexnet, mode='cam')
unet.eval()
chexnet.eval()
def process_image(image_path):
image = Image.open(image_path).convert('RGB')
# Run through net
(t, l, b, r), mask = unet.segment(image)
cropped_image = image.crop((l, t, r, b))
prob = chexnet.predict(cropped_image)
# Save segmentation result
blended = blend_segmentation(image, mask)
blended = (blended - blended.min()) / (blended.max() - blended.min()) # Normalize to [0, 1]
blended = (blended * 255).astype(np.uint8) # Convert to 0-255 range for cv2
cv2.rectangle(blended, (l, t), (r, b), (255, 0, 0), 5) # Color in BGR format for cv2
segment_result_path = os.path.join(output_dir, 'segment_result.png')
plt.imsave(segment_result_path, blended)
# Save CAM result
w, h = cropped_image.size
heatmap, _ = heatmap_generator.from_prob(prob, w, h)
# Resize the heatmap to match the original image dimensions
heatmap_resized = cv2.resize(heatmap, (image.width, image.height))
heatmap_resized = np.repeat(heatmap_resized[:, :, np.newaxis], 3, axis=2) # Ensure it has 3 channels
heatmap_resized = ((heatmap_resized - heatmap_resized.min()) * (
1 / (heatmap_resized.max() - heatmap_resized.min())) * 255).astype(np.uint8)
cam = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
cam = cv2.resize(cam, (image.width, image.height)) # Ensure cam has same dimensions as image
cam = cv2.addWeighted(cam, 0.4, np.array(image), 0.6, 0) # Combine heatmap with the original image
cam_result_path = os.path.join(output_dir, 'cam_result.png')
print("a",cam_result_path)
cv2.imwrite(cam_result_path, cam)
# Top-10 diseases
idx = np.argsort(-prob)
top_prob = prob[idx[:10]]
top_prob = [f'{x:.3}' for x in top_prob]
top_disease = DISEASES[idx[:10]]
prediction = dict(zip(top_disease, top_prob))
result = {'result': prediction}
df = pd.DataFrame(result['result'].items(), columns=['Disease', 'Probability'])
output_file = 'prediction_results.csv'
output_file_path = os.path.join(output_dir, output_file)
df.to_csv(output_file_path, index=False)
return result, segment_result_path, cam_result_path
# if __name__ == '__main__':
# image_path = r'E:\NLP\KN2024\chestX-ray-14\src\fibrosis.jpg' # Replace with your image path
# result, segment_result_path, cam_result_path = process_image(image_path)
# print("Prediction Results:", result)
# print(f"Segmentation Result Saved to: {segment_result_path}")
# print(f"CAM Result Saved to: {cam_result_path}")