Spaces:
Runtime error
Runtime error
File size: 2,297 Bytes
0c7479d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import json
import cv2
import numpy as np
import os
from torch.utils.data import Dataset
import pycocotools.mask as maskUtils
from torchvision import transforms
import utils.transforms as custom_transforms
from PIL import Image
class SAMDataset(Dataset):
def __init__(self, data_path='../data/files', txt_path='../data/data_85616.txt'):
self.data = []
with open(txt_path, 'rt') as f:
for line in f:
self.data.append(eval(line))
self.data_path = data_path
randomresizedcrop = custom_transforms.RandomResizedCrop(
512,
scale=(0.9, 1),
)
self.transform = custom_transforms.Compose([
randomresizedcrop,
custom_transforms.RandomHorizontalFlip(p=0.5),
custom_transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.5)
])
def __len__(self):
return len(self.data)
def load_rle_annotations_from_json(self, json_file_path, return_pil=True):
with open(json_file_path, 'r', encoding='utf-8') as f:
anno_data = json.load(f)
annotations = anno_data['annotations']
height = int(anno_data['image']['height'])
width = int(anno_data['image']['width'])
map = np.zeros((height,width), dtype=np.uint16)
for i in range(len(annotations)):
ann = annotations[i]
mask = maskUtils.decode(ann['segmentation'])
map[mask != 0] = i + 1
if return_pil:
res = np.zeros((map.shape[0], map.shape[1], 3))
res[:, :, 0] = map % 256
res[:, :, 1] = map // 256
res = Image.fromarray(res.astype(np.uint8))
return res
return map
def __getitem__(self, idx):
item = self.data[idx]
source_filename = item['source']
target_filename = item['target']
prompt = item['prompt']
source = self.load_rle_annotations_from_json(os.path.join(self.data_path, source_filename))
target = Image.open(os.path.join(self.data_path, target_filename))
target, source = self.transform(target, source)
print(source.max(), source.min())
target = target.permute(1,2,0)
return dict(jpg=target, txt=prompt, hint=source)
|