|
|
|
|
|
import os
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from pathlib import Path
|
|
|
import matplotlib.pyplot as plt
|
|
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance
|
|
|
|
|
|
from xray_generator.inference import XrayGenerator
|
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).parent
|
|
|
MODEL_PATH = BASE_DIR / "outputs" / "diffusion_checkpoints" / "checkpoint_epoch_480.pt"
|
|
|
OUTPUT_DIR = BASE_DIR / "outputs" / "enhanced_xrays"
|
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
TEST_PROMPTS = [
|
|
|
"Normal chest X-ray with clear lungs and no abnormalities.",
|
|
|
"Right lower lobe pneumonia with focal consolidation.",
|
|
|
"Bilateral pleural effusions, greater on the right."
|
|
|
]
|
|
|
|
|
|
def apply_windowing(image, window_center=0.5, window_width=0.8):
|
|
|
"""
|
|
|
Apply window/level adjustment (similar to radiological windowing).
|
|
|
"""
|
|
|
img_array = np.array(image).astype(np.float32) / 255.0
|
|
|
|
|
|
|
|
|
min_val = window_center - window_width / 2
|
|
|
max_val = window_center + window_width / 2
|
|
|
|
|
|
img_array = np.clip((img_array - min_val) / (max_val - min_val), 0, 1)
|
|
|
|
|
|
return Image.fromarray((img_array * 255).astype(np.uint8))
|
|
|
|
|
|
def apply_edge_enhancement(image, amount=1.5):
|
|
|
"""Apply edge enhancement using unsharp mask."""
|
|
|
|
|
|
if isinstance(image, np.ndarray):
|
|
|
image = Image.fromarray(image)
|
|
|
|
|
|
|
|
|
enhancer = ImageEnhance.Sharpness(image)
|
|
|
return enhancer.enhance(amount)
|
|
|
|
|
|
def apply_median_filter(image, size=3):
|
|
|
"""Apply median filter to reduce noise."""
|
|
|
|
|
|
if isinstance(image, np.ndarray):
|
|
|
image = Image.fromarray(image)
|
|
|
|
|
|
|
|
|
size = max(3, int(size))
|
|
|
if size % 2 == 0:
|
|
|
size += 1
|
|
|
|
|
|
|
|
|
img_array = np.array(image)
|
|
|
filtered = cv2.medianBlur(img_array, size)
|
|
|
|
|
|
return Image.fromarray(filtered)
|
|
|
|
|
|
def apply_clahe(image, clip_limit=2.0, grid_size=(8, 8)):
|
|
|
"""Apply CLAHE to enhance contrast."""
|
|
|
|
|
|
if isinstance(image, Image.Image):
|
|
|
img_array = np.array(image)
|
|
|
else:
|
|
|
img_array = image
|
|
|
|
|
|
|
|
|
clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid_size)
|
|
|
enhanced = clahe.apply(img_array)
|
|
|
|
|
|
return Image.fromarray(enhanced)
|
|
|
|
|
|
def apply_histogram_equalization(image):
|
|
|
"""Apply histogram equalization to enhance contrast."""
|
|
|
|
|
|
if isinstance(image, np.ndarray):
|
|
|
image = Image.fromarray(image)
|
|
|
|
|
|
return ImageOps.equalize(image)
|
|
|
|
|
|
def apply_vignette(image, amount=0.85):
|
|
|
"""Apply vignette effect (darker edges) to mimic X-ray effect."""
|
|
|
|
|
|
img_array = np.array(image).astype(np.float32)
|
|
|
|
|
|
|
|
|
height, width = img_array.shape
|
|
|
center_x, center_y = width // 2, height // 2
|
|
|
radius = np.sqrt(width**2 + height**2) / 2
|
|
|
|
|
|
|
|
|
y, x = np.ogrid[:height, :width]
|
|
|
dist_from_center = np.sqrt((x - center_x)**2 + (y - center_y)**2)
|
|
|
|
|
|
|
|
|
mask = 1 - amount * (dist_from_center / radius)
|
|
|
mask = np.clip(mask, 0, 1)
|
|
|
|
|
|
|
|
|
img_array = img_array * mask
|
|
|
|
|
|
return Image.fromarray(np.clip(img_array, 0, 255).astype(np.uint8))
|
|
|
|
|
|
def enhance_xray(image, params=None):
|
|
|
"""
|
|
|
Apply a sequence of enhancements to make the image look more like an authentic X-ray.
|
|
|
"""
|
|
|
|
|
|
if params is None:
|
|
|
params = {
|
|
|
'window_center': 0.5,
|
|
|
'window_width': 0.8,
|
|
|
'edge_amount': 1.3,
|
|
|
'median_size': 3,
|
|
|
'clahe_clip': 2.5,
|
|
|
'clahe_grid': (8, 8),
|
|
|
'vignette_amount': 0.25,
|
|
|
'apply_hist_eq': True
|
|
|
}
|
|
|
|
|
|
|
|
|
if isinstance(image, np.ndarray):
|
|
|
image = Image.fromarray(image)
|
|
|
|
|
|
|
|
|
image = apply_windowing(image, params['window_center'], params['window_width'])
|
|
|
|
|
|
|
|
|
image_np = np.array(image)
|
|
|
image = apply_clahe(image_np, params['clahe_clip'], params['clahe_grid'])
|
|
|
|
|
|
|
|
|
image = apply_median_filter(image, params['median_size'])
|
|
|
|
|
|
|
|
|
image = apply_edge_enhancement(image, params['edge_amount'])
|
|
|
|
|
|
|
|
|
if params['apply_hist_eq']:
|
|
|
image = apply_histogram_equalization(image)
|
|
|
|
|
|
|
|
|
image = apply_vignette(image, params['vignette_amount'])
|
|
|
|
|
|
return image
|
|
|
|
|
|
def generate_and_enhance(generator, prompt, params_list=None):
|
|
|
"""
|
|
|
Generate an X-ray and apply different enhancement parameter sets.
|
|
|
"""
|
|
|
|
|
|
results = generator.generate(prompt=prompt, num_inference_steps=100, guidance_scale=10.0)
|
|
|
raw_image = results['images'][0]
|
|
|
|
|
|
|
|
|
if params_list is None:
|
|
|
params_list = [{
|
|
|
'window_center': 0.5,
|
|
|
'window_width': 0.8,
|
|
|
'edge_amount': 1.3,
|
|
|
'median_size': 3,
|
|
|
'clahe_clip': 2.5,
|
|
|
'clahe_grid': (8, 8),
|
|
|
'vignette_amount': 0.25,
|
|
|
'apply_hist_eq': True
|
|
|
}]
|
|
|
|
|
|
|
|
|
enhanced_images = []
|
|
|
for i, params in enumerate(params_list):
|
|
|
enhanced = enhance_xray(raw_image, params)
|
|
|
enhanced_images.append({
|
|
|
'image': enhanced,
|
|
|
'params': params,
|
|
|
'index': i+1
|
|
|
})
|
|
|
|
|
|
return {
|
|
|
'raw_image': raw_image,
|
|
|
'enhanced_images': enhanced_images,
|
|
|
'prompt': prompt
|
|
|
}
|
|
|
|
|
|
def save_results(results, output_dir):
|
|
|
"""Save all generated and enhanced images."""
|
|
|
prompt_clean = results['prompt'].replace(" ", "_").replace(".", "").lower()[:30]
|
|
|
|
|
|
|
|
|
raw_path = Path(output_dir) / f"raw_{prompt_clean}.png"
|
|
|
results['raw_image'].save(raw_path)
|
|
|
|
|
|
|
|
|
for item in results['enhanced_images']:
|
|
|
enhanced_path = Path(output_dir) / f"enhanced_{item['index']}_{prompt_clean}.png"
|
|
|
item['image'].save(enhanced_path)
|
|
|
|
|
|
|
|
|
params_path = Path(output_dir) / f"params_{item['index']}_{prompt_clean}.txt"
|
|
|
with open(params_path, 'w') as f:
|
|
|
for key, value in item['params'].items():
|
|
|
f.write(f"{key}: {value}\n")
|
|
|
|
|
|
return raw_path
|
|
|
|
|
|
def display_results(results):
|
|
|
"""Display the raw and enhanced images for comparison."""
|
|
|
n_enhanced = len(results['enhanced_images'])
|
|
|
fig, axes = plt.subplots(1, n_enhanced+1, figsize=(4*(n_enhanced+1), 4))
|
|
|
|
|
|
|
|
|
axes[0].imshow(results['raw_image'], cmap='gray')
|
|
|
axes[0].set_title("Original (Raw)")
|
|
|
axes[0].axis('off')
|
|
|
|
|
|
|
|
|
for i, item in enumerate(results['enhanced_images']):
|
|
|
axes[i+1].imshow(item['image'], cmap='gray')
|
|
|
axes[i+1].set_title(f"Enhanced {item['index']}")
|
|
|
axes[i+1].axis('off')
|
|
|
|
|
|
plt.suptitle(f"Prompt: {results['prompt']}")
|
|
|
plt.tight_layout()
|
|
|
return fig
|
|
|
|
|
|
def main():
|
|
|
"""Main function to load model and generate enhanced X-rays."""
|
|
|
|
|
|
print(f"Loading model from: {MODEL_PATH}")
|
|
|
generator = XrayGenerator(
|
|
|
model_path=str(MODEL_PATH),
|
|
|
device="cuda" if torch.cuda.is_available() else "cpu"
|
|
|
)
|
|
|
|
|
|
|
|
|
params_sets = [
|
|
|
|
|
|
{
|
|
|
'window_center': 0.5,
|
|
|
'window_width': 0.8,
|
|
|
'edge_amount': 1.3,
|
|
|
'median_size': 3,
|
|
|
'clahe_clip': 2.5,
|
|
|
'clahe_grid': (8, 8),
|
|
|
'vignette_amount': 0.25,
|
|
|
'apply_hist_eq': True
|
|
|
},
|
|
|
|
|
|
{
|
|
|
'window_center': 0.45,
|
|
|
'window_width': 0.7,
|
|
|
'edge_amount': 1.5,
|
|
|
'median_size': 3,
|
|
|
'clahe_clip': 3.0,
|
|
|
'clahe_grid': (8, 8),
|
|
|
'vignette_amount': 0.3,
|
|
|
'apply_hist_eq': True
|
|
|
},
|
|
|
|
|
|
{
|
|
|
'window_center': 0.55,
|
|
|
'window_width': 0.85,
|
|
|
'edge_amount': 1.8,
|
|
|
'median_size': 3,
|
|
|
'clahe_clip': 2.0,
|
|
|
'clahe_grid': (6, 6),
|
|
|
'vignette_amount': 0.2,
|
|
|
'apply_hist_eq': False
|
|
|
}
|
|
|
]
|
|
|
|
|
|
|
|
|
for i, prompt in enumerate(TEST_PROMPTS):
|
|
|
print(f"Processing prompt {i+1}/{len(TEST_PROMPTS)}: {prompt}")
|
|
|
|
|
|
|
|
|
results = generate_and_enhance(generator, prompt, params_sets)
|
|
|
|
|
|
|
|
|
output_path = save_results(results, OUTPUT_DIR)
|
|
|
print(f"Saved results to {output_path.parent}")
|
|
|
|
|
|
|
|
|
fig = display_results(results)
|
|
|
fig_path = Path(OUTPUT_DIR) / f"comparison_{i+1}.png"
|
|
|
fig.savefig(fig_path)
|
|
|
plt.close(fig)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |