Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# InstructDiffusion | |
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix) | |
# Modified by Binxin Yang (tennyson@mail.ustc.edu.cn) | |
# -------------------------------------------------------- | |
from __future__ import annotations | |
import os | |
import random | |
import copy | |
import json | |
import math | |
from pathlib import Path | |
from typing import Any | |
import numpy as np | |
import torch | |
import torchvision | |
from einops import rearrange | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from dataset.seg.refcoco import REFER | |
class RefCOCODataset(Dataset): | |
def __init__( | |
self, | |
path: str, | |
split: str = "train", | |
min_resize_res: int = 256, | |
max_resize_res: int = 256, | |
crop_res: int = 256, | |
flip_prob: float = 0.0, | |
transparency: float = 0.0, | |
test: bool = False, | |
): | |
assert split in ("train", "val", "test") | |
self.path = path | |
self.min_resize_res = min_resize_res | |
self.max_resize_res = max_resize_res | |
self.crop_res = crop_res | |
self.flip_prob = flip_prob | |
self.G_ref_dataset=REFER(data_root=path) | |
self.IMAGE_DIR = os.path.join(path, 'images/train2014') | |
self.list_ref=self.G_ref_dataset.getRefIds(split=split) | |
self.transparency = transparency | |
self.test = test | |
seg_diverse_prompt_path = 'dataset/prompt/prompt_seg.txt' | |
self.seg_diverse_prompt_list=[] | |
with open(seg_diverse_prompt_path) as f: | |
line=f.readline() | |
while line: | |
line=line.strip('\n') | |
self.seg_diverse_prompt_list.append(line) | |
line=f.readline() | |
color_list_file_path='dataset/prompt/color_list_train_small.txt' | |
self.color_list=[] | |
with open(color_list_file_path) as f: | |
line = f.readline() | |
while line: | |
line_split = line.strip('\n').split(" ") | |
if len(line_split)>1: | |
temp = [] | |
for i in range(4): | |
temp.append(line_split[i]) | |
self.color_list.append(temp) | |
line = f.readline() | |
def __len__(self) -> int: | |
return len(self.list_ref) | |
def _augmentation_new(self, image, label): | |
# Cropping | |
h, w = label.shape | |
if h > w: | |
start_h = random.randint(0, h - w) | |
end_h = start_h + w | |
image = image[start_h:end_h] | |
label = label[start_h:end_h] | |
elif h < w: | |
start_w = random.randint(0, w - h) | |
end_w = start_w + h | |
image = image[:, start_w:end_w] | |
label = label[:, start_w:end_w] | |
else: | |
pass | |
image = Image.fromarray(image).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.LANCZOS) | |
image = np.asarray(image, dtype=np.uint8) | |
label = Image.fromarray(label).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.NEAREST) | |
label = np.asarray(label, dtype=np.int64) | |
return image, label | |
def __getitem__(self, i: int) -> dict[str, Any]: | |
ref_ids = self.list_ref[i] | |
ref = self.G_ref_dataset.loadRefs(ref_ids)[0] | |
sentences = random.choice(ref['sentences'])['sent'] | |
prompt = random.choice(self.seg_diverse_prompt_list) | |
color = random.choice(self.color_list) | |
color_name = color[0] | |
prompt = prompt.format(color=color_name.lower(), object=sentences.lower()) | |
R, G, B = color[3].split(",") | |
R = int(R) | |
G = int(G) | |
B = int(B) | |
image_name = self.G_ref_dataset.loadImgs(ref['image_id'])[0]['file_name'] | |
image_path = os.path.join(self.IMAGE_DIR,image_name) | |
mask = self.G_ref_dataset.getMask(ref=ref)['mask'] | |
image = Image.open(image_path).convert("RGB") | |
image = np.asarray(image) | |
image, mask = self._augmentation_new(image,mask) | |
mask = (mask == 1) | |
image_0 = Image.fromarray(image) | |
image_1 = copy.deepcopy(image) | |
image_1[:,:,0][mask]=self.transparency*image_1[:,:,0][mask]+(1-self.transparency)*R | |
image_1[:,:,1][mask]=self.transparency*image_1[:,:,1][mask]+(1-self.transparency)*G | |
image_1[:,:,2][mask]=self.transparency*image_1[:,:,2][mask]+(1-self.transparency)*B | |
image_1 = Image.fromarray(image_1) | |
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() | |
image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS) | |
image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS) | |
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") | |
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") | |
crop = torchvision.transforms.RandomCrop(self.crop_res) | |
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) | |
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) | |
mask = torch.tensor(mask).float() | |
crop = torchvision.transforms.RandomCrop(self.crop_res) | |
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) | |
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) | |
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt)) |