Spaces:
Configuration error
Configuration error
import os | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import torch | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from utils import random_box, random_click | |
class DDTI(Dataset): | |
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False): | |
self.name_list = os.listdir(os.path.join(data_path,mode,'images')) | |
self.data_path = data_path | |
self.mode = mode | |
self.prompt = prompt | |
self.img_size = args.image_size | |
self.transform = transform | |
self.transform_msk = transform_msk | |
def __len__(self): | |
return len(self.name_list) | |
def find_connected_components(self,mask): | |
mask = np.clip(mask,0,1) | |
num_labels, labels = cv2.connectedComponents(mask.astype(np.uint8)) | |
point = [] | |
point_labels = [] | |
for label in range(1, num_labels): | |
component_mask = np.where(labels == label, 1, 0) | |
area = np.sum(component_mask) | |
if area > 400: | |
point_label, random_point = random_click(component_mask) | |
point.append(random_point) | |
point_labels.append(point_label) | |
# print(f"Random point in component {label}: {random_point}, label: {point_labels}") | |
if(len(point)==1): | |
point.append(point[0]) | |
point_labels.append(point_labels[0]) | |
if(len(point)>2): | |
point = point[:2] | |
point_labels = point_labels[:2] | |
point = np.array(point) | |
point_labels = np.array(point_labels) | |
return point_labels,point | |
def __getitem__(self, index): | |
point_label = 1 | |
"""Get the images""" | |
name = self.name_list[index] | |
img_path = os.path.join(self.data_path, self.mode, 'images', name) | |
msk_path = os.path.join(self.data_path, self.mode, 'masks', name) | |
img = Image.open(img_path).convert('RGB') | |
mask = Image.open(msk_path).convert('L') | |
# if self.mode == 'Training': | |
# label = 0 if self.label_list[index] == 'benign' else 1 | |
# else: | |
# label = int(self.label_list[index]) | |
newsize = (self.img_size, self.img_size) | |
mask = mask.resize(newsize) | |
if self.prompt == 'click': | |
# two prompt | |
point_label, pt = self.find_connected_components(np.array(mask)) | |
# one prompt | |
# point_label, pt = random_click(np.array(mask) / 255, point_label) | |
if self.transform: | |
state = torch.get_rng_state() | |
img = self.transform(img) | |
torch.set_rng_state(state) | |
if self.transform_msk: | |
mask = self.transform_msk(mask) | |
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0): | |
# mask = 1 - mask | |
mask = torch.clamp(mask,min=0,max=1).int() | |
name = name.split('/')[-1].split(".jpg")[0] | |
image_meta_dict = {'filename_or_obj':name} | |
return { | |
'image':img, | |
'label': mask, | |
'p_label':point_label, | |
'pt':pt, | |
'image_meta_dict':image_meta_dict, | |
} |