Spaces:
Sleeping
Sleeping
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 | |
import os | |
from chexnet import ChexNet | |
from unet import Unet | |
from heatmap import HeatmapGenerator | |
from constant import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES | |
import sys | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
imgto3d_path = os.path.join(script_dir, '.') | |
sys.path.append(imgto3d_path) | |
from chestXray_utils import blend_segmentation | |
import torch | |
import pandas as pd | |
output_dir = "pages/images" | |
os.makedirs(output_dir, exist_ok=True) | |
unet_model = '20190211-101020' | |
chexnet_model = '20180429-130928' | |
DISEASES = np.array(CLASS_NAMES) | |
# Initialize models | |
unet = Unet(trained=True, model_name=unet_model) | |
chexnet = ChexNet(trained=True, model_name=chexnet_model) | |
heatmap_generator = HeatmapGenerator(chexnet, mode='cam') | |
unet.eval() | |
chexnet.eval() | |
def process_image(image_path): | |
image = Image.open(image_path).convert('RGB') | |
# Run through net | |
(t, l, b, r), mask = unet.segment(image) | |
cropped_image = image.crop((l, t, r, b)) | |
prob = chexnet.predict(cropped_image) | |
# Save segmentation result | |
blended = blend_segmentation(image, mask) | |
blended = (blended - blended.min()) / (blended.max() - blended.min()) # Normalize to [0, 1] | |
blended = (blended * 255).astype(np.uint8) # Convert to 0-255 range for cv2 | |
cv2.rectangle(blended, (l, t), (r, b), (255, 0, 0), 5) # Color in BGR format for cv2 | |
segment_result_path = os.path.join(output_dir, 'segment_result.png') | |
plt.imsave(segment_result_path, blended) | |
# Save CAM result | |
w, h = cropped_image.size | |
heatmap, _ = heatmap_generator.from_prob(prob, w, h) | |
# Resize the heatmap to match the original image dimensions | |
heatmap_resized = cv2.resize(heatmap, (image.width, image.height)) | |
heatmap_resized = np.repeat(heatmap_resized[:, :, np.newaxis], 3, axis=2) # Ensure it has 3 channels | |
heatmap_resized = ((heatmap_resized - heatmap_resized.min()) * ( | |
1 / (heatmap_resized.max() - heatmap_resized.min())) * 255).astype(np.uint8) | |
cam = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET) | |
cam = cv2.resize(cam, (image.width, image.height)) # Ensure cam has same dimensions as image | |
cam = cv2.addWeighted(cam, 0.4, np.array(image), 0.6, 0) # Combine heatmap with the original image | |
cam_result_path = os.path.join(output_dir, 'cam_result.png') | |
print("a",cam_result_path) | |
cv2.imwrite(cam_result_path, cam) | |
# Top-10 diseases | |
idx = np.argsort(-prob) | |
top_prob = prob[idx[:10]] | |
top_prob = [f'{x:.3}' for x in top_prob] | |
top_disease = DISEASES[idx[:10]] | |
prediction = dict(zip(top_disease, top_prob)) | |
result = {'result': prediction} | |
df = pd.DataFrame(result['result'].items(), columns=['Disease', 'Probability']) | |
output_file = 'prediction_results.csv' | |
output_file_path = os.path.join(output_dir, output_file) | |
df.to_csv(output_file_path, index=False) | |
return result, segment_result_path, cam_result_path | |
# if __name__ == '__main__': | |
# image_path = r'E:\NLP\KN2024\chestX-ray-14\src\fibrosis.jpg' # Replace with your image path | |
# result, segment_result_path, cam_result_path = process_image(image_path) | |
# print("Prediction Results:", result) | |
# print(f"Segmentation Result Saved to: {segment_result_path}") | |
# print(f"CAM Result Saved to: {cam_result_path}") | |