introvoyz041's picture
Upload folder using huggingface_hub
3f31c34 verified
import os
import cv2
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from utils import random_box, random_click
class DDTI(Dataset):
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
self.name_list = os.listdir(os.path.join(data_path,mode,'images'))
self.data_path = data_path
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.name_list)
def find_connected_components(self,mask):
mask = np.clip(mask,0,1)
num_labels, labels = cv2.connectedComponents(mask.astype(np.uint8))
point = []
point_labels = []
for label in range(1, num_labels):
component_mask = np.where(labels == label, 1, 0)
area = np.sum(component_mask)
if area > 400:
point_label, random_point = random_click(component_mask)
point.append(random_point)
point_labels.append(point_label)
# print(f"Random point in component {label}: {random_point}, label: {point_labels}")
if(len(point)==1):
point.append(point[0])
point_labels.append(point_labels[0])
if(len(point)>2):
point = point[:2]
point_labels = point_labels[:2]
point = np.array(point)
point_labels = np.array(point_labels)
return point_labels,point
def __getitem__(self, index):
point_label = 1
"""Get the images"""
name = self.name_list[index]
img_path = os.path.join(self.data_path, self.mode, 'images', name)
msk_path = os.path.join(self.data_path, self.mode, 'masks', name)
img = Image.open(img_path).convert('RGB')
mask = Image.open(msk_path).convert('L')
# if self.mode == 'Training':
# label = 0 if self.label_list[index] == 'benign' else 1
# else:
# label = int(self.label_list[index])
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
# two prompt
point_label, pt = self.find_connected_components(np.array(mask))
# one prompt
# point_label, pt = random_click(np.array(mask) / 255, point_label)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
mask = torch.clamp(mask,min=0,max=1).int()
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'label': mask,
'p_label':point_label,
'pt':pt,
'image_meta_dict':image_meta_dict,
}