''' Author: Chris Xiao yl.xiao@mail.utoronto.ca Date: 2023-09-30 16:14:13 LastEditors: Chris Xiao yl.xiao@mail.utoronto.ca LastEditTime: 2023-12-17 01:50:37 FilePath: /EndoSAM/endoSAM/test.py Description: fine-tune inference script I Love IU Copyright (c) 2023 by Chris Xiao yl.xiao@mail.utoronto.ca, All Rights Reserved. ''' import argparse from omegaconf import OmegaConf from torch.utils.data import DataLoader import os from dataset import EndoVisDataset from utils import make_if_dont_exist, one_hot_embedding_3d import torch from model import EndoSAMAdapter import numpy as np from segment_anything.build_sam import sam_model_registry from loss import jaccard import cv2 import json import wget COMMON_MODEL_LINKS={ 'default': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth' } def parse_command(): parser = argparse.ArgumentParser() parser.add_argument('--cfg', default=None, type=str, help='path to config file') args = parser.parse_args() return args if __name__ == '__main__': args = parse_command() cfg_path = args.cfg device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if cfg_path is not None: if os.path.exists(cfg_path): cfg = OmegaConf.load(cfg_path) else: raise FileNotFoundError(f'config file {cfg_path} not found') else: raise ValueError('config file not specified') if 'sam_model_dir' not in OmegaConf.to_container(cfg)['model'].keys() or OmegaConf.is_missing(cfg.model, 'sam_model_dir') or not os.path.exists(cfg.model.sam_model_dir): print("Didn't find SAM Checkpoint. Downloading from Facebook AI...") parent_dir = '/'.join(os.getcwd().split('/')[:-1]) model_dir = os.path.join(parent_dir, 'sam_ckpts') make_if_dont_exist(model_dir, overwrite=True) checkpoint = os.path.join(model_dir, cfg.model.sam_model_type+'.pth') wget.download(COMMON_MODEL_LINKS[cfg.model.sam_model_type], checkpoint) OmegaConf.update(cfg, 'model.sam_model_dir', checkpoint) OmegaConf.save(cfg, cfg_path) exp = cfg.experiment_name root_dir = cfg.dataset.dataset_dir img_format = cfg.dataset.img_format ann_format = cfg.dataset.ann_format model_path = cfg.model_folder model_exp_path = os.path.join(model_path, exp) test_path = cfg.test_folder test_exp_path = os.path.join(test_path, exp) test_exp_mask_path = os.path.join(test_exp_path,'mask') test_exp_overlay_path = os.path.join(test_exp_path, 'overlay') make_if_dont_exist(test_exp_path) make_if_dont_exist(test_exp_mask_path) make_if_dont_exist(test_exp_overlay_path) test_dataset = EndoVisDataset(root_dir, ann_format=ann_format, img_format=img_format, mode='test', encoder_size=cfg.model.encoder_size) test_loader = DataLoader(test_dataset, batch_size=cfg.test_bs, shuffle=False, num_workers=cfg.num_workers) sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder = sam_model_registry[cfg.model.sam_model_type](checkpoint=cfg.model.sam_model_dir,customized=cfg.model.sam_model_customized) model = EndoSAMAdapter(device, cfg.model.class_num, sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder, num_token=cfg.num_token).to(device) weights = torch.load(os.path.join(model_exp_path,'model.pth'), map_location=device)['endosam_state_dict'] model.load_state_dict(weights) model.eval() iou_dict = {} ious = [] with torch.no_grad(): for img, ann, name, img_bgr in test_loader: cv2.destroyAllWindows() img = img.to(device) ann = ann.to(device).unsqueeze(1).long() ann = one_hot_embedding_3d(ann, class_num=cfg.model.class_num) pred, pred_quality = model(img) mask_iou = np.nan if torch.unique(pred).size()[0] > 1: iou = jaccard(ann, pred) mask_iou = iou.item() iou_dict[name[0]] = mask_iou ious.append(mask_iou) pred = torch.argmax(pred, dim=1) numpy_pred = pred.squeeze(0).detach().cpu().numpy() numpy_pred[numpy_pred != 0] = 255 img_bgr = img_bgr.squeeze(0).detach().cpu().numpy() # 将预测结果转换为三通道图像 overlay = np.zeros_like(img_bgr) red_color = (0, 0, 255) # 红色 overlay[:,:,2][numpy_pred == 255] = 255 # 将红色区域叠加在原图上 alpha = 0.5 # 半透明度 result = cv2.addWeighted(img_bgr, 1 - alpha, overlay, alpha, 0) cv2.imshow('Result', result) # 等待键盘输入(最多等待1秒) key = cv2.waitKey(1000) # 超时时间为1000毫秒(1秒) # 判断是否有键盘输入 if key == ord('q'): # 如果用户按下 'q' 键 cv2.destroyAllWindows() # 关闭窗口 else: # 继续执行其他操作 pass cv2.imwrite(os.path.join(test_exp_mask_path, f'{name[0]}.png'), numpy_pred.astype(np.uint8)) cv2.imwrite(os.path.join(test_exp_overlay_path, f'{name[0]}.png'), result) with open(os.path.join(test_exp_path, 'mask_ious.json'), 'w') as f: json.dump(iou_dict, f, indent=4, sort_keys=False) f.close() avg_iou = np.mean(ious, axis=0) print(f'average intersection over union of mask: {avg_iou}')