pythn's picture
Upload with huggingface_hub
4a1f918 verified
import random
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
import pandas as pd
from data_transforms.atr_transform import ATR_Transform
class ATR_Dataset(Dataset):
def __init__(self, config, is_train=False, shuffle_list = True, apply_norm=True, no_text_mode=False) -> None:
super().__init__()
self.root_path = config['data']['root_path']
self.img_names = []
self.img_path_list = []
self.label_path_list = []
self.label_list = []
self.class_in_image = []
self.is_train = is_train
self.label_names = config['data']['label_names']
self.num_classes = len(self.label_names)
self.config = config
self.apply_norm = apply_norm
self.no_text_mode = no_text_mode
if self.is_train:
self.df = pd.read_csv(os.path.join(self.root_path, 'folds_masks', 'train0.csv'))
else:
self.df = pd.read_csv(os.path.join(self.root_path, 'folds_masks', 'val0.csv'))
self.populate_lists()
if shuffle_list:
p = [x for x in range(len(self.img_path_list))]
random.shuffle(p)
self.img_path_list = [self.img_path_list[pi] for pi in p]
self.img_names = [self.img_names[pi] for pi in p]
self.label_path_list = [self.label_path_list[pi] for pi in p]
self.label_list = [self.label_list[pi] for pi in p]
self.class_in_image = [self.class_in_image[pi] for pi in p]
#define data transform
self.data_transform = ATR_Transform(config=config)
def __len__(self):
return len(self.img_path_list)
def populate_lists(self):
for i in range(len(self.df)):
img = self.df['mask_path'][i][6:]
img_path = os.path.join(self.root_path, 'imgs', img)
mask_path = os.path.join(self.root_path,self.df['mask_path'][i])
# print(img)
if (('jpg' not in img) and ('jpeg not in img') and ('png' not in img) and ('bmp' not in img)):
continue
if self.no_text_mode:
self.img_names.append(img)
self.img_path_list.append(img_path)
self.label_path_list.append(mask_path)
self.label_list.append('')
self.class_in_image.append(self.df['tgt'][i])
else:
for label_name in self.label_names:
self.img_names.append(img)
self.img_path_list.append(img_path)
self.label_path_list.append(mask_path)
self.label_list.append(label_name)
self.class_in_image.append(self.df['tgt'][i])
def __getitem__(self, index):
img = torch.as_tensor(np.array(Image.open(self.img_path_list[index]).convert("RGB")))
# print(img.shape)
if self.config['data']['volume_channel']==2:
img = img.permute(2,0,1)
try:
if self.num_classes>1:
# print("classs in image: ", self.class_in_image[index])
# print("label list: ", self.label_list[index])
if self.class_in_image[index]+' Vehicle'==self.label_list[index]:
label = torch.Tensor(np.array(Image.open(self.label_path_list[index])))
else:
label = torch.zeros(img.shape[1], img.shape[2])
else:
label = torch.Tensor(np.array(Image.open(self.label_path_list[index])))
if len(label.shape)==3:
label = label[:,:,0]
# print(label.shape)
except:
1/0
label = torch.zeros(img.shape[1], img.shape[2])
label = label.unsqueeze(0)
label = (label>0)+0
label_of_interest = self.label_list[index]
#convert all grayscale pixels due to resizing back to 0, 1
img, label = self.data_transform(img, label, is_train=self.is_train, apply_norm=self.apply_norm)
label = (label>=0.5)+0
label = label[0]
return img, label, self.img_path_list[index], label_of_interest