NoMAISI / scripts /diff_model_infer.py
ft42's picture
Upload 63 files
599a397 verified
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import argparse
import logging
import os
import random
from datetime import datetime
import nibabel as nib
import numpy as np
import torch
import torch.distributed as dist
from monai.inferers import sliding_window_inference
from monai.inferers.inferer import SlidingWindowInferer
from monai.networks.schedulers import RFlowScheduler
from monai.utils import set_determinism
from tqdm import tqdm
from .diff_model_setting import initialize_distributed, load_config, setup_logging
from .sample import ReconModel, check_input
from .utils import define_instance, dynamic_infer
def set_random_seed(seed: int) -> int:
"""
Set random seed for reproducibility.
Args:
seed (int): Random seed.
Returns:
int: Set random seed.
"""
random_seed = random.randint(0, 99999) if seed is None else seed
set_determinism(random_seed)
return random_seed
def load_models(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> tuple:
"""
Load the autoencoder and UNet models.
Args:
args (argparse.Namespace): Configuration arguments.
device (torch.device): Device to load models on.
logger (logging.Logger): Logger for logging information.
Returns:
tuple: Loaded autoencoder, UNet model, and scale factor.
"""
autoencoder = define_instance(args, "autoencoder_def").to(device)
try:
checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True)
autoencoder.load_state_dict(checkpoint_autoencoder)
except Exception:
logger.error("The trained_autoencoder_path does not exist!")
unet = define_instance(args, "diffusion_unet_def").to(device)
checkpoint = torch.load(f"{args.model_dir}/{args.model_filename}", map_location=device, weights_only=False)
unet.load_state_dict(checkpoint["unet_state_dict"], strict=True)
logger.info(f"checkpoints {args.model_dir}/{args.model_filename} loaded.")
scale_factor = checkpoint["scale_factor"]
logger.info(f"scale_factor -> {scale_factor}.")
return autoencoder, unet, scale_factor
def prepare_tensors(args: argparse.Namespace, device: torch.device) -> tuple:
"""
Prepare necessary tensors for inference.
Args:
args (argparse.Namespace): Configuration arguments.
device (torch.device): Device to load tensors on.
Returns:
tuple: Prepared top_region_index_tensor, bottom_region_index_tensor, and spacing_tensor.
"""
top_region_index_tensor = np.array(args.diffusion_unet_inference["top_region_index"]).astype(float) * 1e2
bottom_region_index_tensor = np.array(args.diffusion_unet_inference["bottom_region_index"]).astype(float) * 1e2
spacing_tensor = np.array(args.diffusion_unet_inference["spacing"]).astype(float) * 1e2
top_region_index_tensor = torch.from_numpy(top_region_index_tensor[np.newaxis, :]).half().to(device)
bottom_region_index_tensor = torch.from_numpy(bottom_region_index_tensor[np.newaxis, :]).half().to(device)
spacing_tensor = torch.from_numpy(spacing_tensor[np.newaxis, :]).half().to(device)
modality_tensor = args.diffusion_unet_inference["modality"] * torch.ones(
(len(spacing_tensor)), dtype=torch.long
).to(device)
return top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor
def run_inference(
args: argparse.Namespace,
device: torch.device,
autoencoder: torch.nn.Module,
unet: torch.nn.Module,
scale_factor: float,
top_region_index_tensor: torch.Tensor,
bottom_region_index_tensor: torch.Tensor,
spacing_tensor: torch.Tensor,
modality_tensor: torch.Tensor,
output_size: tuple,
divisor: int,
logger: logging.Logger,
) -> np.ndarray:
"""
Run the inference to generate synthetic images.
Args:
args (argparse.Namespace): Configuration arguments.
device (torch.device): Device to run inference on.
autoencoder (torch.nn.Module): Autoencoder model.
unet (torch.nn.Module): UNet model.
scale_factor (float): Scale factor for the model.
top_region_index_tensor (torch.Tensor): Top region index tensor.
bottom_region_index_tensor (torch.Tensor): Bottom region index tensor.
spacing_tensor (torch.Tensor): Spacing tensor.
modality_tensor (torch.Tensor): Modality tensor.
output_size (tuple): Output size of the synthetic image.
divisor (int): Divisor for downsample level.
logger (logging.Logger): Logger for logging information.
Returns:
np.ndarray: Generated synthetic image data.
"""
include_body_region = unet.include_top_region_index_input
include_modality = unet.num_class_embeds is not None
noise = torch.randn(
(
1,
args.latent_channels,
output_size[0] // divisor,
output_size[1] // divisor,
output_size[2] // divisor,
),
device=device,
)
logger.info(f"noise: {noise.device}, {noise.dtype}, {type(noise)}")
image = noise
noise_scheduler = define_instance(args, "noise_scheduler")
if isinstance(noise_scheduler, RFlowScheduler):
noise_scheduler.set_timesteps(
num_inference_steps=args.diffusion_unet_inference["num_inference_steps"],
input_img_size_numel=torch.prod(torch.tensor(noise.shape[2:])),
)
else:
noise_scheduler.set_timesteps(num_inference_steps=args.diffusion_unet_inference["num_inference_steps"])
recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device)
autoencoder.eval()
unet.eval()
all_timesteps = noise_scheduler.timesteps
all_next_timesteps = torch.cat((all_timesteps[1:], torch.tensor([0], dtype=all_timesteps.dtype)))
progress_bar = tqdm(
zip(all_timesteps, all_next_timesteps),
total=min(len(all_timesteps), len(all_next_timesteps)),
)
with torch.amp.autocast("cuda", enabled=True):
for t, next_t in progress_bar:
# Create a dictionary to store the inputs
unet_inputs = {
"x": image,
"timesteps": torch.Tensor((t,)).to(device),
"spacing_tensor": spacing_tensor,
}
# Add extra arguments if include_body_region is True
if include_body_region:
unet_inputs.update(
{
"top_region_index_tensor": top_region_index_tensor,
"bottom_region_index_tensor": bottom_region_index_tensor,
}
)
if include_modality:
unet_inputs.update(
{
"class_labels": modality_tensor,
}
)
model_output = unet(**unet_inputs)
if not isinstance(noise_scheduler, RFlowScheduler):
image, _ = noise_scheduler.step(model_output, t, image) # type: ignore
else:
image, _ = noise_scheduler.step(model_output, t, image, next_t) # type: ignore
inferer = SlidingWindowInferer(
roi_size=[80, 80, 80],
sw_batch_size=1,
progress=True,
mode="gaussian",
overlap=0.4,
sw_device=device,
device=device,
)
synthetic_images = dynamic_infer(inferer, recon_model, image)
data = synthetic_images.squeeze().cpu().detach().numpy()
a_min, a_max, b_min, b_max = -1000, 1000, 0, 1
data = (data - b_min) / (b_max - b_min) * (a_max - a_min) + a_min
data = np.clip(data, a_min, a_max)
return np.int16(data)
def save_image(
data: np.ndarray,
output_size: tuple,
out_spacing: tuple,
output_path: str,
logger: logging.Logger,
) -> None:
"""
Save the generated synthetic image to a file.
Args:
data (np.ndarray): Synthetic image data.
output_size (tuple): Output size of the image.
out_spacing (tuple): Spacing of the output image.
output_path (str): Path to save the output image.
logger (logging.Logger): Logger for logging information.
"""
out_affine = np.eye(4)
for i in range(3):
out_affine[i, i] = out_spacing[i]
new_image = nib.Nifti1Image(data, affine=out_affine)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
nib.save(new_image, output_path)
logger.info(f"Saved {output_path}.")
@torch.inference_mode()
def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
"""
Main function to run the diffusion model inference.
Args:
env_config_path (str): Path to the environment configuration file.
model_config_path (str): Path to the model configuration file.
model_def_path (str): Path to the model definition file.
"""
args = load_config(env_config_path, model_config_path, model_def_path)
local_rank, world_size, device = initialize_distributed(num_gpus)
logger = setup_logging("inference")
random_seed = set_random_seed(
args.diffusion_unet_inference["random_seed"] + local_rank
if args.diffusion_unet_inference["random_seed"]
else None
)
logger.info(f"Using {device} of {world_size} with random seed: {random_seed}")
output_size = tuple(args.diffusion_unet_inference["dim"])
out_spacing = tuple(args.diffusion_unet_inference["spacing"])
output_prefix = args.output_prefix
ckpt_filepath = f"{args.model_dir}/{args.model_filename}"
if local_rank == 0:
logger.info(f"[config] ckpt_filepath -> {ckpt_filepath}.")
logger.info(f"[config] random_seed -> {random_seed}.")
logger.info(f"[config] output_prefix -> {output_prefix}.")
logger.info(f"[config] output_size -> {output_size}.")
logger.info(f"[config] out_spacing -> {out_spacing}.")
check_input(None, None, None, output_size, out_spacing, None)
autoencoder, unet, scale_factor = load_models(args, device, logger)
num_downsample_level = max(
1,
(
len(args.diffusion_unet_def["num_channels"])
if isinstance(args.diffusion_unet_def["num_channels"], list)
else len(args.diffusion_unet_def["attention_levels"])
),
)
divisor = 2 ** (num_downsample_level - 2)
logger.info(f"num_downsample_level -> {num_downsample_level}, divisor -> {divisor}.")
top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor = prepare_tensors(args, device)
data = run_inference(
args,
device,
autoencoder,
unet,
scale_factor,
top_region_index_tensor,
bottom_region_index_tensor,
spacing_tensor,
modality_tensor,
output_size,
divisor,
logger,
)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
output_path = "{0}/{1}_seed{2}_size{3:d}x{4:d}x{5:d}_spacing{6:.2f}x{7:.2f}x{8:.2f}_{9}_rank{10}.nii.gz".format(
args.output_dir,
output_prefix,
random_seed,
output_size[0],
output_size[1],
output_size[2],
out_spacing[0],
out_spacing[1],
out_spacing[2],
timestamp,
local_rank,
)
save_image(data, output_size, out_spacing, output_path, logger)
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Diffusion Model Inference")
parser.add_argument(
"--env_config",
type=str,
default="./configs/environment_maisi_diff_model_train.json",
help="Path to environment configuration file",
)
parser.add_argument(
"--model_config",
type=str,
default="./configs/config_maisi_diff_model_train.json",
help="Path to model training/inference configuration",
)
parser.add_argument(
"--model_def",
type=str,
default="./configs/config_maisi.json",
help="Path to model definition file",
)
parser.add_argument(
"--num_gpus",
type=int,
default=1,
help="Number of GPUs to use for distributed inference",
)
args = parser.parse_args()
diff_model_infer(args.env_config, args.model_config, args.model_def, args.num_gpus)