Model Use
This model is a class-conditional DDPM that can generate ultrasound images. It is based on the diffusion model from the Towards Realistic Ultrasound Fetal Brain Imaging Synthesis paper. The dataset used to train the model is the FETAL_PLANES_DB dataset. The classes that can be generated and the associated integer labels are: Fetal abdomen (0), Fetal brain (1), Fetal femur (2), Fetal thorax (3), Maternal cervix (4), and Other (5). When generating images, simply provide the label of your chosen class as an argument to the UNet.
Below, you will find code that allows you to load this model, generate an image, and display it:
# !pip install --upgrade diffusers transformers accelerate scipy ftfy safetensors
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
import torch
import matplotlib.pyplot as plt
import numpy as np
# Are we using a GPU or CPU?
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load model and scheduler
model_id = "harveymannering/xfetus-ddpm-v2"
ddpm = DDPMPipeline.from_pretrained(model_id)
ddpm.to(device)
# Generate a single image
x = torch.randn(1, 3, 128, 128).to(device) # noise
for i, t in enumerate(ddpm.scheduler.timesteps):
model_input = ddpm.scheduler.scale_model_input(x, t)
with torch.no_grad():
# Conditiong on the 'Fetal brain' class (with index 1)
class_label = torch.ones(1, dtype=torch.int64)
noise_pred = ddpm.unet(model_input, t, class_label.to(device))["sample"]
x = ddpm.scheduler.step(noise_pred, t, x).prev_sample
# Display image
plt.imshow(np.transpose(x[0].cpu().detach().numpy(), (1,2,0)) + 0.5)
Example Outputs
The figure below includes examples of both real and synthetic images. The following preprocessing and augmentation steps were applied to all training images:
- Random Horizontal Flip
- Random Rotation (±45°)
- Resize to 128×128 using Bicubic Interpolation
Training Loss
The baseline model was trained exclusively on images from the 'Voluson E6' machine. Training and validation losses are presented below. Checkpoints were saved every 50 epochs, and the best-performing checkpoint in terms of validation loss was found at epoch 250. The model provided here corresponds to the checkpoint from epoch 250.