File size: 2,297 Bytes
0c7479d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import json
import cv2
import numpy as np
import os

from torch.utils.data import Dataset
import pycocotools.mask as maskUtils
from torchvision import transforms
import utils.transforms as custom_transforms
from PIL import Image
class SAMDataset(Dataset):
    def __init__(self, data_path='../data/files', txt_path='../data/data_85616.txt'):
        self.data = []
        with open(txt_path, 'rt') as f:
            for line in f:
                self.data.append(eval(line))
        self.data_path = data_path
        randomresizedcrop = custom_transforms.RandomResizedCrop(
            512,
            scale=(0.9, 1),
        )
        self.transform = custom_transforms.Compose([
            randomresizedcrop,
            custom_transforms.RandomHorizontalFlip(p=0.5),
            custom_transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.5)
        ])



    def __len__(self):
        return len(self.data)

    def load_rle_annotations_from_json(self, json_file_path, return_pil=True):
        with open(json_file_path, 'r', encoding='utf-8') as f:
            anno_data = json.load(f)
        annotations = anno_data['annotations']
        height = int(anno_data['image']['height'])
        width = int(anno_data['image']['width'])

        map = np.zeros((height,width), dtype=np.uint16)
        for i in range(len(annotations)):
            ann = annotations[i]
            mask = maskUtils.decode(ann['segmentation'])
            map[mask != 0] = i + 1
        if return_pil:
            res = np.zeros((map.shape[0], map.shape[1], 3))
            res[:, :, 0] = map % 256
            res[:, :, 1] = map // 256
            res = Image.fromarray(res.astype(np.uint8))
            return res
        return map

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']


        source = self.load_rle_annotations_from_json(os.path.join(self.data_path, source_filename))
        target = Image.open(os.path.join(self.data_path, target_filename))


        target, source = self.transform(target, source)

        print(source.max(), source.min())
        target = target.permute(1,2,0)

        return dict(jpg=target, txt=prompt, hint=source)