EndoSAM / endoSAM /dataset.py
Chris Xiao
init model
2df812d
raw
history blame
2.84 kB
'''
Author: Chris Xiao yl.xiao@mail.utoronto.ca
Date: 2023-09-16 17:41:29
LastEditors: Chris Xiao yl.xiao@mail.utoronto.ca
LastEditTime: 2023-12-17 18:22:42
FilePath: /EndoSAM/endoSAM/dataset.py
Description: EndoVisDataset class
I Love IU
Copyright (c) 2023 by Chris Xiao yl.xiao@mail.utoronto.ca, All Rights Reserved.
'''
from torch.utils.data import Dataset
import os
import glob
import numpy as np
import cv2
from utils import ResizeLongestSide, preprocess
import torch
modes = ['train', 'val', 'test']
class EndoVisDataset(Dataset):
def __init__(self, root,
ann_format= 'png',
img_format = 'jpg',
mode='train',
encoder_size=1024):
super(EndoVisDataset, self).__init__()
"""Define the customized EndoVis dataset
Args:
data_root_dir (str, optional): root dir containing all data. Defaults to "../data".
mode (str, optional): either in "train", "val" or "test" mode. Defaults to "train".
vit_mode (str, optional): "h", "l", "b" for huge, large, and base versions of SAM. Defaults to "h".
"""
self.root = root
self.mode = mode
self.ann_format = ann_format
self.img_format = img_format
self.encoder_size = encoder_size
self.ann_path = os.path.join(self.root, 'ann')
self.img_path = os.path.join(self.root, 'img')
if self.mode in modes:
self.img_mode_path = os.path.join(self.img_path, self.mode)
self.ann_mode_path = os.path.join(self.ann_path, self.mode)
else:
raise ValueError('Invalid mode: {}'.format(self.mode))
self.imgs = glob.glob(os.path.join(self.img_mode_path, '*.{}'.format(self.img_format)))
self.anns = glob.glob(os.path.join(self.ann_mode_path, '*.{}'.format(self.ann_format)))
self.transform = ResizeLongestSide(self.encoder_size)
def __len__(self):
if self.mode in modes:
assert len(self.imgs) == len(self.anns)
return len(self.imgs)
else:
raise ValueError('Invalid mode: {}'.format(self.mode))
def __getitem__(self, index) -> tuple:
img_bgr = cv2.imread(self.imgs[index])
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
name = os.path.basename(self.imgs[index]).split('.')[0]
input_image = self.transform.apply_image(img_rgb)
input_image_torch = torch.as_tensor(input_image).permute(2, 0, 1).contiguous()
img = preprocess(input_image_torch, self.encoder_size)
ann_path = os.path.join(self.ann_mode_path, f"{name}.{self.ann_format}")
ann = cv2.imread(ann_path, cv2.IMREAD_GRAYSCALE)
ann = np.array(ann)
ann[ann != 0] = 1
return img, ann, name, img_bgr