File size: 3,500 Bytes
594b244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119



import json
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset 
#from torchvision.io import read_image
from PIL import Image
import os
import torch 
import torchvision.transforms.functional as F
def tokenize_captions(  caption, tokenizer):
    captions = [caption] 
    inputs = tokenizer(
        captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    # tokenizer(prompt, padding='max_length',
    #                         max_length=self.tokenizer.model_max_length, return_tensors='pt')
    return inputs.input_ids




class SquarePad:
	def __call__(self, image ):
		w, h = image.size
		max_wh = max(w, h)
		hp = int((max_wh - w) / 2)
		vp = int((max_wh - h) / 2)
		padding = (hp, vp, hp, vp)
		return F.pad(image, padding, (255,255,255), 'constant') 

class NormalSegDataset(Dataset):
    def __init__(self,args, path,tokenizer,cfg_prob ):
 

        self.image_transforms = transforms.Compose(
            [
                #  transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
                #  SquarePad(),
                #  transforms.Pad( (200,100,200,300),fill=(255,255,255),padding_mode='constant'),
                #  transforms.RandomRotation(degrees=30, fill=(255, 255, 255)) , 
                 transforms.RandomResizedCrop(args.resolution, scale=(0.9, 1.0),   interpolation=2, ),
                    transforms.ToTensor(),
            ]
        )

        self.additional_image_transforms = transforms.Compose(
            [transforms.Normalize([0.5], [0.5]),]
        )


        meta_path = os.path.join(path, 'meta_train_seg.json')

        with open(meta_path, 'r') as f:
            self.meta = json.load(f)

        

        self.keys = self.meta['keys']
        self.meta = self.meta['data']
  

        self.tokenizer = tokenizer

        self.cfg_prob = cfg_prob
 
    def __len__(self):
        return len(self.keys)
    
    def __getitem__(self, index): 

        meta_data = self.meta[self.keys[index]]

        rgb_path = meta_data['rgb']
        normal_path = meta_data['normal']
        seg_path = meta_data['seg']
        text_prompt = meta_data['caption'][0]

        rand = torch.rand(1).item()
        if rand < self.cfg_prob: 
            text_prompt = "" 
        
        image = Image.open(rgb_path).convert("RGB")
        state = torch.get_rng_state()
        image = self.image_transforms(image)

        rand = torch.rand(1).item() 
        if rand < self.cfg_prob:
            # get a white image
            # print("white image")
            normal_image = Image.new('RGB', (image.shape[1], image.shape[2]), (255, 255, 255))
            # gray_image = Image.new('L', (image.shape[1], image.shape[2]), (255))
            seg_image = Image.new('L', (image.shape[1], image.shape[2]), (0))
        else:
            normal_image = Image.open(normal_path).convert("RGB") 
            seg_image = Image.open(seg_path).convert("L")
        torch.set_rng_state(state)
        normal_image = self.image_transforms(normal_image)

        torch.set_rng_state(state)
        seg_image = self.image_transforms(seg_image)


        conditioning_image = torch.cat([normal_image, seg_image], dim=0)

        image = self.additional_image_transforms(image)
 
        prompt = text_prompt 
        


 
        prompt = tokenize_captions(prompt, self.tokenizer)

        return image, conditioning_image, prompt, text_prompt