|
''' |
|
Author: Chris Xiao yl.xiao@mail.utoronto.ca |
|
Date: 2023-09-30 16:14:13 |
|
LastEditors: Chris Xiao yl.xiao@mail.utoronto.ca |
|
LastEditTime: 2023-12-17 01:50:37 |
|
FilePath: /EndoSAM/endoSAM/test.py |
|
Description: fine-tune inference script |
|
I Love IU |
|
Copyright (c) 2023 by Chris Xiao yl.xiao@mail.utoronto.ca, All Rights Reserved. |
|
''' |
|
import argparse |
|
from omegaconf import OmegaConf |
|
from torch.utils.data import DataLoader |
|
import os |
|
from dataset import EndoVisDataset |
|
from utils import make_if_dont_exist, one_hot_embedding_3d |
|
import torch |
|
from model import EndoSAMAdapter |
|
import numpy as np |
|
from segment_anything.build_sam import sam_model_registry |
|
from loss import jaccard |
|
import cv2 |
|
import json |
|
import wget |
|
|
|
COMMON_MODEL_LINKS={ |
|
'default': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', |
|
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', |
|
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', |
|
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth' |
|
} |
|
|
|
|
|
def parse_command(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--cfg', default=None, type=str, help='path to config file') |
|
args = parser.parse_args() |
|
return args |
|
|
|
if __name__ == '__main__': |
|
args = parse_command() |
|
cfg_path = args.cfg |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
if cfg_path is not None: |
|
if os.path.exists(cfg_path): |
|
cfg = OmegaConf.load(cfg_path) |
|
else: |
|
raise FileNotFoundError(f'config file {cfg_path} not found') |
|
else: |
|
raise ValueError('config file not specified') |
|
|
|
if 'sam_model_dir' not in OmegaConf.to_container(cfg)['model'].keys() or OmegaConf.is_missing(cfg.model, 'sam_model_dir') or not os.path.exists(cfg.model.sam_model_dir): |
|
print("Didn't find SAM Checkpoint. Downloading from Facebook AI...") |
|
parent_dir = '/'.join(os.getcwd().split('/')[:-1]) |
|
model_dir = os.path.join(parent_dir, 'sam_ckpts') |
|
make_if_dont_exist(model_dir, overwrite=True) |
|
checkpoint = os.path.join(model_dir, cfg.model.sam_model_type+'.pth') |
|
wget.download(COMMON_MODEL_LINKS[cfg.model.sam_model_type], checkpoint) |
|
OmegaConf.update(cfg, 'model.sam_model_dir', checkpoint) |
|
OmegaConf.save(cfg, cfg_path) |
|
|
|
exp = cfg.experiment_name |
|
root_dir = cfg.dataset.dataset_dir |
|
img_format = cfg.dataset.img_format |
|
ann_format = cfg.dataset.ann_format |
|
model_path = cfg.model_folder |
|
model_exp_path = os.path.join(model_path, exp) |
|
test_path = cfg.test_folder |
|
test_exp_path = os.path.join(test_path, exp) |
|
test_exp_mask_path = os.path.join(test_exp_path,'mask') |
|
test_exp_overlay_path = os.path.join(test_exp_path, 'overlay') |
|
|
|
make_if_dont_exist(test_exp_path) |
|
make_if_dont_exist(test_exp_mask_path) |
|
make_if_dont_exist(test_exp_overlay_path) |
|
|
|
test_dataset = EndoVisDataset(root_dir, ann_format=ann_format, img_format=img_format, mode='test', encoder_size=cfg.model.encoder_size) |
|
test_loader = DataLoader(test_dataset, batch_size=cfg.test_bs, shuffle=False, num_workers=cfg.num_workers) |
|
|
|
sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder = sam_model_registry[cfg.model.sam_model_type](checkpoint=cfg.model.sam_model_dir,customized=cfg.model.sam_model_customized) |
|
model = EndoSAMAdapter(device, cfg.model.class_num, sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder, num_token=cfg.num_token).to(device) |
|
weights = torch.load(os.path.join(model_exp_path,'model.pth'), map_location=device)['endosam_state_dict'] |
|
model.load_state_dict(weights) |
|
|
|
model.eval() |
|
|
|
iou_dict = {} |
|
ious = [] |
|
with torch.no_grad(): |
|
for img, ann, name, img_bgr in test_loader: |
|
cv2.destroyAllWindows() |
|
img = img.to(device) |
|
ann = ann.to(device).unsqueeze(1).long() |
|
ann = one_hot_embedding_3d(ann, class_num=cfg.model.class_num) |
|
pred, pred_quality = model(img) |
|
mask_iou = np.nan |
|
if torch.unique(pred).size()[0] > 1: |
|
iou = jaccard(ann, pred) |
|
mask_iou = iou.item() |
|
iou_dict[name[0]] = mask_iou |
|
ious.append(mask_iou) |
|
pred = torch.argmax(pred, dim=1) |
|
numpy_pred = pred.squeeze(0).detach().cpu().numpy() |
|
numpy_pred[numpy_pred != 0] = 255 |
|
img_bgr = img_bgr.squeeze(0).detach().cpu().numpy() |
|
|
|
overlay = np.zeros_like(img_bgr) |
|
red_color = (0, 0, 255) |
|
overlay[:,:,2][numpy_pred == 255] = 255 |
|
|
|
alpha = 0.5 |
|
result = cv2.addWeighted(img_bgr, 1 - alpha, overlay, alpha, 0) |
|
cv2.imshow('Result', result) |
|
|
|
key = cv2.waitKey(1000) |
|
|
|
if key == ord('q'): |
|
cv2.destroyAllWindows() |
|
else: |
|
|
|
pass |
|
cv2.imwrite(os.path.join(test_exp_mask_path, f'{name[0]}.png'), numpy_pred.astype(np.uint8)) |
|
cv2.imwrite(os.path.join(test_exp_overlay_path, f'{name[0]}.png'), result) |
|
|
|
with open(os.path.join(test_exp_path, 'mask_ious.json'), 'w') as f: |
|
json.dump(iou_dict, f, indent=4, sort_keys=False) |
|
|
|
f.close() |
|
avg_iou = np.mean(ious, axis=0) |
|
print(f'average intersection over union of mask: {avg_iou}') |