Edit model card

You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

This is the trained model for the controlnet-stablediffusion for the Synthetic CT/MRI generaion from Segmentation Map We have to customize the pipeline for controlnet-stablediffusion

This Model is trained on the JHU dataset, containing, 5312 CT volumes with corrosponding Segmentation mask,

We make the 2D slices of CT volumes ~ 1.3M 2D slices

Here is the training and inference code for Diff_Synth_CT

Training details

Hardware: 8x Nvidia-A6000

Batch size: 8 x 4 x 32

For direct inference

step 1: Clone the GitHub repo to get the customized ControlNet-StableDiffusion Pipeline Implementation

git clone https://github.com/Onkarsus13/DiffCTSeg

Step2: Go into the repository and install repository, dependency

cd DiffCTSeg
pip install -e ".[torch]"
pip install -e .[all,dev,notebooks]

Step3: Run python test_eraser.py OR You can run the code given below

from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler, PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler
import torch
from PIL import Image
import numpy as np
import glob


class_dict_BTCV = {
        0:(0, 0, 0),
        1:(255, 60, 0),
        2:(255, 60, 232),
        3:(134, 79, 117),
        4:(125, 0, 190),
        5:(117, 200, 191),
        6:(230, 91, 101),
        7:(255, 0, 155),
        8:(75, 205, 155),
        9:(100, 37, 200)
}

class_dict = {
        0:"background",
        1:"aorta",
        2:"kidney_left",
        3:"liver",
        4:"postcava",
        5:"stomach",
        6:"gall_bladder",
        7:"kidney_right",
        8:"pancreas",
        9:"spleen"
}

def rgb_to_onehot(rgb_arr, color_dict=class_dict_BTCV):
    num_classes = len(color_dict)
    shape = rgb_arr.shape[:2]+(num_classes,)
    arr = np.zeros( shape, dtype=np.int8 )
    for i, cls in enumerate(color_dict):
        arr[:,:,i] = np.all(rgb_arr.reshape( (-1,3) ) == color_dict[i], axis=1).reshape(shape[:2])
    return arr



pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
    "onkarsus13/PixArt_Dual_Tone_CT_SEG_V0.1", torch_dtype=torch.float16, safety_checker=None,
        feature_extractor=None,
)
pipe.scheduler = UniPCMultistepScheduler.from_pretrained('onkarsus13/PixArt_Dual_Tone_CT_SEG_V0.1', subfolder="scheduler")
pipe.to('cuda:0')
pipe.enable_model_cpu_offload()


generator = torch.Generator(device="cpu").manual_seed(1)
images = Image.open("<Give Segmentation Mask>")
npi = np.asarray(images.convert("RGB"))
npi = rgb_to_onehot(npi, ).argmax(-1)
unique_ids = np.unique(npi)

print('CT image containg '+" ".join([class_dict[i] for i in unique_ids]))
image = pipe(
    'CT image containg '+" ".join([class_dict[i] for i in unique_ids]),
    images,
    [images],
    num_inference_steps=30,
    generator=generator,
    controlnet_conditioning_scale=1.0,
).images[0]

image.save('./result.png')


Downloads last month
7
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Spaces using onkarsus13/PixArt_Dual_Tone_CT_SEG_V0.1 4