|
|
|
|
|
|
|
import json |
|
import torchvision.transforms as transforms |
|
from torch.utils.data.dataset import Dataset |
|
|
|
from PIL import Image |
|
import os |
|
import torch |
|
import torchvision.transforms.functional as F |
|
def tokenize_captions( caption, tokenizer): |
|
captions = [caption] |
|
inputs = tokenizer( |
|
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" |
|
) |
|
|
|
|
|
return inputs.input_ids |
|
|
|
|
|
|
|
|
|
class SquarePad: |
|
def __call__(self, image ): |
|
w, h = image.size |
|
max_wh = max(w, h) |
|
hp = int((max_wh - w) / 2) |
|
vp = int((max_wh - h) / 2) |
|
padding = (hp, vp, hp, vp) |
|
return F.pad(image, padding, (255,255,255), 'constant') |
|
|
|
class NormalSegDataset(Dataset): |
|
def __init__(self,args, path,tokenizer,cfg_prob ): |
|
|
|
|
|
self.image_transforms = transforms.Compose( |
|
[ |
|
|
|
|
|
|
|
|
|
transforms.RandomResizedCrop(args.resolution, scale=(0.9, 1.0), interpolation=2, ), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
|
|
self.additional_image_transforms = transforms.Compose( |
|
[transforms.Normalize([0.5], [0.5]),] |
|
) |
|
|
|
|
|
meta_path = os.path.join(path, 'meta_train_seg.json') |
|
|
|
with open(meta_path, 'r') as f: |
|
self.meta = json.load(f) |
|
|
|
|
|
|
|
self.keys = self.meta['keys'] |
|
self.meta = self.meta['data'] |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
|
self.cfg_prob = cfg_prob |
|
|
|
def __len__(self): |
|
return len(self.keys) |
|
|
|
def __getitem__(self, index): |
|
|
|
meta_data = self.meta[self.keys[index]] |
|
|
|
rgb_path = meta_data['rgb'] |
|
normal_path = meta_data['normal'] |
|
seg_path = meta_data['seg'] |
|
text_prompt = meta_data['caption'][0] |
|
|
|
rand = torch.rand(1).item() |
|
if rand < self.cfg_prob: |
|
text_prompt = "" |
|
|
|
image = Image.open(rgb_path).convert("RGB") |
|
state = torch.get_rng_state() |
|
image = self.image_transforms(image) |
|
|
|
rand = torch.rand(1).item() |
|
if rand < self.cfg_prob: |
|
|
|
|
|
normal_image = Image.new('RGB', (image.shape[1], image.shape[2]), (255, 255, 255)) |
|
|
|
seg_image = Image.new('L', (image.shape[1], image.shape[2]), (0)) |
|
else: |
|
normal_image = Image.open(normal_path).convert("RGB") |
|
seg_image = Image.open(seg_path).convert("L") |
|
torch.set_rng_state(state) |
|
normal_image = self.image_transforms(normal_image) |
|
|
|
torch.set_rng_state(state) |
|
seg_image = self.image_transforms(seg_image) |
|
|
|
|
|
conditioning_image = torch.cat([normal_image, seg_image], dim=0) |
|
|
|
image = self.additional_image_transforms(image) |
|
|
|
prompt = text_prompt |
|
|
|
|
|
|
|
|
|
prompt = tokenize_captions(prompt, self.tokenizer) |
|
|
|
return image, conditioning_image, prompt, text_prompt |
|
|
|
|
|
|