cxr-diffusion / post_process.py
pyamy's picture
Upload 31 files
0a0f923 verified
# post_process.py
import os
import cv2
import numpy as np
import torch
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image, ImageOps, ImageFilter, ImageEnhance
from xray_generator.inference import XrayGenerator
# Set up paths
BASE_DIR = Path(__file__).parent
MODEL_PATH = BASE_DIR / "outputs" / "diffusion_checkpoints" / "checkpoint_epoch_480.pt"
OUTPUT_DIR = BASE_DIR / "outputs" / "enhanced_xrays"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# Test prompt
TEST_PROMPTS = [
"Normal chest X-ray with clear lungs and no abnormalities.",
"Right lower lobe pneumonia with focal consolidation.",
"Bilateral pleural effusions, greater on the right."
]
def apply_windowing(image, window_center=0.5, window_width=0.8):
"""
Apply window/level adjustment (similar to radiological windowing).
"""
img_array = np.array(image).astype(np.float32) / 255.0
# Apply windowing formula
min_val = window_center - window_width / 2
max_val = window_center + window_width / 2
img_array = np.clip((img_array - min_val) / (max_val - min_val), 0, 1)
return Image.fromarray((img_array * 255).astype(np.uint8))
def apply_edge_enhancement(image, amount=1.5):
"""Apply edge enhancement using unsharp mask."""
# Convert to PIL if numpy
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Create sharpen filter
enhancer = ImageEnhance.Sharpness(image)
return enhancer.enhance(amount)
def apply_median_filter(image, size=3):
"""Apply median filter to reduce noise."""
# Convert to PIL if numpy
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Ensure size is valid (odd number)
size = max(3, int(size))
if size % 2 == 0:
size += 1
# Apply median filter using numpy instead of PIL for more reliability
img_array = np.array(image)
filtered = cv2.medianBlur(img_array, size)
return Image.fromarray(filtered)
def apply_clahe(image, clip_limit=2.0, grid_size=(8, 8)):
"""Apply CLAHE to enhance contrast."""
# Convert to numpy if PIL
if isinstance(image, Image.Image):
img_array = np.array(image)
else:
img_array = image
# Apply CLAHE
clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid_size)
enhanced = clahe.apply(img_array)
return Image.fromarray(enhanced)
def apply_histogram_equalization(image):
"""Apply histogram equalization to enhance contrast."""
# Convert to PIL if numpy
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
return ImageOps.equalize(image)
def apply_vignette(image, amount=0.85):
"""Apply vignette effect (darker edges) to mimic X-ray effect."""
# Convert to numpy array
img_array = np.array(image).astype(np.float32)
# Create vignette mask
height, width = img_array.shape
center_x, center_y = width // 2, height // 2
radius = np.sqrt(width**2 + height**2) / 2
# Create coordinate grid
y, x = np.ogrid[:height, :width]
dist_from_center = np.sqrt((x - center_x)**2 + (y - center_y)**2)
# Create vignette mask
mask = 1 - amount * (dist_from_center / radius)
mask = np.clip(mask, 0, 1)
# Apply mask
img_array = img_array * mask
return Image.fromarray(np.clip(img_array, 0, 255).astype(np.uint8))
def enhance_xray(image, params=None):
"""
Apply a sequence of enhancements to make the image look more like an authentic X-ray.
"""
# Default parameters
if params is None:
params = {
'window_center': 0.5,
'window_width': 0.8,
'edge_amount': 1.3,
'median_size': 3,
'clahe_clip': 2.5,
'clahe_grid': (8, 8),
'vignette_amount': 0.25,
'apply_hist_eq': True
}
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# 1. Apply windowing for better contrast
image = apply_windowing(image, params['window_center'], params['window_width'])
# 2. Apply CLAHE for adaptive contrast
image_np = np.array(image)
image = apply_clahe(image_np, params['clahe_clip'], params['clahe_grid'])
# 3. Apply median filter to reduce noise
image = apply_median_filter(image, params['median_size'])
# 4. Apply edge enhancement to highlight lung markings
image = apply_edge_enhancement(image, params['edge_amount'])
# 5. Apply histogram equalization for better grayscale distribution (optional)
if params['apply_hist_eq']:
image = apply_histogram_equalization(image)
# 6. Apply vignette effect for authentic X-ray look
image = apply_vignette(image, params['vignette_amount'])
return image
def generate_and_enhance(generator, prompt, params_list=None):
"""
Generate an X-ray and apply different enhancement parameter sets.
"""
# Generate the raw X-ray
results = generator.generate(prompt=prompt, num_inference_steps=100, guidance_scale=10.0)
raw_image = results['images'][0]
# Create default parameters if none provided
if params_list is None:
params_list = [{
'window_center': 0.5,
'window_width': 0.8,
'edge_amount': 1.3,
'median_size': 3,
'clahe_clip': 2.5,
'clahe_grid': (8, 8),
'vignette_amount': 0.25,
'apply_hist_eq': True
}]
# Apply different enhancement parameters
enhanced_images = []
for i, params in enumerate(params_list):
enhanced = enhance_xray(raw_image, params)
enhanced_images.append({
'image': enhanced,
'params': params,
'index': i+1
})
return {
'raw_image': raw_image,
'enhanced_images': enhanced_images,
'prompt': prompt
}
def save_results(results, output_dir):
"""Save all generated and enhanced images."""
prompt_clean = results['prompt'].replace(" ", "_").replace(".", "").lower()[:30]
# Save raw image
raw_path = Path(output_dir) / f"raw_{prompt_clean}.png"
results['raw_image'].save(raw_path)
# Save enhanced images
for item in results['enhanced_images']:
enhanced_path = Path(output_dir) / f"enhanced_{item['index']}_{prompt_clean}.png"
item['image'].save(enhanced_path)
# Save parameters as json
params_path = Path(output_dir) / f"params_{item['index']}_{prompt_clean}.txt"
with open(params_path, 'w') as f:
for key, value in item['params'].items():
f.write(f"{key}: {value}\n")
return raw_path
def display_results(results):
"""Display the raw and enhanced images for comparison."""
n_enhanced = len(results['enhanced_images'])
fig, axes = plt.subplots(1, n_enhanced+1, figsize=(4*(n_enhanced+1), 4))
# Plot raw image
axes[0].imshow(results['raw_image'], cmap='gray')
axes[0].set_title("Original (Raw)")
axes[0].axis('off')
# Plot enhanced images
for i, item in enumerate(results['enhanced_images']):
axes[i+1].imshow(item['image'], cmap='gray')
axes[i+1].set_title(f"Enhanced {item['index']}")
axes[i+1].axis('off')
plt.suptitle(f"Prompt: {results['prompt']}")
plt.tight_layout()
return fig
def main():
"""Main function to load model and generate enhanced X-rays."""
# Initialize generator with the epoch 480 model
print(f"Loading model from: {MODEL_PATH}")
generator = XrayGenerator(
model_path=str(MODEL_PATH),
device="cuda" if torch.cuda.is_available() else "cpu"
)
# Different parameter sets to try
params_sets = [
# Parameter Set 1: Balanced enhancement
{
'window_center': 0.5,
'window_width': 0.8,
'edge_amount': 1.3,
'median_size': 3,
'clahe_clip': 2.5,
'clahe_grid': (8, 8),
'vignette_amount': 0.25,
'apply_hist_eq': True
},
# Parameter Set 2: More contrast
{
'window_center': 0.45,
'window_width': 0.7,
'edge_amount': 1.5,
'median_size': 3,
'clahe_clip': 3.0,
'clahe_grid': (8, 8),
'vignette_amount': 0.3,
'apply_hist_eq': True
},
# Parameter Set 3: Sharper lung markings
{
'window_center': 0.55,
'window_width': 0.85,
'edge_amount': 1.8,
'median_size': 3,
'clahe_clip': 2.0,
'clahe_grid': (6, 6),
'vignette_amount': 0.2,
'apply_hist_eq': False
}
]
# Process each prompt
for i, prompt in enumerate(TEST_PROMPTS):
print(f"Processing prompt {i+1}/{len(TEST_PROMPTS)}: {prompt}")
# Generate and enhance images
results = generate_and_enhance(generator, prompt, params_sets)
# Save results
output_path = save_results(results, OUTPUT_DIR)
print(f"Saved results to {output_path.parent}")
# Display results (save figure)
fig = display_results(results)
fig_path = Path(OUTPUT_DIR) / f"comparison_{i+1}.png"
fig.savefig(fig_path)
plt.close(fig)
if __name__ == "__main__":
main()