File size: 3,500 Bytes
594b244 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import json
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
#from torchvision.io import read_image
from PIL import Image
import os
import torch
import torchvision.transforms.functional as F
def tokenize_captions( caption, tokenizer):
captions = [caption]
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
# tokenizer(prompt, padding='max_length',
# max_length=self.tokenizer.model_max_length, return_tensors='pt')
return inputs.input_ids
class SquarePad:
def __call__(self, image ):
w, h = image.size
max_wh = max(w, h)
hp = int((max_wh - w) / 2)
vp = int((max_wh - h) / 2)
padding = (hp, vp, hp, vp)
return F.pad(image, padding, (255,255,255), 'constant')
class NormalSegDataset(Dataset):
def __init__(self,args, path,tokenizer,cfg_prob ):
self.image_transforms = transforms.Compose(
[
# transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
# SquarePad(),
# transforms.Pad( (200,100,200,300),fill=(255,255,255),padding_mode='constant'),
# transforms.RandomRotation(degrees=30, fill=(255, 255, 255)) ,
transforms.RandomResizedCrop(args.resolution, scale=(0.9, 1.0), interpolation=2, ),
transforms.ToTensor(),
]
)
self.additional_image_transforms = transforms.Compose(
[transforms.Normalize([0.5], [0.5]),]
)
meta_path = os.path.join(path, 'meta_train_seg.json')
with open(meta_path, 'r') as f:
self.meta = json.load(f)
self.keys = self.meta['keys']
self.meta = self.meta['data']
self.tokenizer = tokenizer
self.cfg_prob = cfg_prob
def __len__(self):
return len(self.keys)
def __getitem__(self, index):
meta_data = self.meta[self.keys[index]]
rgb_path = meta_data['rgb']
normal_path = meta_data['normal']
seg_path = meta_data['seg']
text_prompt = meta_data['caption'][0]
rand = torch.rand(1).item()
if rand < self.cfg_prob:
text_prompt = ""
image = Image.open(rgb_path).convert("RGB")
state = torch.get_rng_state()
image = self.image_transforms(image)
rand = torch.rand(1).item()
if rand < self.cfg_prob:
# get a white image
# print("white image")
normal_image = Image.new('RGB', (image.shape[1], image.shape[2]), (255, 255, 255))
# gray_image = Image.new('L', (image.shape[1], image.shape[2]), (255))
seg_image = Image.new('L', (image.shape[1], image.shape[2]), (0))
else:
normal_image = Image.open(normal_path).convert("RGB")
seg_image = Image.open(seg_path).convert("L")
torch.set_rng_state(state)
normal_image = self.image_transforms(normal_image)
torch.set_rng_state(state)
seg_image = self.image_transforms(seg_image)
conditioning_image = torch.cat([normal_image, seg_image], dim=0)
image = self.additional_image_transforms(image)
prompt = text_prompt
prompt = tokenize_captions(prompt, self.tokenizer)
return image, conditioning_image, prompt, text_prompt
|