| |
|
| | import os
|
| | import torch
|
| | import numpy as np
|
| | from PIL import Image
|
| | import logging
|
| | from typing import Union, List, Dict, Tuple, Optional
|
| | from transformers import AutoTokenizer
|
| | from tqdm.auto import tqdm
|
| | from pathlib import Path
|
| |
|
| | from .models.diffusion import DiffusionModel
|
| | from .utils.processing import get_device, apply_clahe
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | class XrayGenerator:
|
| | """
|
| | Wrapper class for chest X-ray generation from text prompts.
|
| | """
|
| | def __init__(
|
| | self,
|
| | model_path: str,
|
| | device: Optional[torch.device] = None,
|
| | tokenizer_name: str = "dmis-lab/biobert-base-cased-v1.1",
|
| | ):
|
| | """
|
| | Initialize the X-ray generator.
|
| |
|
| | Args:
|
| | model_path: Path to the saved model weights
|
| | device: Device to run the model on (defaults to CUDA if available)
|
| | tokenizer_name: Name of the HuggingFace tokenizer
|
| | """
|
| | self.device = device if device is not None else get_device()
|
| | self.model_path = Path(model_path)
|
| |
|
| |
|
| | try:
|
| | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| | logger.info(f"Loaded tokenizer: {tokenizer_name}")
|
| | except Exception as e:
|
| | logger.error(f"Error loading tokenizer: {e}")
|
| | raise RuntimeError(f"Failed to load tokenizer: {e}")
|
| |
|
| |
|
| | self.model = self._load_model()
|
| |
|
| |
|
| | self.model.vae.eval()
|
| | self.model.text_encoder.eval()
|
| | self.model.unet.eval()
|
| |
|
| | logger.info("XrayGenerator initialized successfully")
|
| |
|
| | def _load_model(self) -> DiffusionModel:
|
| | """Load the diffusion model from saved weights."""
|
| | logger.info(f"Loading model from {self.model_path}")
|
| |
|
| | try:
|
| |
|
| | checkpoint = torch.load(self.model_path, map_location=self.device)
|
| |
|
| |
|
| | from .models.vae import MedicalVAE
|
| | from .models.text_encoder import MedicalTextEncoder
|
| | from .models.unet import DiffusionUNet
|
| |
|
| |
|
| | config = checkpoint.get('config', {})
|
| | latent_channels = config.get('latent_channels', 8)
|
| | model_channels = config.get('model_channels', 48)
|
| |
|
| |
|
| | vae = MedicalVAE(
|
| | in_channels=1,
|
| | out_channels=1,
|
| | latent_channels=latent_channels,
|
| | hidden_dims=[model_channels, model_channels*2, model_channels*4, model_channels*8]
|
| | ).to(self.device)
|
| |
|
| | text_encoder = MedicalTextEncoder(
|
| | model_name=config.get('text_model', "dmis-lab/biobert-base-cased-v1.1"),
|
| | projection_dim=768,
|
| | freeze_base=True
|
| | ).to(self.device)
|
| |
|
| | unet = DiffusionUNet(
|
| | in_channels=latent_channels,
|
| | model_channels=model_channels,
|
| | out_channels=latent_channels,
|
| | num_res_blocks=2,
|
| | attention_resolutions=(8, 16, 32),
|
| | dropout=0.1,
|
| | channel_mult=(1, 2, 4, 8),
|
| | context_dim=768
|
| | ).to(self.device)
|
| |
|
| |
|
| | if 'vae_state_dict' in checkpoint:
|
| | vae.load_state_dict(checkpoint['vae_state_dict'])
|
| | logger.info("Loaded VAE weights")
|
| |
|
| | if 'text_encoder_state_dict' in checkpoint:
|
| | text_encoder.load_state_dict(checkpoint['text_encoder_state_dict'])
|
| | logger.info("Loaded text encoder weights")
|
| |
|
| | if 'unet_state_dict' in checkpoint:
|
| | unet.load_state_dict(checkpoint['unet_state_dict'])
|
| | logger.info("Loaded UNet weights")
|
| |
|
| |
|
| | model = DiffusionModel(
|
| | vae=vae,
|
| | unet=unet,
|
| | text_encoder=text_encoder,
|
| | scheduler_type=config.get('scheduler_type', "ddim"),
|
| | num_train_timesteps=config.get('num_train_timesteps', 1000),
|
| | beta_schedule=config.get('beta_schedule', "linear"),
|
| | prediction_type=config.get('prediction_type', "epsilon"),
|
| | guidance_scale=config.get('guidance_scale', 7.5),
|
| | device=self.device
|
| | )
|
| |
|
| | return model
|
| |
|
| | except Exception as e:
|
| | logger.error(f"Error loading model: {e}")
|
| | import traceback
|
| | logger.error(traceback.format_exc())
|
| | raise RuntimeError(f"Failed to load model: {e}")
|
| |
|
| | @torch.no_grad()
|
| | def generate(
|
| | self,
|
| | prompt: Union[str, List[str]],
|
| | height: int = 256,
|
| | width: int = 256,
|
| | num_inference_steps: int = 50,
|
| | guidance_scale: float = 10.0,
|
| | eta: float = 0.0,
|
| | output_type: str = "pil",
|
| | return_dict: bool = True,
|
| | seed: Optional[int] = None,
|
| | ) -> Union[Dict, List[Image.Image]]:
|
| | """
|
| | Generate chest X-rays from text prompts.
|
| |
|
| | Args:
|
| | prompt: Text prompt(s) describing the X-ray
|
| | height: Output image height
|
| | width: Output image width
|
| | num_inference_steps: Number of denoising steps (more = higher quality, slower)
|
| | guidance_scale: Controls adherence to the text prompt (higher = more faithful)
|
| | eta: Controls randomness in sampling (0 = deterministic, 1 = stochastic)
|
| | output_type: Output format, one of ["pil", "np", "tensor"]
|
| | return_dict: Whether to return a dictionary with additional metadata
|
| | seed: Random seed for reproducible generation
|
| |
|
| | Returns:
|
| | Images and optionally metadata
|
| | """
|
| |
|
| | if seed is not None:
|
| | torch.manual_seed(seed)
|
| | torch.cuda.manual_seed(seed)
|
| |
|
| |
|
| | try:
|
| | results = self.model.sample(
|
| | text=prompt,
|
| | height=height,
|
| | width=width,
|
| | num_inference_steps=num_inference_steps,
|
| | guidance_scale=guidance_scale,
|
| | eta=eta,
|
| | tokenizer=self.tokenizer
|
| | )
|
| |
|
| |
|
| | images_tensor = results['images']
|
| |
|
| |
|
| | if output_type == "tensor":
|
| | images = images_tensor
|
| | elif output_type == "np":
|
| | images = [img.cpu().numpy().transpose(1, 2, 0) for img in images_tensor]
|
| | elif output_type == "pil":
|
| | images = []
|
| | for img in images_tensor:
|
| | img_np = img.cpu().numpy().transpose(1, 2, 0)
|
| | img_np = (img_np * 255).astype(np.uint8)
|
| | if img_np.shape[-1] == 1:
|
| | img_np = img_np.squeeze(-1)
|
| | images.append(Image.fromarray(img_np))
|
| | else:
|
| | raise ValueError(f"Unknown output type: {output_type}")
|
| |
|
| |
|
| | if return_dict:
|
| | return {
|
| | 'images': images,
|
| | 'latents': results['latents'].cpu(),
|
| | 'prompt': prompt,
|
| | 'parameters': {
|
| | 'height': height,
|
| | 'width': width,
|
| | 'num_inference_steps': num_inference_steps,
|
| | 'guidance_scale': guidance_scale,
|
| | 'eta': eta,
|
| | 'seed': seed
|
| | }
|
| | }
|
| | else:
|
| | return images
|
| |
|
| | except Exception as e:
|
| | logger.error(f"Error generating images: {e}")
|
| | import traceback
|
| | logger.error(traceback.format_exc())
|
| | raise
|
| |
|
| | def save_images(self, images, output_dir, base_filename="generated", add_prompt=True, prompts=None):
|
| | """
|
| | Save generated images to disk.
|
| |
|
| | Args:
|
| | images: List of images (PIL, numpy, or tensor)
|
| | output_dir: Directory to save images
|
| | base_filename: Base name for saved files
|
| | add_prompt: Whether to include prompt in filename
|
| | prompts: List of prompts corresponding to images
|
| | """
|
| | output_dir = Path(output_dir)
|
| | output_dir.mkdir(parents=True, exist_ok=True)
|
| |
|
| |
|
| | if isinstance(images[0], torch.Tensor):
|
| | images_pil = []
|
| | for img in images:
|
| | img_np = img.cpu().numpy().transpose(1, 2, 0)
|
| | img_np = (img_np * 255).astype(np.uint8)
|
| | if img_np.shape[-1] == 1:
|
| | img_np = img_np.squeeze(-1)
|
| | images_pil.append(Image.fromarray(img_np))
|
| | images = images_pil
|
| | elif isinstance(images[0], np.ndarray):
|
| | images_pil = []
|
| | for img in images:
|
| | img_np = (img * 255).astype(np.uint8)
|
| | if img_np.shape[-1] == 1:
|
| | img_np = img_np.squeeze(-1)
|
| | images_pil.append(Image.fromarray(img_np))
|
| | images = images_pil
|
| |
|
| |
|
| | for i, img in enumerate(images):
|
| |
|
| | if add_prompt and prompts is not None:
|
| |
|
| | prompt_str = prompts[i] if isinstance(prompts, list) else prompts
|
| | prompt_str = prompt_str.replace(" ", "_").replace(".", "").lower()
|
| | prompt_str = ''.join(c for c in prompt_str if c.isalnum() or c == '_')
|
| | prompt_str = prompt_str[:50]
|
| | filename = f"{base_filename}_{i+1}_{prompt_str}.png"
|
| | else:
|
| | filename = f"{base_filename}_{i+1}.png"
|
| |
|
| |
|
| | file_path = output_dir / filename
|
| | img.save(file_path)
|
| | logger.info(f"Saved image to {file_path}") |