LVM-Med / segmentation_3d /MedSAM_3d.py
duynhm's picture
Initial commit
be2715b
raw
history blame contribute delete
No virus
11.4 kB
import numpy as np
import os
join = os.path.join
import gc
from tqdm import tqdm
import torch
import monai, random
from segment_anything import sam_model_registry
from dataloader.sam_transforms import ResizeLongestSide
from dataloader.dataloader import sam_dataloader
from utils.SurfaceDice import multiclass_iou
def fit(cfg,
sam_model,
train_loader,
valid_dataset,
optimizer,
criterion,
model_save_path):
"""
Function to fit model
"""
best_valid_iou3d = 0
device = cfg.base.gpu_id
num_epochs = cfg.train.num_epochs
for epoch in range(num_epochs):
sam_model.train()
epoch_loss = 0
valid_iou3d = 0
print(f"Epoch #{epoch+1}/{num_epochs}")
for step, batch in enumerate(tqdm(train_loader, desc='Model training', unit='batch', leave=True)):
"""
Load precomputed image embeddings to ease training process
We also load mask labels and bounding boxes directly computed from ground truth masks
"""
image, true_mask, boxes = batch['image'], batch['mask'], batch['bboxes']
sam_model = sam_model.to(f"cuda:{device}")
image = image.to(f"cuda:{device}")
true_mask = true_mask.to(f"cuda:{device}")
"""
We freeze image encoder & prompt encoder, only finetune mask decoder
"""
with torch.no_grad():
"""
Compute image embeddings from a batch of images with SAM's frozen encoder
"""
encoder = torch.nn.DataParallel(sam_model.image_encoder, device_ids=[3, 2, 1, 0], output_device=device)
encoder = encoder.to(f"cuda:{encoder.device_ids[0]}")
sam_model = sam_model.to(f"cuda:{encoder.device_ids[0]}")
image = image.to(f"cuda:{encoder.device_ids[0]}")
image = sam_model.preprocess(image[:, :, :])
image_embedding = encoder(image)
"""
Get bounding boxes to make segmentation prediction
We follow the work by Jun Ma & Bo Wang in Segment Anything in Medical Images (2023)
to get bounding boxes from the masks as the boxes prompt for SAM
"""
box_np = boxes.numpy()
sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
box = sam_trans.apply_boxes(box_np, (true_mask.shape[-2], true_mask.shape[-1]))
box_torch = torch.as_tensor(box, dtype=torch.float, device=f"cuda:{device}")
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
"""
Encode box prompts information with SAM's frozen prompt encoder
"""
prompt_encoder = torch.nn.DataParallel(sam_model.prompt_encoder, device_ids=[0,1,2,3], output_device=device)
prompt_encoder = prompt_encoder.to(f"cuda:{prompt_encoder.device_ids[0]}")
box_torch = box_torch.to(f"cuda:{prompt_encoder.device_ids[0]}")
sparse_embeddings, dense_embeddings = prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
"""
We now finetune mask decoder
"""
sam_model = sam_model.to(f"cuda:{device}")
predicted_mask, iou_predictions = sam_model.mask_decoder(
image_embeddings=image_embedding.to(f"cuda:{device}"), # (B, 256, 64, 64)
image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
) # -> (B, 1, 256, 256)
predicted_mask = predicted_mask.to(f"cuda:{device}")
true_mask = true_mask.to(f"cuda:{device}")
loss = criterion(predicted_mask, true_mask)
"""
Upgrade model's params
"""
optimizer.zero_grad(set_to_none=True)
loss.backward()
clip_value = 1 # Clip gradient
torch.nn.utils.clip_grad_norm_(sam_model.mask_decoder.parameters(), clip_value)
optimizer.step()
epoch_loss += loss.item()
"""
Validation step with IoU as the metric
"""
with torch.no_grad():
valid_iou3d = eval_iou(sam_model,
valid_dataset,
device=device)
epoch_loss /= ((step + 1) * len(train_loader))
print(f'Loss: {epoch_loss}\n---')
"""
Save best model
"""
if best_valid_iou3d < valid_iou3d:
best_valid_iou3d = valid_iou3d
torch.save(sam_model.state_dict(), join(model_save_path, f'{cfg.base.best_valid_model_checkpoint}{cfg.base.random_seed}.pth'))
print(f"Valid 3D IoU: {valid_iou3d*100}")
print('=======================================')
print(f"Best valid 3D IoU: {best_valid_iou3d*100}")
def eval_iou(sam_model,
loader,
device):
"""
We use IoU to evalute 3D samples.
For 3D evaluation, we first concatenate 2D slices into 1 unified 3D volume and pass into model
However, due to limited computational resources, we could not perform 3D evaluation in GPU.
Hence, I set up to perform this function completely on CPU.
If you have enough resources, you could evaluate on multi-gpu the same as in training function.
"""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
iou_score = 0
num_volume = 0
for _, batch in enumerate(tqdm(loader.get_3d_iter(), leave=False)):
"""
Load precomputed embeddings, mask labels and bounding boxes computed directly from ground truth masks
"""
image, true_mask, boxes = batch['image'], batch['mask'], batch['bboxes']
image = image.to(f"cpu")
true_mask = true_mask.to(f"cpu", dtype=torch.float32)
"""
Compute image embeddings
"""
sam_model = sam_model.to(f"cpu")
image = image.to(f"cpu")
image = sam_model.preprocess(image[:, :, :])
image_embedding = sam_model.image_encoder(image)
"""
Get bboxes
"""
box_np = boxes.numpy()
sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
box = sam_trans.apply_boxes(box_np, (image_embedding.shape[0], image_embedding.shape[1]))
box_torch = torch.as_tensor(box, dtype=torch.float32, device=device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
"""
Prompt encoder component
"""
box_torch = box_torch.to(f"cpu")
sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
"""
Mask decoder component
"""
sam_model = sam_model.to(f"cpu")
mask_segmentation, iou_predictions = sam_model.mask_decoder(
image_embeddings=image_embedding.to(f"cpu"), # (B, 256, 64, 64)
image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
) # -> (B, 256, 256)
"""
Transform prediction and evaluate
"""
true_mask = true_mask.to("cpu")
medsam_seg_prob = torch.sigmoid(mask_segmentation)
medsam_seg = (medsam_seg_prob > 0.5).to(dtype=torch.float32)
iou_score += multiclass_iou((true_mask>0).to(dtype=torch.float32), (medsam_seg>0).to(dtype=torch.float32))
num_volume += 1
return iou_score.cpu().numpy()/num_volume
def medsam_3d(yml_args, cfg):
"""
Training warm up
"""
torch.multiprocessing.set_start_method('spawn')
random.seed(cfg.base.random_seed)
np.random.seed(cfg.base.random_seed)
torch.manual_seed(cfg.base.random_seed)
torch.cuda.manual_seed(cfg.base.random_seed)
torch.backends.cudnn.deterministic = True
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
"""
General configuration
"""
img_shape = (3, 1024) # hard settings image shape as 3 x 1024 x 1024
model_save_path = join("./work_dir", 'SAM-ViT-B')
os.makedirs(model_save_path, exist_ok=True)
print(f"Fine-tuned SAM (3D IoU) in {cfg.base.dataset_name} with {cfg.train.optimizer}, LR = {cfg.train.learning_rate}")
"""
Load SAM with its original checkpoint
"""
sam_model = sam_model_registry["vit_b"](checkpoint=cfg.base.original_checkpoint)
"""
Load precomputed embeddings
"""
train_loader, _, _, valid_dataset, test_dataset = sam_dataloader(cfg)
"""
Optimizer & learning rate scheduler config
"""
if cfg.train.optimizer == 'sgd':
optimizer = torch.optim.SGD(sam_model.mask_decoder.parameters(),
lr=float(cfg.train.learning_rate),
momentum=0.9)
elif cfg.train.optimizer == 'adam':
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(),
lr=float(cfg.train.learning_rate),
weight_decay=0,
amsgrad=True)
elif cfg.train.optimizer == 'adamw':
optimizer = torch.optim.AdamW(sam_model.mask_decoder.parameters(),
lr=float(cfg.train.learning_rate),
weight_decay=0)
else:
raise NotImplementedError(f"Optimizer {cfg.train.optimizer} is not set up yet")
"""
Loss function
In this work, we use a combination of Dice and Cross Entropy Loss to measure SAM's loss values.
"""
criterion = monai.losses.DiceCELoss(sigmoid=True,
squared_pred=True,
reduction='mean')
"""
Train model
"""
if not yml_args.use_test_mode:
fit(cfg,
sam_model=sam_model,
train_loader=train_loader,
valid_dataset=valid_dataset,
optimizer=optimizer,
criterion=criterion,
model_save_path=model_save_path)
"""
Test model
"""
with torch.no_grad():
sam_model_test_iou = sam_model_registry["vit_b"](checkpoint=join(model_save_path, f'{cfg.base.best_valid_model_checkpoint}{cfg.base.random_seed}.pth'))
sam_model_test_iou.eval()
test_iou_score = eval_iou(sam_model_test_iou,
test_dataset,
device=cfg.base.gpu_id)
print(f"Test 3D IoU score after training with {cfg.train.optimizer}(lr = {cfg.train.learning_rate}): {test_iou_score *100}")